pkg/validator/validator.go

105 lines
2.4 KiB
Go

// Package validator use for valiate struct
package validator
import (
"context"
"reflect"
"github.com/go-playground/locales/en"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
en_translations "github.com/go-playground/validator/v10/translations/en"
zh_translations "github.com/go-playground/validator/v10/translations/zh"
"github.com/samber/lo"
"git.ifooth.com/common/pkg/i18n"
"git.ifooth.com/common/pkg/util"
)
var (
validate *validator.Validate
uni *ut.UniversalTranslator
defaultTrans ut.Translator
zhTrans ut.Translator
)
// ValidationError 校验错误
type ValidationError struct {
ctx context.Context
rawErr error
}
// Validator 实现了 Validate 接口自定义调用
type Validator interface {
Validate() error
}
func (e *ValidationError) getTranslator() ut.Translator {
acceptLang, ok := i18n.GetLang(e.ctx)
if ok && acceptLang == "zh" {
return zhTrans
}
return defaultTrans
}
// Error error iface
func (e *ValidationError) Error() string {
if _, ok := e.rawErr.(*validator.InvalidValidationError); ok {
return e.rawErr.Error()
}
errs := e.rawErr.(validator.ValidationErrors)
// 只返回单个错误
for _, ve := range errs {
return ve.Translate(e.getTranslator())
}
return e.rawErr.Error()
}
// Struct 通过 validate tag 校验结构体, Validate 校验需要传入指针类型
func Struct(ctx context.Context, s any) error {
err := validate.StructCtx(ctx, s)
if err != nil {
return &ValidationError{ctx: ctx, rawErr: err}
}
// 实现了 Validate 接口自定义调用
if v, ok := s.(Validator); ok {
return v.Validate()
}
return nil
}
// readableTagName 返回可读的json/req校验字段名称, 唯一性由codec校验
func readableTagName(field reflect.StructField) string {
name := util.GetTagName(field, "json")
if name != "" && name != "-" {
return name
}
name = util.GetTagName(field, "req")
if name != "" && name != "-" {
return name
}
return ""
}
func init() {
validate = validator.New(validator.WithRequiredStructEnabled())
validate.RegisterTagNameFunc(readableTagName)
// 默认使用英文
en := en.New()
zh := zh.New()
uni = ut.New(en, en, zh)
defaultTrans, _ = uni.GetTranslator("en")
lo.Must0(en_translations.RegisterDefaultTranslations(validate, defaultTrans))
zhTrans, _ = uni.GetTranslator("zh")
lo.Must0(zh_translations.RegisterDefaultTranslations(validate, zhTrans))
}