diff --git a/advance/template/gen/annotation/file.go b/advance/template/gen/annotation/file.go index 2b0cf3e09949dabbe1bf4024e8d82f33a9e36e6a..955a3877cbff065115936e490269a95438f3d064 100644 --- a/advance/template/gen/annotation/file.go +++ b/advance/template/gen/annotation/file.go @@ -10,11 +10,20 @@ type SingleFileEntryVisitor struct { } func (s *SingleFileEntryVisitor) Get() File { - panic("implement me") + if s.file == nil { + return File{} + } + return s.file.Get() } func (s *SingleFileEntryVisitor) Visit(node ast.Node) ast.Visitor { - panic("implement me") + if file, ok := node.(*ast.File); ok { + s.file = &fileVisitor{ + ans: newAnnotations(file, file.Doc), + } + return s.file + } + return s } type fileVisitor struct { @@ -24,11 +33,27 @@ type fileVisitor struct { } func (f *fileVisitor) Get() File { - panic("implement me") + types := make([]Type, 0, len(f.types)) + for _, t := range f.types { + types = append(types, t.Get()) + } + return File{ + Annotations: f.ans, + Types: types, + } } func (f *fileVisitor) Visit(node ast.Node) ast.Visitor { - panic("implement me") + typ, ok := node.(*ast.TypeSpec) + if ok { + res := &typeVisitor{ + ans: newAnnotations(typ, typ.Doc), + fields: make([]Field, 0), + } + f.types = append(f.types, res) + return res + } + return f } type File struct { @@ -42,11 +67,18 @@ type typeVisitor struct { } func (t *typeVisitor) Get() Type { - panic("implement me") + return Type{ + Annotations: t.ans, + Fields: t.fields, + } } func (t *typeVisitor) Visit(node ast.Node) (w ast.Visitor) { - panic("implement me") + if fd, ok := node.(*ast.Field); ok { + t.fields = append(t.fields, Field{Annotations: newAnnotations(fd, fd.Doc)}) + return nil + } + return t } type Type struct { diff --git a/advance/template/gen/http/gen_http.go b/advance/template/gen/http/gen_http.go index b0cb6d377d45ffd8805a9e0aa83ba64678b7b4b1..11b51706e26cd06bff7ce426ae9081863f909715 100644 --- a/advance/template/gen/http/gen_http.go +++ b/advance/template/gen/http/gen_http.go @@ -6,7 +6,45 @@ import ( ) // 这部分和课堂的很像,但是有一些地方被我改掉了 -const serviceTpl = ` +const serviceTpl = `package {{.Package}} + +import ( + "bytes" + "context" + "encoding/json" + "io/ioutil" + "net/http" +) + +{{ $service :=.GenName -}} +type {{ $service }} struct { + Endpoint string + Path string + Client http.Client +} +{{range $idx, $method := .Methods}} +func (s *{{$service}}) {{$method.Name}}(ctx context.Context, req *{{$method.ReqTypeName}}) (*{{$method.RespTypeName}}, error) { + url := s.Endpoint + s.Path + "/{{$method.Name}}" + bs, err := json.Marshal(req) + if err != nil { + return nil, err + } + body := &bytes.Buffer{} + body.Write(bs) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, body) + if err != nil { + return nil, err + } + httpResp, err := s.Client.Do(httpReq) + if err != nil { + return nil, err + } + bs, err = ioutil.ReadAll(httpResp.Body) + resp := &{{$method.RespTypeName}}{} + err = json.Unmarshal(bs, resp) + return resp, err +} +{{end}} ` func Gen(writer io.Writer, def ServiceDefinition) error { diff --git a/advance/template/gen/main.go b/advance/template/gen/main.go index 27fc4eca54220daadd3ac4f357d4c24bc65a30de..36c445e461f2bc4d307f979bffc03065f612ad29 100644 --- a/advance/template/gen/main.go +++ b/advance/template/gen/main.go @@ -1,11 +1,21 @@ package main import ( + "bufio" + "bytes" + "errors" "fmt" - "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/annotation" - "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/http" + "go/ast" + "go/parser" + "go/token" "os" + "path" + "path/filepath" + "strings" "unicode" + + "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/annotation" + "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/http" ) // 实际上 main 函数这里要考虑接收参数 @@ -43,7 +53,25 @@ func gen(src string) error { // 根据 defs 来生成代码 // src 是源代码所在目录,在测试里面它是 ./testdata func genFiles(src string, defs []http.ServiceDefinition) error { - panic("implement me") + for _, def := range defs { + bs := &bytes.Buffer{} + err := http.Gen(bs, def) + file, err := os.OpenFile(path.Join(src, underscoreName(def.Name)+"_gen.go"), os.O_WRONLY|os.O_CREATE, 0666) + + if err != nil { + return err + } + defer file.Close() + + if err != nil { + fmt.Printf("open file error=%v\n", err) + return err + } + writer := bufio.NewWriter(file) + writer.Write(bs.Bytes()) + writer.Flush() + } + return nil } func parseFiles(srcFiles []string) ([]http.ServiceDefinition, error) { @@ -51,8 +79,15 @@ func parseFiles(srcFiles []string) ([]http.ServiceDefinition, error) { for _, src := range srcFiles { fmt.Println(src) // 你需要利用 annotation 里面的东西来扫描 src,然后生成 file - panic("implement me") - var file annotation.File + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, src, nil, parser.ParseComments) + if err != nil { + return nil, err + } + sfev := &annotation.SingleFileEntryVisitor{} + ast.Walk(sfev, f) + + file := sfev.Get() for _, typ := range file.Types { _, ok := typ.Annotations.Get("HttpClient") @@ -72,12 +107,87 @@ func parseFiles(srcFiles []string) ([]http.ServiceDefinition, error) { // 你需要利用 typ 来构造一个 http.ServiceDefinition // 注意你可能需要检测用户的定义是否符合你的预期 func parseServiceDefinition(pkg string, typ annotation.Type) (http.ServiceDefinition, error) { - panic("implement me") + result := &http.ServiceDefinition{ + Package: pkg, + } + for _, a := range typ.Annotations.Ans { + if a.Key == "ServiceName" { + result.Name = a.Value + } + } + if result.Name == "" { + result.Name = typ.Annotations.Node.Name.Name + } + method := &http.ServiceMethod{} + fields := typ.Fields + for _, field := range fields { + method.Name = field.Annotations.Node.Names[0].Name + + mAns := field.Annotations.Ans + for _, ma := range mAns { + if ma.Key == "Path" { + method.Path = ma.Value + } + } + if method.Path == "" { + method.Path = "/" + method.Name + } + params := field.Annotations.Node.Type.(*ast.FuncType).Params + if len(params.List) != 2 { + return *result, errors.New("gen: 方法必须接收两个参数,其中第一个参数是 context.Context,第二个参数请求") + } + method.ReqTypeName = params.List[1].Type.(*ast.StarExpr).X.(*ast.Ident).Name + results := field.Annotations.Node.Type.(*ast.FuncType).Results + if len(results.List) != 2 { + return *result, errors.New("gen: 方法必须返回两个参数,其中第一个返回值是响应,第二个返回值是error") + } + method.RespTypeName = results.List[0].Type.(*ast.StarExpr).X.(*ast.Ident).Name + result.Methods = append(result.Methods, *method) + } + return *result, nil } // 返回符合条件的 Go 源代码文件,也就是你要用 AST 来分析这些文件的代码 func scanFiles(src string) ([]string, error) { - panic("implement me") + srcFiles := make([]string, 0, 10) + files, err := os.ReadDir(src) + if err != nil { + return nil, err + } + for _, file := range files { + if strings.HasSuffix(file.Name(), ".go") && + !strings.HasSuffix(file.Name(), "_test.go") && + !strings.HasSuffix(file.Name(), "gen.go") { + src, err = filepath.Abs(src) + if err != nil { + return nil, err + } + srcFiles = append(srcFiles, filepath.Join(src, file.Name())) + } + } + return srcFiles, nil + //srcAbs, err := filepath.Abs(src) + //if err != nil { + // return nil, err + //} + //files := make([]string, 0) + //if err := filepath.Walk(src, func(filePath string, f os.FileInfo, err error) error { + // if f == nil { + // return err + // } + // if f.IsDir() { + // return nil + // } + // if strings.HasSuffix(f.Name(), ".go") && + // !strings.HasSuffix(f.Name(), "_test.go") && + // !strings.HasSuffix(f.Name(), "gen.go") { + // files = append(files, path.Join(srcAbs, filePath)) + // } + // return nil + //}); err != nil { + // return nil, err + //} + //return files, nil } // underscoreName 驼峰转字符串命名,在决定生成的文件名的时候需要这个方法 diff --git a/advance/template/gen/testdata/user_service_gen.txt b/advance/template/gen/testdata/user_service_gen.txt index 031e0a438766f2690d3c0974ab1c03e39c0606f4..35f7c4069a6df9aab5540b7a4bb3e4c24f244cb1 100644 --- a/advance/template/gen/testdata/user_service_gen.txt +++ b/advance/template/gen/testdata/user_service_gen.txt @@ -37,7 +37,7 @@ func (s *UserServiceGen) Get(ctx context.Context, req *GetUserReq) (*GetUserResp } func (s *UserServiceGen) Update(ctx context.Context, req *UpdateUserReq) (*UpdateUserResp, error) { - url := s.Endpoint + s.Path + "/user/update" + url := s.Endpoint + s.Path + "/Update" bs, err := json.Marshal(req) if err != nil { return nil, err diff --git a/orm/homework1/aggregate.go b/orm/homework1/aggregate.go index e34a8f7b9048768e005eb70d5abb393d9617f5c0..8bdcb1c3a44db34effc6a0a495dfd4c4919dc9e7 100644 --- a/orm/homework1/aggregate.go +++ b/orm/homework1/aggregate.go @@ -21,15 +21,27 @@ func (a Aggregate) As(alias string) Aggregate { // EQ 例如 C("id").Eq(12) func (a Aggregate) EQ(arg any) Predicate { - panic("implement me") + return Predicate{ + left: a, + op: opEQ, + right: exprOf(arg), + } } func (a Aggregate) LT(arg any) Predicate { - panic("implement me") + return Predicate{ + left: a, + op: opLT, + right: exprOf(arg), + } } func (a Aggregate) GT(arg any) Predicate { - panic("implement me") + return Predicate{ + left: a, + op: opGT, + right: exprOf(arg), + } } func Avg(c string) Aggregate { diff --git a/orm/homework1/select.go b/orm/homework1/select.go index b6295eb32b57200250d9ad3d700bd107d800eb4b..a1f4624a1cf794c24ad9738c8e80c38c5fac2d1e 100644 --- a/orm/homework1/select.go +++ b/orm/homework1/select.go @@ -67,7 +67,47 @@ func (s *Selector[T]) Build() (*Query, error) { } } - panic("implement me") + if len(s.orderBy) > 0 { + s.sb.WriteString(" ORDER BY ") + if err = s.buildOrderBy(); err != nil { + return nil, err + } + } + + if len(s.groupBy) > 0 { + s.sb.WriteString(" GROUP BY ") + for i, c := range s.groupBy { + if i > 0 { + s.sb.WriteByte(',') + } + if err = s.buildColumn(c.name, c.alias); err != nil { + return nil, err + } + } + } + + if len(s.having) > 0 { + s.sb.WriteString(" HAVING ") + if err = s.buildPredicates(s.having); err != nil { + return nil, err + } + } + + if s.limit > 0 { + s.sb.WriteString(" LIMIT ?") + s.addArgs(s.limit) + } + + if s.offset > 0 { + s.sb.WriteString(" OFFSET ?") + s.addArgs(s.offset) + } + + //panic("implement me") + + if s.having != nil { + + } s.sb.WriteString(";") return &Query{ @@ -159,7 +199,52 @@ func (s *Selector[T]) buildColumn(c string, alias string) error { } func (s *Selector[T]) buildExpression(e Expression) error { - panic("implement me") + if e == nil { + return nil + } + switch exp := e.(type) { + case Column: + return s.buildColumn(exp.name, exp.alias) + case Aggregate: + return s.buildAggregate(exp, false) + case value: + s.sb.WriteByte('?') + s.addArgs(exp.val) + case Predicate: + _, lp := exp.left.(Predicate) + if lp { + s.sb.WriteByte('(') + } + if err := s.buildExpression(exp.left); err != nil { + return err + } + if lp { + s.sb.WriteByte(')') + } + + // 可能只有左边 + if exp.op == "" { + return nil + } + + s.sb.WriteByte(' ') + s.sb.WriteString(exp.op.String()) + s.sb.WriteByte(' ') + + _, rp := exp.right.(Predicate) + if rp { + s.sb.WriteByte('(') + } + if err := s.buildExpression(exp.right); err != nil { + return err + } + if rp { + s.sb.WriteByte(')') + } + default: + return errs.NewErrUnsupportedExpressionType(exp) + } + return nil } // Where 用于构造 WHERE 查询条件。如果 ps 长度为 0,那么不会构造 WHERE 部分 @@ -271,9 +356,15 @@ type OrderBy struct { } func Asc(col string) OrderBy { - panic("implement me") + return OrderBy{ + col: col, + order: "ASC", + } } func Desc(col string) OrderBy { - panic("implement me") + return OrderBy{ + col: col, + order: "DESC", + } }