gorm 改造 orders 代码
来源:6-9 【勤于思考,夯实学习成果】阶段练习题

Ganjr
2021-02-01
package dao
import (
"flash-sale/frontend/model"
"flash-sale/shared/id"
"flash-sale/shared/mysql/objid"
"fmt"
)
type Order struct {
ID int
UserId int
ProductID int
Status model.OrderStatus
}
func (m *Mysql) CreateOrder(o *model.Order) (*Order, error) {
objUId, err := objid.FromID(id.User(o.UserId))
if err != nil {
return nil, fmt.Errorf("invalid uid: %v", err)
}
objPId, err := objid.FromID(id.Product(o.ProductID))
if err != nil {
return nil, fmt.Errorf("invalid pid: %v", err)
}
ord := &Order{
UserId: objUId,
ProductID: objPId,
Status: o.Status,
}
res := m.Db.Create(ord)
if res.Error != nil {
return nil, res.Error
}
return ord, nil
}
func (m *Mysql) GetOrder(uid id.User) (*Order, error) {
objID, err := objid.FromID(uid)
if err != nil {
return nil, fmt.Errorf("invalid id: %v", err)
}
o := &Order{
UserId: objID,
}
res := m.Db.First(o)
if res.Error != nil {
return nil, res.Error
}
return o, nil
}
func (m *Mysql) GetOrders() ([]*Order, error) {
var ords []*Order
res := m.Db.Find(&ords)
if res.Error != nil {
return nil, res.Error
}
return ords, nil
}
func (m *Mysql) GetOrdersWithInfo() ([]*model.OrderWithInfo, error) {
var results []*model.OrderWithInfo
res := m.Db.Model(&Order{}).Select("orders.id, products.name, orders.status").
Joins("left join products on orders.product_id = products.id").Scan(&results)
if res.Error != nil {
return nil, res.Error
}
return results, nil
}
func (m *Mysql) UpdateOrder(oid id.Order, o *model.Order) error {
objID, err := objid.FromID(oid)
if err != nil {
return fmt.Errorf("invalid id: %v", err)
}
objUId, err := objid.FromID(id.User(o.UserId))
if err != nil {
return fmt.Errorf("invalid uid: %v", err)
}
objPId, err := objid.FromID(id.Product(o.ProductID))
if err != nil {
return fmt.Errorf("invalid pid: %v", err)
}
ord := Order{
UserId: objUId,
ProductID: objPId,
Status: o.Status,
}
res := m.Db.Model(&Order{ID: objID}).Updates(ord)
if res.Error != nil {
return res.Error
}
return nil
}
func (m *Mysql) DeleteOrder(oid id.Order) error {
objID, err := objid.FromID(oid)
if err != nil {
return fmt.Errorf("invalid id: %v", err)
}
res := m.Db.Delete(&Order{
ID: objID,
})
if res.Error != nil {
return res.Error
}
return nil
}
附上单元测试:
package dao
import (
"flash-sale/frontend/model"
"flash-sale/shared/id"
"flash-sale/shared/mysql/objid"
mysqltesting "flash-sale/shared/mysql/testing"
"fmt"
"github.com/google/go-cmp/cmp"
"testing"
)
func TestMysql_CreateOrder(t *testing.T) {
db, err := mysqltesting.NewDB()
if err != nil {
t.Fatalf("cannot get database: %v", err)
}
err = mysqltesting.CreateTables(db)
if err != nil {
t.Fatalf("cannot create tables: %v", err)
}
m := &Mysql{Db: db}
cases := []struct {
userId string
productID string
status model.OrderStatus
}{
{
userId: "123",
productID: "412341",
status: model.Unsubmitted,
},
{
userId: "1234",
productID: "12351245",
status: model.Pending,
},
}
for _, cc := range cases {
_, err := m.CreateOrder(&model.Order{
UserId: cc.userId,
ProductID: cc.productID,
Status: cc.status,
})
if err != nil {
t.Errorf("%s: error creating product: %v", cc.userId, err)
}
}
}
func TestMysql_GetOrder(t *testing.T) {
db, err := mysqltesting.NewDB()
if err != nil {
t.Fatalf("cannot get database: %v", err)
}
err = mysqltesting.CreateTables(db)
if err != nil {
t.Fatalf("cannot create tables: %v", err)
}
m := &Mysql{Db: db}
ord, err := m.CreateOrder(&model.Order{
UserId: "412351",
ProductID: "12341",
Status: model.Unsubmitted,
})
if err != nil {
t.Fatalf("cannot create order: %v", err)
}
got, err := m.GetOrder(id.User(objid.ToOrderID(ord.ID)))
if err != nil {
t.Errorf("cannot get product: %v", err)
}
if diff := cmp.Diff(ord, got); diff != "" {
t.Errorf("result differs; -want +got: %s", diff)
}
}
func TestMysql_GetOrders(t *testing.T) {
rows := []struct {
name string
userId string
productID string
status model.OrderStatus
}{
{
name: "product_name_for_get_products",
userId: "293784907",
productID: "123498",
status: model.Unsubmitted,
},
{
name: "product_name_for_get_products",
userId: "293784907",
productID: "123498",
status: model.Unsubmitted,
},
{
name: "product_name_for_get_products",
userId: "293784907",
productID: "123498",
status: model.Unsubmitted,
},
}
db, err := mysqltesting.NewDB()
if err != nil {
t.Fatalf("cannot get database: %v", err)
}
err = mysqltesting.CreateTables(db)
if err != nil {
t.Fatalf("cannot create tables: %v", err)
}
m := &Mysql{Db: db}
var wants []*Order
for _, row := range rows {
want, err := m.CreateOrder(&model.Order{
UserId: row.userId,
ProductID: row.productID,
Status: row.status,
})
if err != nil {
t.Fatalf("cannot create order: %v", err)
}
wants = append(wants, want)
}
gots, err := m.GetOrders()
if err != nil {
t.Errorf("cannot get orders: %v", err)
}
if diff := cmp.Diff(wants, gots); diff != "" {
t.Errorf("result differs; -want +got: %s", diff)
}
}
func TestMysql_GetOrdersWithInfo(t *testing.T) {
db, err := mysqltesting.NewDB()
if err != nil {
t.Fatalf("cannot get database: %v", err)
}
err = mysqltesting.CreateTables(db)
if err != nil {
t.Fatalf("cannot create tables: %v", err)
}
m := &Mysql{Db: db}
prod, err := m.CreateProduct(&model.Product{
Name: "product1",
Num: 666,
Image: "xxx/xxx.gif",
Url: "http://xxx.com/xxx",
})
if err != nil {
t.Fatalf("cannot create product: %v", err)
}
res, err := m.CreateOrder(&model.Order{
UserId: "412351",
ProductID: objid.ToProductID(prod.Id).String(),
Status: model.Unsubmitted,
})
if err != nil {
t.Fatalf("cannot create order: %v", err)
}
ords, err := m.GetOrdersWithInfo()
if err != nil {
t.Errorf("cannot get orders with info: %v", err)
}
fmt.Println("prod: ", prod.Id, *prod.Product)
fmt.Println("res: ", *res)
for _, ord := range ords {
fmt.Println(*ord /*, "product name: ", ord.ProductName*/)
}
}
func TestMysql_UpdateOrder(t *testing.T) {
db, err := mysqltesting.NewDB()
if err != nil {
t.Fatalf("cannot get database: %v", err)
}
err = mysqltesting.CreateTables(db)
if err != nil {
t.Fatalf("cannot create tables: %v", err)
}
m := &Mysql{Db: db}
res, err := m.CreateOrder(&model.Order{
UserId: "2341234",
ProductID: "12341234",
Status: model.Unsubmitted,
})
if err != nil {
t.Fatalf("cannot create order: %v", err)
}
err = m.UpdateOrder(objid.ToOrderID(res.ID), &model.Order{
UserId: "12341234",
ProductID: "1234134",
Status: model.Pending,
})
if err != nil {
t.Errorf("cannot update order: %v", err)
}
}
func TestMysql_DeleteOrder(t *testing.T) {
db, err := mysqltesting.NewDB()
if err != nil {
t.Fatalf("cannot get database: %v", err)
}
err = mysqltesting.CreateTables(db)
if err != nil {
t.Fatalf("cannot create tables: %v", err)
}
m := &Mysql{Db: db}
res, err := m.CreateOrder(&model.Order{
UserId: "2341234",
ProductID: "12341234",
Status: model.Unsubmitted,
})
if err != nil {
t.Fatalf("cannot create order: %v", err)
}
err = m.DeleteOrder(objid.ToOrderID(res.ID))
if err != nil {
t.Fatalf("cannot delete order: %v", err)
}
}
func TestMain(m *testing.M) {
os.Exit(mysqltesting.RunWithMysqlInDocker(m))
}
疑问:gorm 有强类型 join 的表达方式吗?
res := m.Db.Model(&Order{}).Select("orders.id, products.name, orders.status").
Joins("left join products on orders.product_id = products.id").Scan(&results)
写回答
1回答
-
nice 非常好
00
相似问题