99 lines
2.3 KiB
Go
99 lines
2.3 KiB
Go
|
// Package validator use for valiate struct
|
||
|
package validator
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/go-playground/locales/en"
|
||
|
"github.com/go-playground/locales/zh"
|
||
|
ut "github.com/go-playground/universal-translator"
|
||
|
"github.com/go-playground/validator/v10"
|
||
|
en_translations "github.com/go-playground/validator/v10/translations/en"
|
||
|
zh_translations "github.com/go-playground/validator/v10/translations/zh"
|
||
|
"github.com/samber/lo"
|
||
|
|
||
|
"git.ifooth.com/common/pkg/i18n"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
validate *validator.Validate
|
||
|
uni *ut.UniversalTranslator
|
||
|
defaultTrans ut.Translator
|
||
|
zhTrans ut.Translator
|
||
|
)
|
||
|
|
||
|
// ValidationError 校验错误
|
||
|
type ValidationError struct {
|
||
|
ctx context.Context
|
||
|
rawErr error
|
||
|
}
|
||
|
|
||
|
// Validator 实现了 Validate 接口自定义调用
|
||
|
type Validator interface {
|
||
|
Validate() error
|
||
|
}
|
||
|
|
||
|
func (e *ValidationError) getTranslator() ut.Translator {
|
||
|
acceptLang, ok := i18n.GetLang(e.ctx)
|
||
|
if ok && acceptLang == "zh" {
|
||
|
return zhTrans
|
||
|
}
|
||
|
return defaultTrans
|
||
|
}
|
||
|
|
||
|
// Error error iface
|
||
|
func (e *ValidationError) Error() string {
|
||
|
if _, ok := e.rawErr.(*validator.InvalidValidationError); ok {
|
||
|
return e.rawErr.Error()
|
||
|
}
|
||
|
|
||
|
errs := e.rawErr.(validator.ValidationErrors)
|
||
|
// 只返回单个错误
|
||
|
for _, ve := range errs {
|
||
|
return ve.Translate(e.getTranslator())
|
||
|
}
|
||
|
|
||
|
return e.rawErr.Error()
|
||
|
}
|
||
|
|
||
|
// Struct 通过 validate tag 校验结构体, Validate 校验需要传入指针类型
|
||
|
func Struct(ctx context.Context, s any) error {
|
||
|
err := validate.StructCtx(ctx, s)
|
||
|
if err != nil {
|
||
|
return &ValidationError{ctx: ctx, rawErr: err}
|
||
|
}
|
||
|
|
||
|
// 实现了 Validate 接口自定义调用
|
||
|
if v, ok := s.(Validator); ok {
|
||
|
return v.Validate()
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// tagNameFunc 优先从 json tag 获取名称
|
||
|
func tagNameFunc(fld reflect.StructField) string {
|
||
|
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||
|
if name == "-" {
|
||
|
return ""
|
||
|
}
|
||
|
return name
|
||
|
}
|
||
|
|
||
|
func init() {
|
||
|
validate = validator.New(validator.WithRequiredStructEnabled())
|
||
|
validate.RegisterTagNameFunc(tagNameFunc)
|
||
|
|
||
|
// 默认使用英文
|
||
|
en := en.New()
|
||
|
zh := zh.New()
|
||
|
uni = ut.New(en, en, zh)
|
||
|
defaultTrans, _ = uni.GetTranslator("en")
|
||
|
lo.Must0(en_translations.RegisterDefaultTranslations(validate, defaultTrans))
|
||
|
|
||
|
zhTrans, _ = uni.GetTranslator("zh")
|
||
|
lo.Must0(zh_translations.RegisterDefaultTranslations(validate, zhTrans))
|
||
|
}
|