DBResolver

gorm.io/plugin/dbresolver@ v1.1.0/dbresolver.go

type DBResolver struct {
   *gorm.DB
   configs          []Config
   resolvers        map[string]*resolver
   global           *resolver
   prepareStmtStore map[gorm.ConnPool]*gorm.PreparedStmtDB
   compileCallbacks []func(gorm.ConnPool) error
}

func (dr *DBResolver) Name() string {
   return "gorm:db_resolver"
}

func (dr *DBResolver) Initialize(db *gorm.DB) error {
   dr.DB = db
   dr.registerCallbacks(db)
   return dr.compile()
}

func Register(config Config, datas ...interface{}) *DBResolver {
   return (&DBResolver{}).Register(config, datas...)
}

func (dr *DBResolver) Register(config Config, datas ...interface{}) *DBResolver {
   if dr.prepareStmtStore == nil {
      dr.prepareStmtStore = map[gorm.ConnPool]*gorm.PreparedStmtDB{}
   }

   if dr.resolvers == nil {
      dr.resolvers = map[string]*resolver{}
   }

   if config.Policy == nil {
      config.Policy = RandomPolicy{}
   }

   config.datas = datas
   dr.configs = append(dr.configs, config)
   if dr.DB != nil {
      dr.compileConfig(config)
   }
   return dr
}

DBResolver定义了resolvers;DBResolver实现了Plugin接口的Name、Initialize;Initialize方法执行了dr.registerCallbacks(db)、dr.compile()

// Plugin GORM plugin interface
type Plugin interface {
   Name() string
   Initialize(*DB) error
}

registerCallbacks

func (dr *DBResolver) registerCallbacks(db *gorm.DB) {
   dr.Callback().Create().Before("*").Register("gorm:db_resolver", dr.switchSource)
   dr.Callback().Query().Before("*").Register("gorm:db_resolver", dr.switchReplica)
   dr.Callback().Update().Before("*").Register("gorm:db_resolver", dr.switchSource)
   dr.Callback().Delete().Before("*").Register("gorm:db_resolver", dr.switchSource)
   dr.Callback().Row().Before("*").Register("gorm:db_resolver", dr.switchReplica)
   dr.Callback().Raw().Before("*").Register("gorm:db_resolver", dr.switchGuess)
}

registerCallbacks方法针对Create、Update、Delete方法注册了dr.switchSource;针对Query、Row注册了dr.switchReplica

switchSource

func (dr *DBResolver) switchSource(db *gorm.DB) {
   if !isTransaction(db.Statement.ConnPool) {
      db.Statement.ConnPool = dr.resolve(db.Statement, Write)
   }
}

func isTransaction(connPool gorm.ConnPool) bool {
   _, ok := connPool.(gorm.TxCommitter)
   return ok
}

switchSource方法在当前连接没有开启事务时执行dr.resolve(db.Statement, Write)

switchReplica

const writeName = "gorm:db_resolver:write"

func (dr *DBResolver) switchReplica(db *gorm.DB) {
   if !isTransaction(db.Statement.ConnPool) {
      if rawSQL := db.Statement.SQL.String(); len(rawSQL) > 0 {
         dr.switchGuess(db)
      } else {
         _, locking := db.Statement.Clauses["FOR"]
         if _, ok := db.Statement.Clauses[writeName]; ok || locking {
            db.Statement.ConnPool = dr.resolve(db.Statement, Write)
         } else {
            db.Statement.ConnPool = dr.resolve(db.Statement, Read)
         }
      }
   }
}

switchReplica方法在当前连接没有开启事务时,在rawSQL长度大于0时执行switchGuess,否则判断是否有for语句,若tag有指定write或者语句有for加锁则执行dr.resolve(db.Statement, Write),否则执行dr.resolve(db.Statement, Read)

switchGuess

func (dr *DBResolver) switchGuess(db *gorm.DB) {
   if !isTransaction(db.Statement.ConnPool) {
      if _, ok := db.Statement.Clauses[writeName]; ok {
         db.Statement.ConnPool = dr.resolve(db.Statement, Write)
      } else if rawSQL := strings.TrimSpace(db.Statement.SQL.String()); len(rawSQL) > 10 && strings.EqualFold(rawSQL[:6], "select") && !strings.EqualFold(rawSQL[len(rawSQL)-10:], "for update") {
         db.Statement.ConnPool = dr.resolve(db.Statement, Read)
      } else {
         db.Statement.ConnPool = dr.resolve(db.Statement, Write)
      }
   }
}

resolve

func (dr *DBResolver) resolve(stmt *gorm.Statement, op Operation) gorm.ConnPool {
   if len(dr.resolvers) > 0 {
      if u, ok := stmt.Clauses[usingName].Expression.(using); ok && u.Use != "" {
         if r, ok := dr.resolvers[u.Use]; ok {
            return r.resolve(stmt, op)
         }
      }

      if stmt.Table != "" {
         if r, ok := dr.resolvers[stmt.Table]; ok {
            return r.resolve(stmt, op)
         }
      }

      if stmt.Schema != nil {
         if r, ok := dr.resolvers[stmt.Schema.Table]; ok {
            return r.resolve(stmt, op)
         }
      }

      if rawSQL := stmt.SQL.String(); rawSQL != "" {
         if r, ok := dr.resolvers[getTableFromRawSQL(rawSQL)]; ok {
            return r.resolve(stmt, op)
         }
      }
   }

   if dr.global != nil {
      return dr.global.resolve(stmt, op) //
   }

   return stmt.ConnPool
}

