虽然golang写protoc插件已经足够简单了golang源码分析:自定义proto插件,插件进程从标准输入读取出CodeGeneratorRequest数据,将CodeGeneratorResponse数据写到标准输出。但是,我们需要自己遍历protoc生成的proto对应的抽象语法树,相对而言还是比较有难度的。
https://github.com/lyft/protoc-gen-star 对上述过程进行了很好的封装,这样我们就可以通过定义简单的Module实现我们在抽象语法树上的操作,转化出我们需要目标代码,在每个Module内部我们通过访问者模式来处理proto每一个语法单元,非常简单快捷。入口函数我们只需要指定Module和生成代码后的后处理函数即可
代码语言:javascript复制pgs.Init(
pgs.DebugEnv("DEBUG"),
).RegisterModule(
ASTPrinter(),
JSONify(),
).RegisterPostProcessor(
pgsgo.GoFmt(),
).Render()
其中 ASTPrinter(),JSONify(),就是我们的Module。后处理函数调用了go fmt来对代码进行了格式化。
代码语言:javascript复制type goFmt struct{}
func (p goFmt) Process(in []byte) ([]byte, error) { return format.Source(in) }
它最终调用了golang源码包里的format函数源码位于type goFmt struct{}
代码语言:javascript复制func Source(src []byte) ([]byte, error) {}
下面我们就结合protoc-gen-star的示例代码插件testdata/protoc-gen-example/main.go,来进行源码分析。
首先我们编译一下
代码语言:javascript复制% go build -o ../protoc-gen-mypgs ./testdata/protoc-gen-example/
% cp ../protoc-gen-mypgs $GOPATH/bin/
然后我们就可以用这个插件来进行代码生成,比如我们的proto定义如下:
代码语言:javascript复制syntax = "proto3";
package api;
option go_package = "api/v1;v1";
message HelloRequest {
string msg = 1;
}
我们执行命令
代码语言:javascript复制protoc
-I ../exp2
--mypgs_out="foo=bar:."
../exp2/test.proto
就可以看到我们的两个Module分别生成了两个文件
test.pb.json.go
代码语言:javascript复制package v1
import (
"bytes"
"encoding/json"
"github.com/golang/protobuf/jsonpb"
)
// HelloRequestJSONMarshaler describes the default jsonpb.Marshaler used by all
// instances of HelloRequest. This struct is safe to replace or modify but
// should not be done so concurrently.
var HelloRequestJSONMarshaler = new(jsonpb.Marshaler)
// MarshalJSON satisfies the encoding/json Marshaler interface. This method
// uses the more correct jsonpb package to correctly marshal the message.
func (m *HelloRequest) MarshalJSON() ([]byte, error) {
if m == nil {
return json.Marshal(nil)
}
buf := &bytes.Buffer{}
if err := HelloRequestJSONMarshaler.Marshal(buf, m); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
var _ json.Marshaler = (*HelloRequest)(nil)
// HelloRequestJSONUnmarshaler describes the default jsonpb.Unmarshaler used by all
// instances of HelloRequest. This struct is safe to replace or modify but
// should not be done so concurrently.
var HelloRequestJSONUnmarshaler = new(jsonpb.Unmarshaler)
// UnmarshalJSON satisfies the encoding/json Unmarshaler interface. This method
// uses the more correct jsonpb package to correctly unmarshal the message.
func (m *HelloRequest) UnmarshalJSON(b []byte) error {
return HelloRequestJSONUnmarshaler.Unmarshal(bytes.NewReader(b), m)
}
var _ json.Unmarshaler = (*HelloRequest)(nil)
test.tree.txt
代码语言:javascript复制┳ File: test.proto
┣┳ Message: HelloRequest
┃┣━ msg
前一个文件实现了json序列化和反序列化,后一个文件将文件,message和field三级结构用树形结构展示了出来。
回过头来我们看下 JSONifyModule的核心函数
代码语言:javascript复制func (p *JSONifyModule) Execute(targets map[string]pgs.File, pkgs map[string]pgs.Package) []pgs.Artifact {
for _, t := range targets {
p.generate(t)
}
return p.Artifacts()
}
它是通过模板文件生成对应的序列化和反序列化代码
代码语言:javascript复制func (p *JSONifyModule) generate(f pgs.File) {
if len(f.Messages()) == 0 {
return
}
name := p.ctx.OutputPath(f).SetExt(".json.go")
p.AddGeneratorTemplateFile(name.String(), p.tpl, f)
}
类似的PrinterModule
代码语言:javascript复制func (p *PrinterModule) Execute(targets map[string]pgs.File, packages map[string]pgs.Package) []pgs.Artifact {
buf := &bytes.Buffer{}
for _, f := range targets {
p.printFile(f, buf)
}
return p.Artifacts()
}
代码语言:javascript复制
func (p *PrinterModule) printFile(f pgs.File, buf *bytes.Buffer) {
p.Push(f.Name().String())
defer p.Pop()
buf.Reset()
v := initPrintVisitor(buf, "")
p.CheckErr(pgs.Walk(v, f), "unable to print AST tree")
out := buf.String()
if ok, _ := p.Parameters().Bool("log_tree"); ok {
p.Logf("Proto Tree:n%s", out)
}
p.AddGeneratorFile(
f.InputPath().SetExt(".tree.txt").String(),
out,
)
}
可以看到PrinterModule的核心函数里调用了pgs.Walk(v, f)方法,其实就是访问者模式。
谁在用pgs呢?protoc-gen-validate包就在使用,我们可以看下它生成go代码的插件的的入口函数cmd/protoc-gen-validate-go/main.go
代码语言:javascript复制
func main() {
optional := uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
pgs.
Init(pgs.DebugEnv("DEBUG_PGV"), pgs.SupportedFeatures(&optional)).
RegisterModule(module.ValidatorForLanguage("go")).
RegisterPostProcessor(pgsgo.GoFmt()).
Render()
}
它的核心Module是ValidatorForLanguage源码位module/validate.go
代码语言:javascript复制func ValidatorForLanguage(lang string) pgs.Module {
return &Module{lang: lang, ModuleBase: &pgs.ModuleBase{}}
}
它的Execute方法定义如下
代码语言:javascript复制func (m *Module) Execute(targets map[string]pgs.File, pkgs map[string]pgs.Package) []pgs.Artifact {
lang := m.lang
langParamValue := m.Parameters().Str(langParam)
if lang == "" {
lang = langParamValue
m.Assert(lang != "", "`lang` parameter must be set")
} else if langParamValue != "" {
m.Fail("unknown `lang` parameter")
}
module := m.Parameters().Str(moduleParam)
// Process file-level templates
tpls := templates.Template(m.Parameters())[lang]
m.Assert(tpls != nil, "could not find templates for `lang`: ", lang)
for _, f := range targets {
m.Push(f.Name().String())
for _, msg := range f.AllMessages() {
m.CheckRules(msg)
}
for _, tpl := range tpls {
out := templates.FilePathFor(tpl)(f, m.ctx, tpl)
// A nil path means no output should be generated for this file - as controlled by
// implementation-specific FilePathFor implementations.
// Ex: Don't generate Java validators for files that don't reference PGV.
if out != nil {
outPath := strings.TrimLeft(strings.ReplaceAll(filepath.ToSlash(out.String()), module, ""), "/")
if opts := f.Descriptor().GetOptions(); opts != nil && opts.GetJavaMultipleFiles() && lang == "java" {
// TODO: Only Java supports multiple file generation. If more languages add multiple file generation
// support, the implementation should be made more inderect.
for _, msg := range f.Messages() {
m.AddGeneratorTemplateFile(java.JavaMultiFilePath(f, msg).String(), tpl, msg)
}
} else {
m.AddGeneratorTemplateFile(outPath, tpl, f)
}
}
}
m.Pop()
}
return m.Artifacts()
}
渲染每个模板,生成目标代码。模板定义位于protoc-gen-validate/templates/go 已message.go为例
代码语言:javascript复制const messageTpl = `
{{ $f := .Field }}{{ $r := .Rules }}
{{ template "required" . }}
{{ if .MessageRules.GetSkip }}
// skipping validation for {{ $f.Name }}
{{ else }}
if all {
switch v := interface{}({{ accessor . }}).(type) {
case interface{ ValidateAll() error }:
if err := v.ValidateAll(); err != nil {
errors = append(errors, {{ errCause . "err" "embedded message failed validation" }})
}
case interface{ Validate() error }:
{{- /* Support legacy validation for messages that were generated with a plugin version prior to existence of ValidateAll() */ -}}
if err := v.Validate(); err != nil {
errors = append(errors, {{ errCause . "err" "embedded message failed validation" }})
}
}
} else if v, ok := interface{}({{ accessor . }}).(interface{ Validate() error }); ok {
if err := v.Validate(); err != nil {
return {{ errCause . "err" "embedded message failed validation" }}
}
}
{{ end }}
`