如何定义一个Fake server,接受客户端的请求,返回希望的结果,本质上是一个tcp server服务器,定义一个服务器过程如下:
代码语言:javascript复制 l, err := net.Listen("tcp", "127.0.0.1:4000")
c, err := l.Accept()
conn, err := server.NewConn(c, "root", "", server.EmptyHandler{})
for {
if err := conn.HandleCommand(); err != nil {
log.Fatal(err)
}
}
建立连接相关代码位于:github.com/go-mysql-org/go-mysql@v1.7.0/server/conn.go
代码语言:javascript复制func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, error) {
p := NewInMemoryProvider()
p.AddUser(user, password)
packetConn = packet.NewConn(conn)
c := &Conn{
Conn: packetConn,
serverConf: defaultServer,
credentialProvider: p,
h: h,
connectionID: atomic.AddUint32(&baseConnID, 1),
stmts: make(map[uint32]*Stmt),
salt: RandomBuf(20),
}
if err := c.handshake(); err != nil {
其中的InMemoryProvider是一个sync.Map
代码语言:javascript复制// implements a in memory credential provider
type InMemoryProvider struct {
userPool sync.Map // username -> password
}
创建连接的过程如下:
代码语言:javascript复制func NewConn(conn net.Conn) *Conn {
c := new(Conn)
c.Conn = conn
c.bufPool = NewBufPool()
c.br = bufio.NewReaderSize(c, 65536) // 64kb
c.reader = c.br
c.copyNBuf = make([]byte, 16*1024)
return c
}
返回的Conn代表了一个mysql连接:
代码语言:javascript复制type Conn struct {
*packet.Conn
serverConf *Server
capability uint32
charset uint8
authPluginName string
attributes map[string]string
connectionID uint32
status uint16
warnings uint16
salt []byte // should be 8 12 for auth-plugin-data-part-1 and auth-plugin-data-part-2
credentialProvider CredentialProvider
user string
password string
cachingSha2FullAuth bool
h Handler
stmts map[uint32]*Stmt
stmtID uint32
closed sync2.AtomicBool
}
接着就是握手
代码语言:javascript复制func (c *Conn) handshake() error {
if err := c.writeInitialHandshake(); err != nil {
return err
}
if err := c.readHandshakeResponse(); err != nil {
然后进入了连接的请求和返回的处理流程github.com/go-mysql-org/go-mysql@v1.7.0/server/command.go,处理请求前,我们需要注册请求的处理器,默认实现了一个空的处理器,以及一个空的复制请求处理器。
代码语言:javascript复制type EmptyHandler struct {
}
代码语言:javascript复制func (h EmptyHandler) UseDB(dbName string) error {
return nil
}
代码语言:javascript复制type EmptyReplicationHandler struct {
EmptyHandler
}
它们实现的接口定义如下,包含7个函数,应对处理mysql的7个命令:
代码语言:javascript复制type Handler interface {
//handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
UseDB(dbName string) error
//handle COM_QUERY command, like SELECT, INSERT, UPDATE, etc...
//If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the response, otherwise, we will send Result
HandleQuery(query string) (*Result, error)
//handle COM_FILED_LIST command
HandleFieldList(table string, fieldWildcard string) ([]*Field, error)
//handle COM_STMT_PREPARE, params is the param number for this statement, columns is the column number
//context will be used later for statement execute
HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error)
//handle COM_STMT_EXECUTE, context is the previous one set in prepare
//query is the statement prepare query, and args is the params for this statement
HandleStmtExecute(context interface{}, query string, args []interface{}) (*Result, error)
//handle COM_STMT_CLOSE, context is the previous one set in prepare
//this handler has no response
HandleStmtClose(context interface{}) error
//handle any other command that is not currently handled by the library,
//default implementation for this method will return an ER_UNKNOWN_ERROR
HandleOtherCommand(cmd byte, data []byte) error
}
复制请求处理器接口也包含了3个接口,注册slave,dump binlog和dump GTID
代码语言:javascript复制type ReplicationHandler interface {
// handle Replication command
HandleRegisterSlave(data []byte) error
HandleBinlogDump(pos Position) (*replication.BinlogStreamer, error)
HandleBinlogDumpGTID(gtidSet *MysqlGTIDSet) (*replication.BinlogStreamer, error)
}
请求处理循环是server的核心,上面的处理器函数,被镶嵌在这个处理循环中:
代码语言:javascript复制func (c *Conn) HandleCommand() error {
data, err := c.ReadPacket()
v := c.dispatch(data)
err = c.WriteValue(v)
if c.Conn != nil {
c.ResetSequence()
if err != nil {
c.Close()
它不断接受packet然后进行分发给处理器,最后将结果写入返回缓冲区。github.com/go-mysql-org/go-mysql@v1.7.0/packet/conn.go
代码语言:javascript复制func (c *Conn) ReadPacket() ([]byte, error) {
return c.ReadPacketReuseMem(nil)
}
数据会读入缓冲区:
代码语言:javascript复制func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
buf := utils.BytesBufferGet()
if err := c.ReadPacketTo(buf); err != nil {
if len(dst) > 0 {
result = append(dst, readBytes...)
分发的过程中,按照mysql的command来进行分发,然后调用处理器的函数来进行处理。
代码语言:javascript复制func (c *Conn) dispatch(data []byte) interface{} {
switch cmd {
case COM_QUIT:
c.Close()
c.Conn = nil
return noResponse{}
case COM_QUERY:
if r, err := c.h.HandleQuery(hack.String(data)); err != nil {
return err
} else {
return r
}
case COM_PING:
return nil
case COM_INIT_DB:
if err := c.h.UseDB(hack.String(data)); err != nil {
return err
} else {
return nil
}
case COM_FIELD_LIST:
index := bytes.IndexByte(data, 0x00)
table := hack.String(data[0:index])
wildcard := hack.String(data[index 1:])
if fs, err := c.h.HandleFieldList(table, wildcard); err != nil {
return err
} else {
return fs
}
case COM_STMT_PREPARE:
c.stmtID
st := new(Stmt)
st.ID = c.stmtID
st.Query = hack.String(data)
var err error
if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil {
return err
} else {
st.ResetParams()
c.stmts[c.stmtID] = st
return st
}
case COM_STMT_EXECUTE:
if r, err := c.handleStmtExecute(data); err != nil {
return err
} else {
return r
}
case COM_STMT_CLOSE:
if err := c.handleStmtClose(data); err != nil {
return err
}
return noResponse{}
case COM_STMT_SEND_LONG_DATA:
if err := c.handleStmtSendLongData(data); err != nil {
return err
}
return noResponse{}
case COM_STMT_RESET:
if r, err := c.handleStmtReset(data); err != nil {
return err
} else {
return r
}
case COM_SET_OPTION:
if err := c.h.HandleOtherCommand(cmd, data); err != nil {
return err
}
return eofResponse{}
case COM_REGISTER_SLAVE:
if h, ok := c.h.(ReplicationHandler); ok {
return h.HandleRegisterSlave(data)
} else {
return c.h.HandleOtherCommand(cmd, data)
}
case COM_BINLOG_DUMP:
if h, ok := c.h.(ReplicationHandler); ok {
pos, err := parseBinlogDump(data)
if err != nil {
return err
}
if s, err := h.HandleBinlogDump(pos); err != nil {
return err
} else {
return s
}
} else {
return c.h.HandleOtherCommand(cmd, data)
}
case COM_BINLOG_DUMP_GTID:
if h, ok := c.h.(ReplicationHandler); ok {
gtidSet, err := parseBinlogDumpGTID(data)
if err != nil {
return err
}
if s, err := h.HandleBinlogDumpGTID(gtidSet); err != nil {
return err
} else {
return s
}
} else {
return c.h.HandleOtherCommand(cmd, data)
}
default:
return c.h.HandleOtherCommand(cmd, data)
最后来到了,结果的返回流程:github.com/go-mysql-org/go-mysql@v1.7.0/server/resp.go,根据不同的数据类型进行序列化,返回输出缓冲区:
代码语言:javascript复制 func (c *Conn) WriteValue(value interface{}) error {
switch v := value.(type) {
case noResponse:
return nil
case eofResponse:
return c.writeEOF()
case error:
return c.writeError(v)
case nil:
return c.writeOK(nil)
case *Result:
if v != nil && v.Resultset != nil {
return c.writeResultset(v.Resultset)
} else {
return c.writeOK(v)
}
case []*Field:
return c.writeFieldList(v, nil)
case []FieldValue:
return c.writeFieldValues(v)
case *replication.BinlogStreamer:
return c.writeBinlogEvents(v)
case *Stmt:
return c.writePrepare(v)
default:
return fmt.Errorf("invalid response type %T", value)
}
}
github.com/go-mysql-org/go-mysql@v1.7.0/server/stmt.go
代码语言:javascript复制func (c *Conn) writePrepare(s *Stmt) error {
data := make([]byte, 4, 128)
//status ok
data = append(data, 0)
//stmt id
data = append(data, Uint32ToBytes(s.ID)...)
//number columns
data = append(data, Uint16ToBytes(uint16(s.Columns))...)
//number params
data = append(data, Uint16ToBytes(uint16(s.Params))...)
//filter [00]
data = append(data, 0)
//warning count
data = append(data, 0, 0)
if err := c.WritePacket(data); err != nil {
return err
}
if s.Params > 0 {
for i := 0; i < s.Params; i {
data = data[0:4]
data = append(data, paramFieldData...)
if err := c.WritePacket(data); err != nil {
return errors.Trace(err)
}
}
if err := c.writeEOF(); err != nil {
return err
}
}
if s.Columns > 0 {
for i := 0; i < s.Columns; i {
data = data[0:4]
data = append(data, columnFieldData...)
if err := c.WritePacket(data); err != nil {
return errors.Trace(err)
}
}
if err := c.writeEOF(); err != nil {
return err
}
}
return nil
}