grpc-go之参数验证(五)

2022-10-11 19:39:01 浏览数 (1)

介绍

参数验证是一个非常常用的场景, grpc-go中一般地我们会直接使用使用第三方插件go-proto-validators自动生成验证规则, 然后配合grpc-go的拦截器来实现参数验证的逻辑.

具体讲解前先安装一下go-proto-validators

代码语言:txt复制
go get github.com/mwitkow/go-proto-validators

案例介绍

拦截器定义

接下来定义参数验证的拦截器grpc_validator/validator.go

代码语言:go复制
package grpc_validator

import (
	"context"
	"fmt"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

// The validate interface starting with protoc-gen-validate v0.6.0.
// See https://github.com/envoyproxy/protoc-gen-validate/pull/455.
type validator interface {
	Validate(all bool) error
}

// The validate interface prior to protoc-gen-validate v0.6.0.
type validatorLegacy interface {
	Validate() error
}

func validate(req interface{}) error {
	switch v := req.(type) {
	case validatorLegacy:
		if err := v.Validate(); err != nil {
			return status.Error(codes.InvalidArgument, err.Error())
		}
	case validator:
		if err := v.Validate(false); err != nil {
			return status.Error(codes.InvalidArgument, err.Error())
		}
	}
	return nil
}

// UnaryServerInterceptor 接收到数据后需要检查参数是否合法
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		fmt.Println("参数验证...")
		if err := validate(req); err != nil {
			return nil, err
		}
		return handler(ctx, req)
	}
}

// UnaryClientInterceptor 发送数据前需要检查参数是否合法
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
		fmt.Println("参数验证...")
		if err := validate(req); err != nil {
			return err
		}
		return invoker(ctx, method, req, reply, cc, opts...)
	}
}

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
	return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		wrapper := &serverWrapper{stream}
		return handler(srv, wrapper)
	}
}

type serverWrapper struct {
	grpc.ServerStream
}

// RecvMsg 接收消息后, 应该先检查一下参数是否合法
func (s *serverWrapper) RecvMsg(m interface{}) error {
	if err := s.ServerStream.RecvMsg(m); err != nil {
		return err
	}

	if err := validate(m); err != nil {
		return err
	}

	return nil
}

// SendMsg 发送前, 应该先检查一下参数是否合法
func (s *serverWrapper) SendMsg(m interface{}) error {
	if err := validate(m); err != nil {
		return err
	}
	if err := s.ServerStream.SendMsg(m); err != nil {
		return err
	}
	return nil
}

// clientWrapper  用于包装 grpc.ClientStream 结构体并拦截其对应的方法。
type clientWrapper struct {
	grpc.ClientStream
}

func newWrappedClientStream(c grpc.ClientStream) grpc.ClientStream {
	return &clientWrapper{c}
}

func (c *clientWrapper) RecvMsg(m interface{}) error {
	if err := c.ClientStream.RecvMsg(m); err != nil {
		return err
	}
	if err := validate(m); err != nil {
		return err
	}

	return nil
}

func (c *clientWrapper) SendMsg(m interface{}) error {
	if err := validate(m); err != nil {
		return err
	}
	if err := c.ClientStream.SendMsg(m); err != nil {
		return err
	}
	return nil
}

// ClientStreamInterceptor
func ClientStreamInterceptor() grpc.StreamClientInterceptor {
	return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
		streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
		s, err := streamer(ctx, desc, cc, method, opts...)
		if err != nil {
			return nil, err
		}
		return newWrappedClientStream(s), nil
	}
}

自动生成go代码

前面只是定义了一下拦截器, 实际验证代码还需要通过工具生成, 先修改hello_world.pb文件

代码语言:text复制
//声明proto的版本 只有 proto3 才支持 gRPC
syntax = "proto3";
// 将编译后文件输出在 pb 目录
option go_package = "./pb";
// 指定当前proto文件属于hello world包
package pb;

import "github.com/mwitkow/go-proto-validators/validator.proto";
import "google/protobuf/descriptor.proto";

// 定义一个名叫 greeting 的服务
service Greeter {
  // 该服务包含一个 SayHello 方法 HelloRequest、HelloReply分别为该方法的输入与输出
  rpc SayHello (HelloRequest) returns (HelloReply) {}
}
// 具体的参数定义
message HelloRequest {
  string name = 1 [(validator.field) = {regex: "^[a-z]{2,16}$"}];
}

message HelloReply {
  string message = 1[(validator.field) = {regex: "^[a-z]{10,120}$"}];
}

语法具体可以参考validator.proto文件

前置检查

  • 需要进入$GOPATH/src/github.com/mwitkow/go-proto-validators执行下go mod tidy保证插件完全安装成功
  • 需要下载官方descriptor.pb并放到目录$GOPATH/src/google/protobuf/

执行命令

通过下面的命令生成go代码

代码语言:shell复制
protoc --proto_path=$GOPATH/src  --proto_path=. --go_out=. --go-grpc_out=.  --govalidators_out=. ./pb/hello_world.proto 

执行完会生成以下文件, 其中hello_world.validator.pb.go就是用来做参数验证的.

代码语言:txt复制
.
├── hello_world.pb.go
├── hello_world.proto
├── hello_world.validator.pb.go
├── hello_world_grpc.pb.go

配置拦截器

server/main.go

代码语言:go复制
s := grpc.NewServer(
		grpc.Creds(credentials.NewServerTLSFromCert(&cert)),
		grpc.ChainUnaryInterceptor(
			grpc_recovery.UnaryServerInterceptor(opts...),
			grpc_validator.UnaryServerInterceptor(),
			myEnsureValidToken),

		grpc.ChainStreamInterceptor(
			grpc_recovery.StreamServerInterceptor(opts...),
			grpc_validator.StreamServerInterceptor(),
			streamInterceptorAuth),
	)

client/main.go

代码语言:go复制
	conn, err := grpc.Dial(address, grpc.WithDefaultServiceConfig(retryPolicy,
		grpc.WithTransportCredentials(cred),
		grpc.WithPerRPCCredentials(userPwdAuth),
		grpc.WithPerRPCCredentials(oauthAuth),
		grpc.WithPerRPCCredentials(jwtAuth),
		grpc.WithChainUnaryInterceptor(grpc_validator.UnaryClientInterceptor()),
		grpc.WithStreamInterceptor(grpc_validator.ClientStreamInterceptor()),
	)
	

输出效果

服务端:

image.pngimage.png

客户端:

image.pngimage.png

0 人点赞