resolve方法查找对应的resolver执行,没有的话使用dr.global

compile

func (dr *DBResolver) compile() error {
   for _, config := range dr.configs {
      if err := dr.compileConfig(config); err != nil {
         return err
      }
   }
   return nil
}

func (dr *DBResolver) compileConfig(config Config) (err error) {
   var (
      connPool = dr.DB.Config.ConnPool
      r        = resolver{
         dbResolver: dr,
         policy:     config.Policy,
      }
   )

   if preparedStmtDB, ok := connPool.(*gorm.PreparedStmtDB); ok {
      connPool = preparedStmtDB.ConnPool
   }

   if len(config.Sources) == 0 {
      r.sources = []gorm.ConnPool{connPool}
   } else if r.sources, err = dr.convertToConnPool(config.Sources); err != nil {
      return err
   }

   if len(config.Replicas) == 0 {
      r.replicas = r.sources
   } else if r.replicas, err = dr.convertToConnPool(config.Replicas); err != nil {
      return err
   }

   if len(config.datas) > 0 {
      for _, data := range config.datas {
         if t, ok := data.(string); ok {
            dr.resolvers[t] = &r
         } else {
            stmt := &gorm.Statement{DB: dr.DB}
            if err := stmt.Parse(data); err == nil {
               dr.resolvers[stmt.Table] = &r
            } else {
               return err
            }
         }
      }
   } else if dr.global == nil {
      dr.global = &r // dr.global赋值
   } else {
      return errors.New("conflicted global resolver")
   }

   for _, fc := range dr.compileCallbacks {
      if err = r.call(fc); err != nil {
         return err
      }
   }

   return nil
}

作者:xc_oo
链接:https://juejin.cn/post/7186912576830177338
来源:稀土掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

compile方法遍历dr.configs,挨个执行dr.compileConfig(config),它会使用config.Policy创建resolver

resolver

gorm.io/plugin/dbresolver@v1.1.0/resolver.go

type resolver struct {
   sources    []gorm.ConnPool
   replicas   []gorm.ConnPool
   policy     Policy
   dbResolver *DBResolver
}

作者:xc_oo
链接:https://juejin.cn/post/7186912576830177338
来源:稀土掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

resolve

func (r *resolver) resolve(stmt *gorm.Statement, op Operation) (connPool gorm.ConnPool) {
   if op == Read {
      if len(r.replicas) == 1 {
         connPool = r.replicas[0]
      } else {
         connPool = r.policy.Resolve(r.replicas)
      }
   } else if len(r.sources) == 1 {
      connPool = r.sources[0]
   } else {
      connPool = r.policy.Resolve(r.sources)
   }

   if stmt.DB.PrepareStmt {
      if preparedStmt, ok := r.dbResolver.prepareStmtStore[connPool]; ok {
         return &gorm.PreparedStmtDB{
            ConnPool: connPool,
            Mux:      preparedStmt.Mux,
            Stmts:    preparedStmt.Stmts,
         }
      }
   }

   return
}

resolver的resolve在Operation为Read的时候,会使用r.replicas,若只有1个replica则直接返回,若有多个则使用r.policy.Resolve(r.replicas)选一个;若Operation为write时,判断sources,若只有一个sources,则直接返回,若有多个source则通过r.policy.Resolve(r.sources)选择

Policy

gorm.io/plugin/dbresolver@ v1.1.0/policy.go

type Policy interface {
   Resolve([]gorm.ConnPool) gorm.ConnPool
}

type RandomPolicy struct {
}

func (RandomPolicy) Resolve(connPools []gorm.ConnPool) gorm.ConnPool {
   return connPools[rand.Intn(len(connPools))]
}

Policy接口定义了Resolve方法来选取数据源,默认提供了RandomPolicy,随机选取。

gorm.io/gorm@ v1.23.8/gorm.go

// Use use plugin
func (db *DB) Use(plugin Plugin) error {
   name := plugin.Name()
   if _, ok := db.Plugins[name]; ok {
      return ErrRegistered
   }
   if err := plugin.Initialize(db); err != nil {
      return err
   }
   db.Plugins[name] = plugin
   return nil
}

实例

func dbResolverDemo() {
    db, _ := gorm.Open(mysql.Open("master_dsn"), &gorm.Config{})
    dbResolverCfg := dbresolver.Config{
        Sources:  []gorm.Dialector{mysql.Open("master_dsn")},
        Replicas: []gorm.Dialector{mysql.Open("replica_a_dsn"), mysql.Open("replica_b_dsn")},
        Policy:   dbresolver.RandomPolicy{}}
    readWritePlugin := dbresolver.Register(dbResolverCfg)
    db.Use(readWritePlugin) // 使用 DBResolver 这个插件
}

小结

gorm的dbresolver实现了Plugin接口,

  • 它针对Create、Update、Delete方法注册了dr.switchSource;
  • 针对Query、Row注册了dr.switchReplica;
  • switchSource及switchReplica方法在当前连接没有开启事务时动态判断是否Operation是Read还是Write,开启事务时执行dr.resolve(db.Statement, Write);
  • resolver的resolve根据Operation来进行数据源的切换。

作者:xc_oo
链接:https://juejin.cn/post/7186912576830177338
来源:稀土掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

最后编辑: Simon  文档更新时间: 2024-10-13 15:47   作者:Simon