package etcd import ( "bytes" "context" "encoding/json" "fmt" "log" "sync" "time" "github.com/RichardKnop/machinery/v2/backends/iface" "github.com/RichardKnop/machinery/v2/common" "github.com/RichardKnop/machinery/v2/config" "github.com/RichardKnop/machinery/v2/tasks" clientv3 "go.etcd.io/etcd/client/v3" "golang.org/x/sync/errgroup" ) type etcdBackend struct { common.Backend ctx context.Context conf *config.Config client *clientv3.Client } func New(ctx context.Context, conf *config.Config) (iface.Backend, error) { etcdConf := clientv3.Config{ Endpoints: []string{conf.ResultBackend}, Context: ctx, DialTimeout: time.Second * 5, TLS: conf.TLSConfig, } client, err := clientv3.New(etcdConf) if err != nil { return nil, err } backend := etcdBackend{ Backend: common.NewBackend(conf), ctx: ctx, client: client, conf: conf, } return &backend, nil } // Group related functions func (b *etcdBackend) InitGroup(groupUUID string, taskUUIDs []string) error { groupMeta := &tasks.GroupMeta{ GroupUUID: groupUUID, TaskUUIDs: taskUUIDs, CreatedAt: time.Now().UTC(), } encoded, err := json.Marshal(groupMeta) if err != nil { return err } ctx, cancel := context.WithTimeout(b.ctx, time.Second*2) defer cancel() key := fmt.Sprintf("/machinery/v2/backend/%s", groupUUID) _, err = b.client.KV.Put(ctx, key, string(encoded)) return err } func (b *etcdBackend) GroupCompleted(groupUUID string, groupTaskCount int) (bool, error) { groupMeta, err := b.getGroupMeta(groupUUID) if err != nil { return false, err } taskStates, err := b.getStates(groupMeta.TaskUUIDs...) if err != nil { return false, err } var countSuccessTasks = 0 for _, taskState := range taskStates { if taskState.IsCompleted() { countSuccessTasks++ } } return countSuccessTasks == groupTaskCount, nil } func (b *etcdBackend) getGroupMeta(groupUUID string) (*tasks.GroupMeta, error) { key := fmt.Sprintf("/machinery/v2/backend/%s", groupUUID) resp, err := b.client.Get(b.ctx, key) if err != nil { return nil, err } if len(resp.Kvs) == 0 { return nil, fmt.Errorf("task %s not exist", groupUUID) } kv := resp.Kvs[0] meta := new(tasks.GroupMeta) decoder := json.NewDecoder(bytes.NewReader(kv.Value)) decoder.UseNumber() if err := decoder.Decode(meta); err != nil { return nil, err } return meta, nil } func (b *etcdBackend) GroupTaskStates(groupUUID string, groupTaskCount int) ([]*tasks.TaskState, error) { groupMeta, err := b.getGroupMeta(groupUUID) if err != nil { return []*tasks.TaskState{}, err } return b.getStates(groupMeta.TaskUUIDs...) } func (b *etcdBackend) TriggerChord(groupUUID string) (bool, error) { return false, fmt.Errorf("not support") } // Setting / getting task state // SetStatePending updates task state to PENDING func (b *etcdBackend) SetStatePending(signature *tasks.Signature) error { taskState := tasks.NewPendingTaskState(signature) return b.updateState(taskState) } // SetStateReceived updates task state to RECEIVED func (b *etcdBackend) SetStateReceived(signature *tasks.Signature) error { taskState := tasks.NewReceivedTaskState(signature) b.mergeNewTaskState(taskState) return b.updateState(taskState) } // SetStateStarted updates task state to STARTED func (b *etcdBackend) SetStateStarted(signature *tasks.Signature) error { taskState := tasks.NewStartedTaskState(signature) b.mergeNewTaskState(taskState) return b.updateState(taskState) } // SetStateRetry updates task state to RETRY func (b *etcdBackend) SetStateRetry(signature *tasks.Signature) error { taskState := tasks.NewRetryTaskState(signature) b.mergeNewTaskState(taskState) return b.updateState(taskState) } // SetStateSuccess updates task state to SUCCESS func (b *etcdBackend) SetStateSuccess(signature *tasks.Signature, results []*tasks.TaskResult) error { taskState := tasks.NewSuccessTaskState(signature, results) b.mergeNewTaskState(taskState) return b.updateState(taskState) } // SetStateFailure updates task state to FAILURE func (b *etcdBackend) SetStateFailure(signature *tasks.Signature, err string) error { taskState := tasks.NewFailureTaskState(signature, err) b.mergeNewTaskState(taskState) return b.updateState(taskState) } func (b *etcdBackend) GetState(taskUUID string) (*tasks.TaskState, error) { return b.getState(b.ctx, taskUUID) } func (b *etcdBackend) getState(ctx context.Context, taskUUID string) (*tasks.TaskState, error) { key := fmt.Sprintf("/machinery/v2/backend/%s", taskUUID) resp, err := b.client.Get(b.ctx, key) if err != nil { return nil, err } if len(resp.Kvs) == 0 { return nil, fmt.Errorf("task %s not exist", taskUUID) } kv := resp.Kvs[0] state := new(tasks.TaskState) decoder := json.NewDecoder(bytes.NewReader(kv.Value)) decoder.UseNumber() if err := decoder.Decode(state); err != nil { return nil, err } return state, nil } // Purging stored stored tasks states and group meta data func (b *etcdBackend) IsAMQP() bool { return false } func (b *etcdBackend) mergeNewTaskState(newState *tasks.TaskState) { state, err := b.GetState(newState.TaskUUID) if err == nil { newState.CreatedAt = state.CreatedAt newState.TaskName = state.TaskName } } func (b *etcdBackend) PurgeState(taskUUID string) error { key := fmt.Sprintf("/machinery/v2/backend/%s", taskUUID) _, err := b.client.KV.Delete(b.ctx, key) return err } func (b *etcdBackend) PurgeGroupMeta(groupUUID string) error { key := fmt.Sprintf("/machinery/v2/backend/%s", groupUUID) _, err := b.client.KV.Delete(b.ctx, key) return err } // getStates returns multiple task states func (b *etcdBackend) getStates(taskUUIDs ...string) ([]*tasks.TaskState, error) { eg, ctx := errgroup.WithContext(b.ctx) eg.SetLimit(10) taskStates := make([]*tasks.TaskState, 0, len(taskUUIDs)) var mtx sync.Mutex for _, taskUUID := range taskUUIDs { t := taskUUID eg.Go(func() error { state, err := b.getState(ctx, t) if err != nil { return err } mtx.Lock() taskStates = append(taskStates, state) mtx.Unlock() return nil }) } if err := eg.Wait(); err != nil { return nil, err } return taskStates, nil } // updateState saves current task state func (b *etcdBackend) updateState(taskState *tasks.TaskState) error { encoded, err := json.Marshal(taskState) if err != nil { return err } key := fmt.Sprintf("/machinery/v2/backend/%s", taskState.TaskUUID) _, err = b.client.Put(b.ctx, key, string(encoded)) if err != nil { return err } log.Default().Printf("update taskstate %s %s, %s", taskState.TaskName, taskState.TaskUUID, encoded) return nil } // getExpiration returns expiration for a stored task state func (b *etcdBackend) getExpiration() time.Duration { expiresIn := b.GetConfig().ResultsExpireIn if expiresIn == 0 { // expire results after 1 hour by default expiresIn = config.DefaultResultsExpireIn } return time.Duration(expiresIn) * time.Second }