golang源码分析:protoc-gen-star protoc-gen-validate

2023-03-14 20:47:08 浏览数 (2)

虽然golang写protoc插件已经足够简单了golang源码分析:自定义proto插件,插件进程从标准输入读取出CodeGeneratorRequest数据,将CodeGeneratorResponse数据写到标准输出。但是,我们需要自己遍历protoc生成的proto对应的抽象语法树,相对而言还是比较有难度的。

https://github.com/lyft/protoc-gen-star 对上述过程进行了很好的封装,这样我们就可以通过定义简单的Module实现我们在抽象语法树上的操作,转化出我们需要目标代码,在每个Module内部我们通过访问者模式来处理proto每一个语法单元,非常简单快捷。入口函数我们只需要指定Module和生成代码后的后处理函数即可

代码语言:javascript复制
pgs.Init(
    pgs.DebugEnv("DEBUG"),
  ).RegisterModule(
    ASTPrinter(),
    JSONify(),
  ).RegisterPostProcessor(
    pgsgo.GoFmt(),
  ).Render()

其中 ASTPrinter(),JSONify(),就是我们的Module。后处理函数调用了go fmt来对代码进行了格式化。

代码语言:javascript复制
type goFmt struct{}
func (p goFmt) Process(in []byte) ([]byte, error) { return format.Source(in) }

它最终调用了golang源码包里的format函数源码位于type goFmt struct{}

代码语言:javascript复制
func Source(src []byte) ([]byte, error) {}

下面我们就结合protoc-gen-star的示例代码插件testdata/protoc-gen-example/main.go,来进行源码分析。

首先我们编译一下

代码语言:javascript复制
%  go build  -o ../protoc-gen-mypgs ./testdata/protoc-gen-example/
% cp ../protoc-gen-mypgs $GOPATH/bin/

然后我们就可以用这个插件来进行代码生成,比如我们的proto定义如下:

代码语言:javascript复制
syntax = "proto3";

package api;

option go_package = "api/v1;v1";

message HelloRequest {
    string msg = 1;
}

我们执行命令

代码语言:javascript复制
protoc 
  -I ../exp2 
  --mypgs_out="foo=bar:." 
  ../exp2/test.proto

就可以看到我们的两个Module分别生成了两个文件

test.pb.json.go

代码语言:javascript复制
package v1

import (
  "bytes"
  "encoding/json"

  "github.com/golang/protobuf/jsonpb"
)

// HelloRequestJSONMarshaler describes the default jsonpb.Marshaler used by all
// instances of HelloRequest. This struct is safe to replace or modify but
// should not be done so concurrently.
var HelloRequestJSONMarshaler = new(jsonpb.Marshaler)

// MarshalJSON satisfies the encoding/json Marshaler interface. This method
// uses the more correct jsonpb package to correctly marshal the message.
func (m *HelloRequest) MarshalJSON() ([]byte, error) {
  if m == nil {
    return json.Marshal(nil)
  }

  buf := &bytes.Buffer{}
  if err := HelloRequestJSONMarshaler.Marshal(buf, m); err != nil {
    return nil, err
  }

  return buf.Bytes(), nil
}

var _ json.Marshaler = (*HelloRequest)(nil)

// HelloRequestJSONUnmarshaler describes the default jsonpb.Unmarshaler used by all
// instances of HelloRequest. This struct is safe to replace or modify but
// should not be done so concurrently.
var HelloRequestJSONUnmarshaler = new(jsonpb.Unmarshaler)

// UnmarshalJSON satisfies the encoding/json Unmarshaler interface. This method
// uses the more correct jsonpb package to correctly unmarshal the message.
func (m *HelloRequest) UnmarshalJSON(b []byte) error {
  return HelloRequestJSONUnmarshaler.Unmarshal(bytes.NewReader(b), m)
}

var _ json.Unmarshaler = (*HelloRequest)(nil)

test.tree.txt

代码语言:javascript复制
┳ File: test.proto
┣┳ Message: HelloRequest
┃┣━ msg

前一个文件实现了json序列化和反序列化,后一个文件将文件,message和field三级结构用树形结构展示了出来。

回过头来我们看下 JSONifyModule的核心函数

代码语言:javascript复制
func (p *JSONifyModule) Execute(targets map[string]pgs.File, pkgs map[string]pgs.Package) []pgs.Artifact {

  for _, t := range targets {
    p.generate(t)
  }

  return p.Artifacts()
}

它是通过模板文件生成对应的序列化和反序列化代码

代码语言:javascript复制
func (p *JSONifyModule) generate(f pgs.File) {
  if len(f.Messages()) == 0 {
    return
  }

  name := p.ctx.OutputPath(f).SetExt(".json.go")
  p.AddGeneratorTemplateFile(name.String(), p.tpl, f)
}

