插件与扩展
第十二章:插件与扩展
12.1 插件机制概述
GORM 提供了强大的插件机制,允许开发者在数据库操作的不同阶段注入自定义逻辑。
12.2 常用官方插件
Prometheus 监控
import "github.com/wei840222/gorm-prom"
db.Use(gormprom.New(gormprom.Config{
DBName: "myapp", // 数据库名称
StartServer: true, // 启动 HTTP 服务
HTTPServerPort: 8080, // 服务端口
}))
// 访问 http://localhost:8080/metrics 查看指标
乐观锁
import "gorm.io/plugin/optimisticlock"
type Product struct {
ID uint
Name string
Version optimisticlock.Version // 版本号
Stock int
}
// 更新时自动检查版本号
result := db.Model(&product).Update("stock", product.Stock-1)
if result.RowsAffected == 0 {
// 版本冲突,需要重试
}
读写分离
import "gorm.io/plugin/dbresolver"
db.Use(dbresolver.Register(dbresolver.Config{
Sources: []gorm.Dialector{mysql.Open("write_dsn")},
Replicas: []gorm.Dialector{
mysql.Open("read1_dsn"),
mysql.Open("read2_dsn"),
},
Policy: dbresolver.RandomPolicy{},
// TraceResolverMode: true, // 打印使用的数据源
}))
// 指定使用主库
db.Clauses(dbresolver.Write).First(&user)
12.3 编写自定义插件
插件结构
type Plugin interface {
Name() string
Initialize(*gorm.DB) error
}
完整插件示例
package myplugin
import (
"gorm.io/gorm"
)
type MyPlugin struct{}
func (p *MyPlugin) Name() string {
return "my_plugin"
}
func (p *MyPlugin) Initialize(db *gorm.DB) error {
// 注册回调
db.Callback().Create().Before("gorm:create").Register("my_plugin:before_create", beforeCreate)
db.Callback().Create().After("gorm:create").Register("my_plugin:after_create", afterCreate)
db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
db.Callback().Delete().Before("gorm:delete").Register("my_plugin:before_delete", beforeDelete)
return nil
}
func beforeCreate(db *gorm.DB) {
// 创建前的逻辑
if db.Error != nil {
return
}
// 实现逻辑
}
func afterCreate(db *gorm.DB) {
// 创建后的逻辑
}
func afterQuery(db *gorm.DB) {
// 查询后的逻辑
}
func beforeUpdate(db *gorm.DB) {
// 更新前的逻辑
}
func beforeDelete(db *gorm.DB) {
// 删除前的逻辑
}
// 使用
db.Use(&MyPlugin{})
12.4 审计日志插件
type AuditPlugin struct {
LogChannel chan AuditLog
}
type AuditLog struct {
Action string
Table string
RecordID interface{}
OldData interface{}
NewData interface{}
UserID uint
Timestamp time.Time
}
func (p *AuditPlugin) Name() string {
return "audit_plugin"
}
func (p *AuditPlugin) Initialize(db *gorm.DB) error {
// 创建前记录
db.Callback().Create().After("gorm:create").Register("audit:create", func(db *gorm.DB) {
p.log(db, "CREATE", nil, db.Statement.ReflectValue.Interface())
})
// 更新前查询旧数据
db.Callback().Update().Before("gorm:update").Register("audit:update:before", func(db *gorm.DB) {
if db.Statement.Schema != nil {
var oldData interface{}
db.Session(&gorm.Session{}).First(&oldData, db.Statement.Dest)
db.Statement.Set("audit:old_data", oldData)
}
})
db.Callback().Update().After("gorm:update").Register("audit:update:after", func(db *gorm.DB) {
oldData, _ := db.Statement.Get("audit:old_data")
p.log(db, "UPDATE", oldData, db.Statement.ReflectValue.Interface())
})
// 删除前记录
db.Callback().Delete().Before("gorm:delete").Register("audit:delete", func(db *gorm.DB) {
p.log(db, "DELETE", db.Statement.ReflectValue.Interface(), nil)
})
return nil
}
func (p *AuditPlugin) log(db *gorm.DB, action string, oldData, newData interface{}) {
userID, _ := db.Statement.Context.Value("userID").(uint)
log := AuditLog{
Action: action,
Table: db.Statement.Table,
RecordID: db.Statement.PrimaryKeyValue,
OldData: oldData,
NewData: newData,
UserID: userID,
Timestamp: time.Now(),
}
select {
case p.LogChannel <- log:
default:
// 通道满时丢弃或记录警告
}
}
12.5 软删除增强插件
type SoftDeletePlugin struct {
Field string // 删除标记字段名
TimeField string // 删除时间字段名
UserField string // 删除人字段名
}
func (p *SoftDeletePlugin) Name() string {
return "soft_delete"
}
func (p *SoftDeletePlugin) Initialize(db *gorm.DB) error {
// 替换默认删除行为
db.Callback().Delete().Replace("gorm:delete", func(db *gorm.DB) {
if db.Error != nil {
return
}
// 检查是否是软删除模型
if !db.Statement.Schema.FieldsByDBName[p.Field] {
// 不是软删除模型,执行物理删除
db.Exec("DELETE FROM ? WHERE ?", db.Statement.Table, db.Statement.Where)
return
}
// 执行软删除
updates := map[string]interface{}{
p.Field: 1,
p.TimeField: time.Now(),
}
if userID := db.Statement.Context.Value("userID"); userID != nil {
updates[p.UserField] = userID
}
db.Model(db.Statement.ReflectValue.Addr().Interface()).Updates(updates)
})
return nil
}
12.6 查询缓存插件
type CachePlugin struct {
redis *redis.Client
ttl time.Duration
}
func (p *CachePlugin) Name() string {
return "query_cache"
}
func (p *CachePlugin) Initialize(db *gorm.DB) error {
// 查询前检查缓存
db.Callback().Query().Before("gorm:query").Register("cache:before_query", func(db *gorm.DB) {
cacheKey := p.generateCacheKey(db)
if data, err := p.redis.Get(db.Statement.Context, cacheKey).Result(); err == nil {
// 缓存命中
if err := json.Unmarshal([]byte(data), db.Statement.Dest); err == nil {
db.Statement.SkipHooks = true // 跳过后续查询
}
}
})
// 查询后写入缓存
db.Callback().Query().After("gorm:query").Register("cache:after_query", func(db *gorm.DB) {
if db.Error != nil {
return
}
cacheKey := p.generateCacheKey(db)
data, _ := json.Marshal(db.Statement.Dest)
p.redis.Set(db.Statement.Context, cacheKey, data, p.ttl)
})
// 更新后清除缓存
db.Callback().Update().After("gorm:update").Register("cache:invalidate", func(db *gorm.DB) {
p.invalidateCache(db.Statement.Table, db.Statement.PrimaryKeyValue)
})
return nil
}
func (p *CachePlugin) generateCacheKey(db *gorm.DB) string {
// 基于表名、SQL 和参数生成缓存键
return fmt.Sprintf("gorm:%s:%x", db.Statement.Table, md5.Sum([]byte(db.Statement.SQL.String())))
}
12.7 字段加密插件
type EncryptPlugin struct {
key []byte
}
func (p *EncryptPlugin) Name() string {
return "field_encrypt"
}
func (p *EncryptPlugin) Initialize(db *gorm.DB) error {
// 创建前加密
db.Callback().Create().Before("gorm:create").Register("encrypt:before_create", p.encrypt)
db.Callback().Update().Before("gorm:update").Register("encrypt:before_update", p.encrypt)
// 查询后解密
db.Callback().Query().After("gorm:query").Register("encrypt:after_query", p.decrypt)
return nil
}
func (p *EncryptPlugin) encrypt(db *gorm.DB) {
if db.Statement.Schema == nil {
return
}
for _, field := range db.Statement.Schema.Fields {
if tag := field.Tag.Get("encrypt"); tag == "true" {
// 获取字段值并加密
if value, ok := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); ok {
if str, ok := value.(string); ok && str != "" {
encrypted, _ := p.aesEncrypt(str)
field.Set(db.Statement.Context, db.Statement.ReflectValue, encrypted)
}
}
}
}
}
func (p *EncryptPlugin) decrypt(db *gorm.DB) {
// 解密逻辑类似
}
12.8 插件组合使用
func setupDB(dsn string) (*gorm.DB, error) {
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
return nil, err
}
// 启用多个插件
db.Use(gormprom.New(gormprom.Config{DBName: "myapp"}))
db.Use(dbresolver.Register(dbresolver.Config{...}))
db.Use(&AuditPlugin{LogChannel: make(chan AuditLog, 100)})
db.Use(&CachePlugin{redis: redisClient, ttl: 5 * time.Minute})
return db, nil
}
12.9 最佳实践
1. 插件顺序
// 注意插件执行顺序
db.Use(plugin1) // 先注册的先执行
db.Use(plugin2)
2. 错误处理
func (p *MyPlugin) Initialize(db *gorm.DB) error {
// 所有回调都要检查 db.Error
db.Callback().Create().Before("gorm:create").Register("my_plugin", func(db *gorm.DB) {
if db.Error != nil {
return
}
// 业务逻辑
})
return nil
}
3. 性能考虑
// 避免在回调中执行耗时操作
db.Callback().Query().After("gorm:query").Register("my_plugin", func(db *gorm.DB) {
// 不要在这里:
// - 调用外部 HTTP 服务
// - 执行复杂计算
// - 大量数据库操作
// 可以:
// - 简单的数据转换
// - 发送异步消息
// - 更新内存缓存
})
12.10 练习题
- 编写一个租户隔离插件,自动为所有查询添加 tenant_id 过滤
- 实现一个 SQL 注入检测插件,拦截可疑查询
- 创建一个数据脱敏插件,对敏感字段自动脱敏
12.11 小结
本章介绍了 GORM 的插件机制,包括常用插件的使用和自定义插件的开发。插件是扩展 GORM 功能的强大方式,合理使用可以大大简化业务代码。
本文代码地址:https://github.com/LittleMoreInteresting/gorm_study
欢迎关注公众号,一起学习进步!