golang源码分析:go-mysql-server(1)

2023-09-06 19:25:09 浏览数 (1)

go-mysql-server是基于内存的mysql server,使用方法分下面五步,创建engine,设置root账户,初始化配置,初始化server,开启服务。具体如下:

代码语言:javascript复制
  engine := sqle.NewDefault(
    sql.NewDatabaseProvider(
      createTestDatabase(),
      information_schema.NewInformationSchemaDatabase(),
    ))
  engine.Analyzer.Catalog.MySQLDb.AddRootAccount()
  config := server.Config{
    Protocol: "tcp",
    Address:  "localhost:3306",
  }
  s, err := server.NewDefaultServer(config, engine)
  if err != nil {
    panic(err)
  }
  s.Start()

下面我们开始依次分析:

1,创建engine:

代码语言:javascript复制
engine := sqle.NewDefault(
    sql.NewDatabaseProvider(
      createTestDatabase(),
      information_schema.NewInformationSchemaDatabase(),
    ))

源码位于github.com/dolthub/go-mysql-server@v0.14.0/engine.go

代码语言:javascript复制
func NewDefault(pro sql.DatabaseProvider) *Engine {
  a := analyzer.NewDefault(pro)
  return New(a, nil)
}
代码语言:javascript复制
func New(a *analyzer.Analyzer, cfg *Config) *Engine {
  ls := sql.NewLockSubsystem()
  a.Catalog.RegisterFunction(emptyCtx, sql.FunctionN{
  Name: "version",
  Fn:   function.NewVersion(cfg.VersionPostfix),
  })
  a.Catalog.RegisterFunction(emptyCtx, function.GetLockingFuncs(ls)...)
  
  return &Engine{
    Analyzer:          a,
    MemoryManager:     sql.NewMemoryManager(sql.ProcessMemory),
    ProcessList:       NewProcessList(),
    LS:                ls,
    BackgroundThreads: sql.NewBackgroundThreads(),
    IsReadOnly:        cfg.IsReadOnly,
    IsServerLocked:    cfg.IsServerLocked,
    PreparedData:      make(map[uint32]PreparedData),
    mu:                &sync.Mutex{},
  }

Engine的定义如下:

代码语言:javascript复制
type Engine struct {
  Analyzer          *analyzer.Analyzer
  LS                *sql.LockSubsystem
  ProcessList       sql.ProcessList
  MemoryManager     *sql.MemoryManager
  BackgroundThreads *sql.BackgroundThreads
  IsReadOnly        bool
  IsServerLocked    bool
  PreparedData      map[uint32]PreparedData
  mu                *sync.Mutex
}

其中ProcessList处理逻辑位于github.com/dolthub/go-mysql-server@v0.14.0/processlist.go

代码语言:javascript复制
func NewProcessList() *ProcessList {
  return &ProcessList{
    procs: make(map[uint64]*sql.Process),
  }
}

github.com/dolthub/go-mysql-server@v0.14.0/sql/analyzer/analyzer.go

代码语言:javascript复制
func NewDefault(provider sql.DatabaseProvider) *Analyzer {
  return NewBuilder(provider).Build()
}
代码语言:javascript复制
func NewBuilder(pro sql.DatabaseProvider) *Builder {
  return &Builder{
    provider:        pro,

其中builder定义如下:

代码语言:javascript复制
type Builder struct {
  preAnalyzeRules     []Rule
  postAnalyzeRules    []Rule
  preValidationRules  []Rule
  postValidationRules []Rule
  onceBeforeRules     []Rule
  defaultRules        []Rule
  onceAfterRules      []Rule
  validationRules     []Rule
  afterAllRules       []Rule
  provider            sql.DatabaseProvider
  debug               bool
  parallelism         int
}  
代码语言:javascript复制
func (ab *Builder) Build() *Analyzer {
代码语言:javascript复制
type Analyzer struct {
  // Whether to log various debugging messages
  Debug bool
  // Whether to output the query plan at each step of the analyzer
  Verbose bool
  // A stack of debugger context. See PushDebugContext, PopDebugContext
  contextStack []string
  Parallelism  int
  // Batches of Rules to apply.
  Batches []*Batch
  // Catalog of databases and registered functions.
  Catalog *Catalog
}

github.com/dolthub/go-mysql-server@v0.14.0/sql/provider.go

代码语言:javascript复制
func NewDatabaseProvider(dbs ...Database) DatabaseProvider {
  dbMap := make(map[string]Database, len(dbs))
  for _, db := range dbs {
    dbMap[strings.ToLower(db.Name())] = db
  }
  return databaseProvider{
    dbs: dbMap,
    mu:  &sync.RWMutex{},
  }
}
代码语言:javascript复制
type databaseProvider struct {
  dbs map[string]Database
  mu  *sync.RWMutex
}

github.com/dolthub/go-mysql-server@v0.14.0/sql/core.go

代码语言:javascript复制
type DatabaseProvider interface {
  // Database gets a Database from the provider.
  Database(ctx *Context, name string) (Database, error)


  // HasDatabase checks if the Database exists in the provider.
  HasDatabase(ctx *Context, name string) bool


  // AllDatabases returns a slice of all Databases in the provider.
  AllDatabases(ctx *Context) []Database
}
代码语言:javascript复制
type RowInserter interface {
  TableEditor
  // Insert inserts the row given, returning an error if it cannot. Insert will be called once for each row to process
  // for the insert operation, which may involve many rows. After all rows in an operation have been processed, Close
  // is called.
  Insert(*Context, Row) error
  // Close finalizes the insert operation, persisting its result.
  Closer
}

初始化provider的参数就是我们自定义创建数据库的函数,返回的一个内存数据库对象

代码语言:javascript复制
  func createTestDatabase() *memory.Database {

github.com/dolthub/go-mysql-server@v0.14.0/memory/database.go

代码语言:javascript复制
type Database struct {
  *BaseDatabase
  views map[string]string
}
代码语言:javascript复制
type BaseDatabase struct {
  name              string
  tables            map[string]sql.Table
  fkColl            *ForeignKeyCollection
  triggers          []sql.TriggerDefinition
  storedProcedures  []sql.StoredProcedureDetails
  primaryKeyIndexes bool
  collation         sql.CollationID
}  
代码语言:javascript复制
func NewDatabase(name string) *Database {
  return &Database{
    BaseDatabase: NewViewlessDatabase(name),
    views:        make(map[string]string),
  }
}
代码语言:javascript复制
func NewViewlessDatabase(name string) *BaseDatabase {
  return &BaseDatabase{
    name:   name,
    tables: map[string]sql.Table{},
    fkColl: newForeignKeyCollection(),
  }
}  
代码语言:javascript复制
func (d *BaseDatabase) AddTable(name string, t sql.Table) {
  d.tables[name] = t
}

通过map实现根据表名定位表的数据。下面看下如何新建一个数据库:

代码语言:javascript复制
db := memory.NewDatabase(dbName)
table := memory.NewTable(tableName, sql.NewPrimaryKeySchema(sql.Schema{}), &memory.ForeignKeyCollection{})

github.com/dolthub/go-mysql-server@v0.14.0/memory/table.go

代码语言:javascript复制
func NewTable(name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection) *Table {
  return NewPartitionedTableWithCollation(name, schema, fkColl, 0, sql.Collation_Default)
}  
代码语言:javascript复制
func NewPartitionedTableWithCollation(name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection, numPartitions int, collation sql.CollationID) *Table {
   for i := 0; i < numPartitions; i   {
    key := strconv.Itoa(i)
    keys = append(keys, []byte(key))
    partitions[key] = []sql.Row{}
  }
    return &Table{
    name:          name,
    schema:        schema,
    fkColl:        fkColl,
    collation:     collation,
    partitions:    partitions,
    partitionKeys: keys,
    autoIncVal:    autoIncVal,
    autoColIdx:    autoIncIdx,
  }

其中table的定义如下:

代码语言:javascript复制
type Table struct {
  // Schema and related info
  name             string
  schema           sql.PrimaryKeySchema
  indexes          map[string]sql.Index
  fkColl           *ForeignKeyCollection
  checks           []sql.CheckDefinition
  collation        sql.CollationID
  pkIndexesEnabled bool


  // pushdown info
  filters         []sql.Expression // currently unused, filter pushdown is significantly broken right now
  projection      []string
  projectedSchema sql.Schema
  columns         []int


  // Data storage
  partitions    map[string][]sql.Row
  partitionKeys [][]byte


  // Insert bookkeeping
  insertPartIdx int


  // Indexed lookups
  lookup sql.DriverIndexLookup


  // AUTO_INCREMENT bookkeeping
  autoIncVal uint64
  autoColIdx int


  tableStats *TableStatistics
}

重点看下它的插入方法,看下如何插入数据。

代码语言:javascript复制
func (t *Table) Insert(ctx *sql.Context, row sql.Row) error {
  inserter := t.Inserter(ctx)
  if err := inserter.Insert(ctx, row); err != nil {
    return err
  }
  return inserter.Close(ctx)
}
代码语言:javascript复制
func (t *Table) Inserter(*sql.Context) sql.RowInserter {
  return t.newTableEditor()
}  
代码语言:javascript复制
func (t *Table) newTableEditor() *tableEditor {
   for _, idx := range t.indexes {
    if !idx.IsUnique() {
      continue
    }
    var colNames []string
    expressions := idx.(*Index).Exprs
    for _, exp := range expressions {
      colNames = append(colNames, exp.(*expression.GetField).Name())
    }
    colIdxs, err := t.columnIndexes(colNames)
    if err != nil {
      panic("failed to get column indexes")
    }
    uniqIdxCols = append(uniqIdxCols, colIdxs)
  }
  return &tableEditor{

github.com/dolthub/go-mysql-server@v0.14.0/sql/schema.go

代码语言:javascript复制
func NewPrimaryKeySchema(s Schema, pkOrds ...int) PrimaryKeySchema {
  if len(pkOrds) == 0 {
    pkOrds = make([]int, 0)
    for i, c := range s {
      if c.PrimaryKey {
        pkOrds = append(pkOrds, i)
      }
    }
  }
  return PrimaryKeySchema{Schema: s, PkOrdinals: pkOrds}
}

定义完table后进行数据插入:

代码语言:javascript复制
  db.AddTable(tableName, table)
  ctx := sql.NewEmptyContext()
  table.Insert(ctx, sql.NewRow("John Doe", "john@doe.com", sql.JSONDocument{Val: []string{"555-555-555"}}, time.Now()))

github.com/dolthub/go-mysql-server@v0.14.0/memory/table_editor.go

代码语言:javascript复制
type tableEditor struct {
  table             *Table
  initialAutoIncVal uint64
  initialPartitions map[string][]sql.Row
  ea                tableEditAccumulator
  initialInsert     int
  // array of key ordinals for each unique index defined on the table
  uniqueIdxCols [][]int
  fkTable       *Table
}  
代码语言:javascript复制
func (t *tableEditor) Insert(ctx *sql.Context, row sql.Row) error {
  if err := checkRow(t.table.schema.Schema, row); err != nil {
    t.table.verifyRowTypes(row)


  partitionRow, added, err := t.ea.Get(row)
          
  if added {
    pkColIdxes := t.pkColumnIndexes()
    return sql.NewUniqueKeyErr(formatRow(row, pkColIdxes), true, partitionRow)
          for _, cols := range t.uniqueIdxCols {
    if hasNullForAnyCols(row, cols) {
      continue
    }
    existing, found, err := t.ea.GetByCols(row, cols)
    if err != nil {
      return err
    }


    if found {
      return sql.NewUniqueKeyErr(formatRow(row, cols), false, existing)
    }
  }


  err = t.ea.Insert(row)

具体存储每一行数据还是通过map结构,key是唯一键组成的key,value就是行数据。

代码语言:javascript复制
func (pke *pkTableEditAccumulator) Insert(value sql.Row) error {
rowKey := pke.getRowKey(value)
delete(pke.deletes, rowKey)
pke.adds[rowKey] = value  
代码语言:javascript复制
func (pke *pkTableEditAccumulator) getRowKey(r sql.Row) string {
  var rowKey strings.Builder
  for _, i := range pke.table.schema.PkOrdinals {
    rowKey.WriteString(fmt.Sprintf("%v", r[i]))  
代码语言:javascript复制
type pkTableEditAccumulator struct {
  table   *Table
  adds    map[string]sql.Row
  deletes map[string]sql.Row
}

github.com/dolthub/go-mysql-server@v0.14.0/sql/row.go行数据定义如下:

代码语言:javascript复制
func NewRow(values ...interface{}) Row {
  row := make([]interface{}, len(values))
  copy(row, values)
  return row
}

接着看下,如何定义用户

代码语言:javascript复制
engine.Analyzer.Catalog.MySQLDb.AddRootAccount()

github.com/dolthub/go-mysql-server@v0.14.0/sql/mysql_db/mysql_db.go

代码语言:javascript复制
func (db *MySQLDb) AddRootAccount() {
  db.Enabled = true
  addSuperUser(db.user, "root", "localhost", "")
  db.clearCache()
}

github.com/dolthub/go-mysql-server@v0.14.0/sql/mysql_db/user_table.go

代码语言:javascript复制
func addSuperUser(userTable *mysqlTable, username string, host string, password string) {
   err := userTable.data.Put(sql.NewEmptyContext(), &User{
    User:                username,
    Host:                host,
    PrivilegeSet:        newPrivilegeSetWithAllPrivileges(),
    Plugin:              "mysql_native_password",

github.com/dolthub/go-mysql-server@v0.14.0/sql/analyzer/catalog.go

代码语言:javascript复制
type Catalog struct {
  MySQLDb *mysql_db.MySQLDb


  provider         sql.DatabaseProvider
  builtInFunctions function.Registry
  mu               sync.RWMutex
  locks            sessionLocks
}

定义完用户后,开始定义数据库的配置:

代码语言:javascript复制
config := server.Config{
    Protocol: "tcp",
    Address:  "localhost:3306",
  }

github.com/dolthub/go-mysql-server@v0.14.0/server/server_config.go

代码语言:javascript复制
type Config struct {
  // Protocol for the connection.
  Protocol string
  // Address of the server.
  Address string
  // Tracer to use in the server. By default, a noop tracer will be used if
  // no tracer is provided.
  Tracer trace.Tracer
  // Version string to advertise in running server
  Version string
  // ConnReadTimeout is the server's read timeout
  ConnReadTimeout time.Duration
  // ConnWriteTimeout is the server's write timeout
  ConnWriteTimeout time.Duration
  // MaxConnections is the maximum number of simultaneous connections that the server will allow.
  MaxConnections uint64
  // TLSConfig is the configuration for TLS on this server. If |nil|, TLS is not supported.
  TLSConfig *tls.Config
  // RequestSecureTransport will require incoming connections to be TLS. Requires non-|nil| TLSConfig.
  RequireSecureTransport bool
  // DisableClientMultiStatements will prevent processing of incoming
  // queries as if they contain more than one query. This processing
  // currently works in some simple cases, but breaks in the presence of
  // statements (such as in CREATE TRIGGER queries). Configuring the
  // server to disable processing these is one option for users to get
  // support back for single queries that contain statements, at the cost
  // of not supporting the CLIENT_MULTI_STATEMENTS option on the server.
  DisableClientMultiStatements bool
  // NoDefaults prevents using persisted configuration for new server sessions
  NoDefaults bool
  // Socket is a path to unix socket file
  Socket                   string
  AllowClearTextWithoutTLS bool
}
      type Server struct {
  Listener   *mysql.Listener
  handler    mysql.Handler
  sessionMgr *SessionManager
}

0 人点赞