类似的PrinterModule

代码语言:javascript复制
func (p *PrinterModule) Execute(targets map[string]pgs.File, packages map[string]pgs.Package) []pgs.Artifact {
  buf := &bytes.Buffer{}

  for _, f := range targets {
    p.printFile(f, buf)
  }

  return p.Artifacts()
}
代码语言:javascript复制

func (p *PrinterModule) printFile(f pgs.File, buf *bytes.Buffer) {
  p.Push(f.Name().String())
  defer p.Pop()

  buf.Reset()
  v := initPrintVisitor(buf, "")
  p.CheckErr(pgs.Walk(v, f), "unable to print AST tree")

  out := buf.String()

  if ok, _ := p.Parameters().Bool("log_tree"); ok {
    p.Logf("Proto Tree:n%s", out)
  }

  p.AddGeneratorFile(
    f.InputPath().SetExt(".tree.txt").String(),
    out,
  )
}

可以看到PrinterModule的核心函数里调用了pgs.Walk(v, f)方法,其实就是访问者模式。

谁在用pgs呢?protoc-gen-validate包就在使用,我们可以看下它生成go代码的插件的的入口函数cmd/protoc-gen-validate-go/main.go

代码语言:javascript复制

func main() {
  optional := uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
  pgs.
    Init(pgs.DebugEnv("DEBUG_PGV"), pgs.SupportedFeatures(&optional)).
    RegisterModule(module.ValidatorForLanguage("go")).
    RegisterPostProcessor(pgsgo.GoFmt()).
    Render()
}

它的核心Module是ValidatorForLanguage源码位module/validate.go

代码语言:javascript复制
func ValidatorForLanguage(lang string) pgs.Module {
  return &Module{lang: lang, ModuleBase: &pgs.ModuleBase{}}
}

它的Execute方法定义如下

代码语言:javascript复制
func (m *Module) Execute(targets map[string]pgs.File, pkgs map[string]pgs.Package) []pgs.Artifact {
  lang := m.lang
  langParamValue := m.Parameters().Str(langParam)
  if lang == "" {
    lang = langParamValue
    m.Assert(lang != "", "`lang` parameter must be set")
  } else if langParamValue != "" {
    m.Fail("unknown `lang` parameter")
  }

  module := m.Parameters().Str(moduleParam)

  // Process file-level templates
  tpls := templates.Template(m.Parameters())[lang]
  m.Assert(tpls != nil, "could not find templates for `lang`: ", lang)

  for _, f := range targets {
    m.Push(f.Name().String())

    for _, msg := range f.AllMessages() {
      m.CheckRules(msg)
    }

    for _, tpl := range tpls {
      out := templates.FilePathFor(tpl)(f, m.ctx, tpl)

      // A nil path means no output should be generated for this file - as controlled by
      // implementation-specific FilePathFor implementations.
      // Ex: Don't generate Java validators for files that don't reference PGV.
      if out != nil {
        outPath := strings.TrimLeft(strings.ReplaceAll(filepath.ToSlash(out.String()), module, ""), "/")

        if opts := f.Descriptor().GetOptions(); opts != nil && opts.GetJavaMultipleFiles() && lang == "java" {
          // TODO: Only Java supports multiple file generation. If more languages add multiple file generation
          // support, the implementation should be made more inderect.
          for _, msg := range f.Messages() {
            m.AddGeneratorTemplateFile(java.JavaMultiFilePath(f, msg).String(), tpl, msg)
          }
        } else {
          m.AddGeneratorTemplateFile(outPath, tpl, f)
        }
      }
    }

    m.Pop()
  }

  return m.Artifacts()
}

渲染每个模板,生成目标代码。模板定义位于protoc-gen-validate/templates/go 已message.go为例

代码语言:javascript复制
const messageTpl = `
  {{ $f := .Field }}{{ $r := .Rules }}
  {{ template "required" . }}

  {{ if .MessageRules.GetSkip }}
    // skipping validation for {{ $f.Name }}
  {{ else }}
    if all {
      switch v := interface{}({{ accessor . }}).(type) {
        case interface{ ValidateAll() error }:
          if err := v.ValidateAll(); err != nil {
            errors = append(errors, {{ errCause . "err" "embedded message failed validation" }})
          }
        case interface{ Validate() error }:
          {{- /* Support legacy validation for messages that were generated with a plugin version prior to existence of ValidateAll() */ -}}
          if err := v.Validate(); err != nil {
            errors = append(errors, {{ errCause . "err" "embedded message failed validation" }})
          }
      }
    } else if v, ok := interface{}({{ accessor . }}).(interface{ Validate() error }); ok {
      if err := v.Validate(); err != nil {
        return {{ errCause . "err" "embedded message failed validation" }}
      }
    }
  {{ end }}
`

0 人点赞