插件与扩展


第十二章:插件与扩展

12.1 插件机制概述

GORM 提供了强大的插件机制,允许开发者在数据库操作的不同阶段注入自定义逻辑。

12.2 常用官方插件

Prometheus 监控

import "github.com/wei840222/gorm-prom"

db.Use(gormprom.New(gormprom.Config{
    DBName:    "myapp",           // 数据库名称
    StartServer: true,            // 启动 HTTP 服务
    HTTPServerPort: 8080,         // 服务端口
}))

// 访问 http://localhost:8080/metrics 查看指标

乐观锁

import "gorm.io/plugin/optimisticlock"

type Product struct {
    ID      uint
    Name    string
    Version optimisticlock.Version  // 版本号
    Stock   int
}

// 更新时自动检查版本号
result := db.Model(&product).Update("stock", product.Stock-1)
if result.RowsAffected == 0 {
    // 版本冲突,需要重试
}

读写分离

import "gorm.io/plugin/dbresolver"

db.Use(dbresolver.Register(dbresolver.Config{
    Sources: []gorm.Dialector{mysql.Open("write_dsn")},
    Replicas: []gorm.Dialector{
        mysql.Open("read1_dsn"),
        mysql.Open("read2_dsn"),
    },
    Policy: dbresolver.RandomPolicy{},
    // TraceResolverMode: true,  // 打印使用的数据源
}))

// 指定使用主库
db.Clauses(dbresolver.Write).First(&user)

12.3 编写自定义插件

插件结构

type Plugin interface {
    Name() string
    Initialize(*gorm.DB) error
}

完整插件示例

package myplugin

import (
    "gorm.io/gorm"
)

type MyPlugin struct{}

func (p *MyPlugin) Name() string {
    return "my_plugin"
}

func (p *MyPlugin) Initialize(db *gorm.DB) error {
    // 注册回调
    db.Callback().Create().Before("gorm:create").Register("my_plugin:before_create", beforeCreate)
    db.Callback().Create().After("gorm:create").Register("my_plugin:after_create", afterCreate)
    db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
    db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)
    db.Callback().Delete().Before("gorm:delete").Register("my_plugin:before_delete", beforeDelete)
    
    return nil
}

func beforeCreate(db *gorm.DB) {
    // 创建前的逻辑
    if db.Error != nil {
        return
    }
    // 实现逻辑
}

func afterCreate(db *gorm.DB) {
    // 创建后的逻辑
}

func afterQuery(db *gorm.DB) {
    // 查询后的逻辑
}

func beforeUpdate(db *gorm.DB) {
    // 更新前的逻辑
}

func beforeDelete(db *gorm.DB) {
    // 删除前的逻辑
}

// 使用
db.Use(&MyPlugin{})

12.4 审计日志插件

type AuditPlugin struct {
    LogChannel chan AuditLog
}

type AuditLog struct {
    Action    string
    Table     string
    RecordID  interface{}
    OldData   interface{}
    NewData   interface{}
    UserID    uint
    Timestamp time.Time
}

func (p *AuditPlugin) Name() string {
    return "audit_plugin"
}

func (p *AuditPlugin) Initialize(db *gorm.DB) error {
    // 创建前记录
    db.Callback().Create().After("gorm:create").Register("audit:create", func(db *gorm.DB) {
        p.log(db, "CREATE", nil, db.Statement.ReflectValue.Interface())
    })
    
    // 更新前查询旧数据
    db.Callback().Update().Before("gorm:update").Register("audit:update:before", func(db *gorm.DB) {
        if db.Statement.Schema != nil {
            var oldData interface{}
            db.Session(&gorm.Session{}).First(&oldData, db.Statement.Dest)
            db.Statement.Set("audit:old_data", oldData)
        }
    })
    
    db.Callback().Update().After("gorm:update").Register("audit:update:after", func(db *gorm.DB) {
        oldData, _ := db.Statement.Get("audit:old_data")
        p.log(db, "UPDATE", oldData, db.Statement.ReflectValue.Interface())
    })
    
    // 删除前记录
    db.Callback().Delete().Before("gorm:delete").Register("audit:delete", func(db *gorm.DB) {
        p.log(db, "DELETE", db.Statement.ReflectValue.Interface(), nil)
    })
    
    return nil
}

func (p *AuditPlugin) log(db *gorm.DB, action string, oldData, newData interface{}) {
    userID, _ := db.Statement.Context.Value("userID").(uint)
    
    log := AuditLog{
        Action:    action,
        Table:     db.Statement.Table,
        RecordID:  db.Statement.PrimaryKeyValue,
        OldData:   oldData,
        NewData:   newData,
        UserID:    userID,
        Timestamp: time.Now(),
    }
    
    select {
    case p.LogChannel <- log:
    default:
        // 通道满时丢弃或记录警告
    }
}

12.5 软删除增强插件

type SoftDeletePlugin struct {
    Field     string  // 删除标记字段名
    TimeField string  // 删除时间字段名
    UserField string  // 删除人字段名
}

func (p *SoftDeletePlugin) Name() string {
    return "soft_delete"
}

