package etcd import ( "bytes" "context" "encoding/json" "errors" "fmt" "math" "runtime" "sync" "time" "github.com/RichardKnop/machinery/v2/brokers/errs" "github.com/RichardKnop/machinery/v2/brokers/iface" "github.com/RichardKnop/machinery/v2/common" "github.com/RichardKnop/machinery/v2/config" "github.com/RichardKnop/machinery/v2/log" "github.com/RichardKnop/machinery/v2/tasks" clientv3 "go.etcd.io/etcd/client/v3" "golang.org/x/sync/errgroup" ) var haveNoTaskErr = errors.New("have no task") // Signature .. type Signature struct { *tasks.Signature CreateAt time.Time `json:"CreatedAt"` Score int64 `json:"Score"` // 写入时间, 队列排序用 body []byte `json:"-"` // 序列号后的值 } func (s *Signature) Body() ([]byte, error) { if len(s.body) > 0 { return s.body, nil } body, err := json.Marshal(s) if err != nil { return nil, err } s.body = body return s.body, nil } func (s *Signature) String() string { body, _ := s.Body() return string(body) } type etcdBroker struct { common.Broker globalConf *config.Config conf clientv3.Config cli *clientv3.Client wg sync.WaitGroup } // New .. func New(cnf *config.Config, endpoint string) (iface.Broker, error) { etcdConf := clientv3.Config{Endpoints: []string{endpoint}} cli, err := clientv3.New(etcdConf) if err != nil { return nil, err } broker := etcdBroker{ Broker: common.NewBroker(cnf), globalConf: cnf, conf: etcdConf, cli: cli, } return &broker, nil } // StartConsuming .. func (b *etcdBroker) StartConsuming(consumerTag string, concurrency int, taskProcessor iface.TaskProcessor) (bool, error) { if concurrency < 1 { concurrency = runtime.NumCPU() } b.Broker.StartConsuming(consumerTag, concurrency, taskProcessor) log.INFO.Printf("[*] Waiting for messages, concurrency=%d. To exit press CTRL+C", concurrency) // Channel to which we will push tasks ready for processing by worker deliveries := make(chan *Signature) // A receiving goroutine keeps popping messages from the queue by BLPOP // If the message is valid and can be unmarshaled into a proper structure // we send it to the deliveries channel b.wg.Add(1) go func() { defer b.wg.Done() for { select { // A way to stop this goroutine from b.StopConsuming case <-b.GetStopChan(): close(deliveries) return default: if !taskProcessor.PreConsumeHandler() { continue } task, err := b.nextTask(getQueue(b.GetConfig(), taskProcessor)) if err != nil { if !errors.Is(err, haveNoTaskErr) { log.ERROR.Print(err) } continue } if task != nil { deliveries <- task } } } }() // A goroutine to watch for delayed tasks and push them to deliveries // channel for consumption by the worker b.wg.Add(1) go func() { defer b.wg.Done() for { select { // A way to stop this goroutine from b.StopConsuming case <-b.GetStopChan(): return default: task, err := b.nextDelayedTask() if err != nil { if !errors.Is(err, haveNoTaskErr) { log.ERROR.Print(err) } continue } if err := b.Publish(context.Background(), task.Signature); err != nil { log.ERROR.Print(err) } } } }() if err := b.consume(deliveries, concurrency, taskProcessor); err != nil { return b.GetRetry(), err } b.wg.Wait() return b.GetRetry(), nil } // consume takes delivered messages from the channel and manages a worker pool // to process tasks concurrently func (b *etcdBroker) consume(deliveries <-chan *Signature, concurrency int, taskProcessor iface.TaskProcessor) error { eg, ctx := errgroup.WithContext(context.Background()) eg.SetLimit(concurrency) for i := 0; i < concurrency; i++ { eg.Go(func() error { for { select { case <-ctx.Done(): return ctx.Err() case t, ok := <-deliveries: if !ok { return nil } if err := b.consumeOne(t, taskProcessor); err != nil { return err } } } }) } return eg.Wait() } // consumeOne processes a single message using TaskProcessor func (b *etcdBroker) consumeOne(signature *Signature, taskProcessor iface.TaskProcessor) error { // If the task is not registered, we requeue it, // there might be different workers for processing specific tasks if !b.IsTaskRegistered(signature.Name) { if signature.IgnoreWhenTaskNotRegistered { return nil } log.INFO.Printf("Task not registered with this worker. Requeuing message: %s", signature) b.requeueMessage(signature, taskProcessor) return nil } log.DEBUG.Printf("Received new message: %s", signature) return taskProcessor.Process(signature.Signature) } // StopConsuming 停止 func (b *etcdBroker) StopConsuming() { b.Broker.StopConsuming() b.wg.Wait() } // Publish put kvs to etcd stor func (b *etcdBroker) Publish(ctx context.Context, signature *tasks.Signature) error { // Adjust routing key (this decides which queue the message will be published to) b.Broker.AdjustRoutingKey(signature) now := time.Now() s := Signature{ Signature: signature, CreateAt: now, Score: now.UnixMilli(), } msg, err := json.Marshal(s) if err != nil { return fmt.Errorf("JSON marshal error: %s", err) } key := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s/%s", s.RoutingKey, s.UUID) // Check the ETA signature field, if it is set and it is in the future, // delay the task if s.ETA != nil { if s.ETA.After(now) { // score := signature.ETA.UnixNano() key = fmt.Sprintf("/machinery/v2/broker/delayed_tasks/%s", s.UUID) _, err = b.cli.Put(ctx, key, string(msg)) return err } } _, err = b.cli.Put(ctx, key, string(msg)) return err } func (b *etcdBroker) getTasks(ctx context.Context, key string) ([]*Signature, error) { resp, err := b.cli.Get(ctx, key, clientv3.WithPrefix()) if err != nil { return nil, err } result := make([]*Signature, 0, len(resp.Kvs)) for _, kvs := range resp.Kvs { signature := new(Signature) decoder := json.NewDecoder(bytes.NewReader(kvs.Value)) decoder.UseNumber() if err := decoder.Decode(signature); err != nil { return nil, errs.NewErrCouldNotUnmarshalTaskSignature(kvs.Value, err) } result = append(result, signature) } return result, nil } // GetPendingTasks 获取执行队列, 任务统计可使用 func (b *etcdBroker) GetPendingTasks(queue string) ([]*tasks.Signature, error) { if queue == "" { queue = b.GetConfig().DefaultQueue } key := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s", queue) items, err := b.getTasks(context.Background(), key) if err != nil { return nil, err } rawTasks := make([]*tasks.Signature, 0, len(items)) for _, v := range items { rawTasks = append(rawTasks, v.Signature) } return rawTasks, nil } // GetDelayedTasks 任务统计可使用 func (b *etcdBroker) GetDelayedTasks() ([]*tasks.Signature, error) { key := "/machinery/v2/broker/delayed_tasks" items, err := b.getTasks(context.Background(), key) if err != nil { return nil, err } rawTasks := make([]*tasks.Signature, 0, len(items)) for _, v := range items { rawTasks = append(rawTasks, v.Signature) } return rawTasks, nil } func (b *etcdBroker) nextTask(queue string) (*Signature, error) { key := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s", queue) items, err := b.getTasks(context.Background(), key) if err != nil { return nil, err } defer time.Sleep(time.Second) if len(items) == 0 { return nil, haveNoTaskErr } var t *Signature score := int64(math.MaxInt64) for _, v := range items { if v.Score < score { t = v score = v.Score } } k := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s/%s", queue, t.UUID) _, _ = b.cli.Delete(context.Background(), k) return t, nil } func (b *etcdBroker) nextDelayedTask() (*Signature, error) { key := "/machinery/v2/broker/delayed_tasks" items, err := b.getTasks(context.Background(), key) if err != nil { return nil, err } if len(items) == 0 { return nil, haveNoTaskErr } var task *Signature now := time.Now() earliest := now for _, t := range items { // 还没有到时间 if t.ETA.After(now) { continue } // 选择最早的时间 if t.ETA.Before(earliest) { earliest = *t.ETA task = t } } if task != nil { k := fmt.Sprintf("/machinery/v2/broker/delayed_tasks/%s", task.UUID) _, err = b.cli.Delete(context.Background(), k) if err != nil { log.ERROR.Print(err) } return task, nil } return nil, haveNoTaskErr } func (b *etcdBroker) requeueMessage(delivery *Signature, taskProcessor iface.TaskProcessor) { queue := getQueue(b.GetConfig(), taskProcessor) key := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s", queue) now := time.Now() s := Signature{ Signature: delivery.Signature, CreateAt: now, Score: now.UnixMilli(), } body, err := json.Marshal(s) if err != nil { log.ERROR.Print(err) return } _, err = b.cli.KV.Put(b.conf.Context, key, string(body)) if err != nil { log.ERROR.Print(err) } } func getQueue(config *config.Config, taskProcessor iface.TaskProcessor) string { customQueue := taskProcessor.CustomQueue() if customQueue == "" { return config.DefaultQueue } return customQueue }