func (p *SoftDeletePlugin) Initialize(db *gorm.DB) error {
    // 替换默认删除行为
    db.Callback().Delete().Replace("gorm:delete", func(db *gorm.DB) {
        if db.Error != nil {
            return
        }
        
        // 检查是否是软删除模型
        if !db.Statement.Schema.FieldsByDBName[p.Field] {
            // 不是软删除模型,执行物理删除
            db.Exec("DELETE FROM ? WHERE ?", db.Statement.Table, db.Statement.Where)
            return
        }
        
        // 执行软删除
        updates := map[string]interface{}{
            p.Field:     1,
            p.TimeField: time.Now(),
        }
        
        if userID := db.Statement.Context.Value("userID"); userID != nil {
            updates[p.UserField] = userID
        }
        
        db.Model(db.Statement.ReflectValue.Addr().Interface()).Updates(updates)
    })
    
    return nil
}

12.6 查询缓存插件

type CachePlugin struct {
    redis *redis.Client
    ttl   time.Duration
}

func (p *CachePlugin) Name() string {
    return "query_cache"
}

func (p *CachePlugin) Initialize(db *gorm.DB) error {
    // 查询前检查缓存
    db.Callback().Query().Before("gorm:query").Register("cache:before_query", func(db *gorm.DB) {
        cacheKey := p.generateCacheKey(db)
        
        if data, err := p.redis.Get(db.Statement.Context, cacheKey).Result(); err == nil {
            // 缓存命中
            if err := json.Unmarshal([]byte(data), db.Statement.Dest); err == nil {
                db.Statement.SkipHooks = true  // 跳过后续查询
            }
        }
    })
    
    // 查询后写入缓存
    db.Callback().Query().After("gorm:query").Register("cache:after_query", func(db *gorm.DB) {
        if db.Error != nil {
            return
        }
        
        cacheKey := p.generateCacheKey(db)
        data, _ := json.Marshal(db.Statement.Dest)
        p.redis.Set(db.Statement.Context, cacheKey, data, p.ttl)
    })
    
    // 更新后清除缓存
    db.Callback().Update().After("gorm:update").Register("cache:invalidate", func(db *gorm.DB) {
        p.invalidateCache(db.Statement.Table, db.Statement.PrimaryKeyValue)
    })
    
    return nil
}

func (p *CachePlugin) generateCacheKey(db *gorm.DB) string {
    // 基于表名、SQL 和参数生成缓存键
    return fmt.Sprintf("gorm:%s:%x", db.Statement.Table, md5.Sum([]byte(db.Statement.SQL.String())))
}

12.7 字段加密插件

type EncryptPlugin struct {
    key []byte
}

func (p *EncryptPlugin) Name() string {
    return "field_encrypt"
}

func (p *EncryptPlugin) Initialize(db *gorm.DB) error {
    // 创建前加密
    db.Callback().Create().Before("gorm:create").Register("encrypt:before_create", p.encrypt)
    db.Callback().Update().Before("gorm:update").Register("encrypt:before_update", p.encrypt)
    
    // 查询后解密
    db.Callback().Query().After("gorm:query").Register("encrypt:after_query", p.decrypt)
    
    return nil
}

func (p *EncryptPlugin) encrypt(db *gorm.DB) {
    if db.Statement.Schema == nil {
        return
    }
    
    for _, field := range db.Statement.Schema.Fields {
        if tag := field.Tag.Get("encrypt"); tag == "true" {
            // 获取字段值并加密
            if value, ok := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); ok {
                if str, ok := value.(string); ok && str != "" {
                    encrypted, _ := p.aesEncrypt(str)
                    field.Set(db.Statement.Context, db.Statement.ReflectValue, encrypted)
                }
            }
        }
    }
}

func (p *EncryptPlugin) decrypt(db *gorm.DB) {
    // 解密逻辑类似
}

12.8 插件组合使用

func setupDB(dsn string) (*gorm.DB, error) {
    db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
    if err != nil {
        return nil, err
    }
    
    // 启用多个插件
    db.Use(gormprom.New(gormprom.Config{DBName: "myapp"}))
    db.Use(dbresolver.Register(dbresolver.Config{...}))
    db.Use(&AuditPlugin{LogChannel: make(chan AuditLog, 100)})
    db.Use(&CachePlugin{redis: redisClient, ttl: 5 * time.Minute})
    
    return db, nil
}

12.9 最佳实践

1. 插件顺序

// 注意插件执行顺序
db.Use(plugin1)  // 先注册的先执行
db.Use(plugin2)

2. 错误处理

func (p *MyPlugin) Initialize(db *gorm.DB) error {
    // 所有回调都要检查 db.Error
    db.Callback().Create().Before("gorm:create").Register("my_plugin", func(db *gorm.DB) {
        if db.Error != nil {
            return
        }
        // 业务逻辑
    })
    return nil
}

3. 性能考虑

// 避免在回调中执行耗时操作
db.Callback().Query().After("gorm:query").Register("my_plugin", func(db *gorm.DB) {
    // 不要在这里:
    // - 调用外部 HTTP 服务
    // - 执行复杂计算
    // - 大量数据库操作
    
    // 可以:
    // - 简单的数据转换
    // - 发送异步消息
    // - 更新内存缓存
})

12.10 练习题

  1. 编写一个租户隔离插件,自动为所有查询添加 tenant_id 过滤
  2. 实现一个 SQL 注入检测插件,拦截可疑查询
  3. 创建一个数据脱敏插件,对敏感字段自动脱敏

12.11 小结

本章介绍了 GORM 的插件机制,包括常用插件的使用和自定义插件的开发。插件是扩展 GORM 功能的强大方式,合理使用可以大大简化业务代码。


本文代码地址:https://github.com/LittleMoreInteresting/gorm_study

欢迎关注公众号,一起学习进步!

如有疑问关注公众号给我留言
wx

关注公众号

©2017-2023 鲁ICP备17023316号-1 Powered by Hugo