machinery-plugins/brokers/etcd/etcd.go

383 lines
8.9 KiB
Go

package etcd
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"runtime"
"sync"
"time"
"github.com/RichardKnop/machinery/v2/brokers/errs"
"github.com/RichardKnop/machinery/v2/brokers/iface"
"github.com/RichardKnop/machinery/v2/common"
"github.com/RichardKnop/machinery/v2/config"
"github.com/RichardKnop/machinery/v2/log"
"github.com/RichardKnop/machinery/v2/tasks"
clientv3 "go.etcd.io/etcd/client/v3"
"golang.org/x/sync/errgroup"
)
var haveNoTaskErr = errors.New("have no task")
// Signature ..
type Signature struct {
*tasks.Signature
CreateAt time.Time `json:"CreatedAt"`
Score int64 `json:"Score"` // 写入时间, 队列排序用
body []byte `json:"-"` // 序列号后的值
}
func (s *Signature) Body() ([]byte, error) {
if len(s.body) > 0 {
return s.body, nil
}
body, err := json.Marshal(s)
if err != nil {
return nil, err
}
s.body = body
return s.body, nil
}
func (s *Signature) String() string {
body, _ := s.Body()
return string(body)
}
type etcdBroker struct {
common.Broker
ctx context.Context
client *clientv3.Client
wg sync.WaitGroup
}
// New ..
func New(ctx context.Context, conf *config.Config) (iface.Broker, error) {
etcdConf := clientv3.Config{
Endpoints: []string{conf.Lock},
Context: ctx,
DialTimeout: time.Second * 5,
TLS: conf.TLSConfig,
}
client, err := clientv3.New(etcdConf)
if err != nil {
return nil, err
}
broker := etcdBroker{
Broker: common.NewBroker(conf),
ctx: ctx,
client: client,
}
return &broker, nil
}
// StartConsuming ..
func (b *etcdBroker) StartConsuming(consumerTag string, concurrency int, taskProcessor iface.TaskProcessor) (bool, error) {
if concurrency < 1 {
concurrency = runtime.NumCPU()
}
b.Broker.StartConsuming(consumerTag, concurrency, taskProcessor)
log.INFO.Printf("[*] Waiting for messages, concurrency=%d. To exit press CTRL+C", concurrency)
// Channel to which we will push tasks ready for processing by worker
deliveries := make(chan *Signature)
// A receiving goroutine keeps popping messages from the queue by BLPOP
// If the message is valid and can be unmarshaled into a proper structure
// we send it to the deliveries channel
b.wg.Add(1)
go func() {
defer b.wg.Done()
for {
select {
// A way to stop this goroutine from b.StopConsuming
case <-b.GetStopChan():
close(deliveries)
return
default:
if !taskProcessor.PreConsumeHandler() {
continue
}
task, err := b.nextTask(getQueue(b.GetConfig(), taskProcessor))
if err != nil {
if !errors.Is(err, haveNoTaskErr) {
log.ERROR.Print(err)
}
continue
}
if task != nil {
deliveries <- task
}
}
}
}()
// A goroutine to watch for delayed tasks and push them to deliveries
// channel for consumption by the worker
b.wg.Add(1)
go func() {
defer b.wg.Done()
for {
select {
// A way to stop this goroutine from b.StopConsuming
case <-b.GetStopChan():
return
default:
task, err := b.nextDelayedTask()
if err != nil {
if !errors.Is(err, haveNoTaskErr) {
log.ERROR.Print(err)
}
continue
}
if err := b.Publish(context.Background(), task.Signature); err != nil {
log.ERROR.Print(err)
}
}
}
}()
if err := b.consume(deliveries, concurrency, taskProcessor); err != nil {
return b.GetRetry(), err
}
b.wg.Wait()
return b.GetRetry(), nil
}
// consume takes delivered messages from the channel and manages a worker pool
// to process tasks concurrently
func (b *etcdBroker) consume(deliveries <-chan *Signature, concurrency int, taskProcessor iface.TaskProcessor) error {
eg, ctx := errgroup.WithContext(context.Background())
eg.SetLimit(concurrency)
for i := 0; i < concurrency; i++ {
eg.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case t, ok := <-deliveries:
if !ok {
return nil
}
if err := b.consumeOne(t, taskProcessor); err != nil {
return err
}
}
}
})
}
return eg.Wait()
}
// consumeOne processes a single message using TaskProcessor
func (b *etcdBroker) consumeOne(signature *Signature, taskProcessor iface.TaskProcessor) error {
// If the task is not registered, we requeue it,
// there might be different workers for processing specific tasks
if !b.IsTaskRegistered(signature.Name) {
if signature.IgnoreWhenTaskNotRegistered {
return nil
}
log.INFO.Printf("Task not registered with this worker. Requeuing message: %s", signature)
// b.requeueMessage(signature, taskProcessor)
return nil
}
log.DEBUG.Printf("Received new message: %s", signature)
err := taskProcessor.Process(signature.Signature)
// ack
// b.deleteKey(si)
return err
}
// StopConsuming 停止
func (b *etcdBroker) StopConsuming() {
b.Broker.StopConsuming()
b.wg.Wait()
}
// Publish put kvs to etcd stor
func (b *etcdBroker) Publish(ctx context.Context, signature *tasks.Signature) error {
// Adjust routing key (this decides which queue the message will be published to)
b.Broker.AdjustRoutingKey(signature)
now := time.Now()
s := Signature{
Signature: signature,
CreateAt: now,
Score: now.UnixMilli(),
}
msg, err := json.Marshal(s)
if err != nil {
return fmt.Errorf("JSON marshal error: %s", err)
}
key := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s/%s", s.RoutingKey, s.UUID)
// Check the ETA signature field, if it is set and it is in the future,
// delay the task
if s.ETA != nil {
if s.ETA.After(now) {
key = fmt.Sprintf("/machinery/v2/broker/delayed_tasks/t%d-%s", s.ETA.UnixNano(), s.UUID)
_, err = b.client.Put(ctx, key, string(msg))
return err
}
}
_, err = b.client.Put(ctx, key, string(msg))
return err
}
func (b *etcdBroker) getTasks(ctx context.Context, key string) ([]*Signature, error) {
resp, err := b.client.Get(ctx, key, clientv3.WithPrefix())
if err != nil {
return nil, err
}
result := make([]*Signature, 0, len(resp.Kvs))
for _, kvs := range resp.Kvs {
signature := new(Signature)
decoder := json.NewDecoder(bytes.NewReader(kvs.Value))
decoder.UseNumber()
if err := decoder.Decode(signature); err != nil {
return nil, errs.NewErrCouldNotUnmarshalTaskSignature(kvs.Value, err)
}
result = append(result, signature)
}
return result, nil
}
// GetPendingTasks 获取执行队列, 任务统计可使用
func (b *etcdBroker) GetPendingTasks(queue string) ([]*tasks.Signature, error) {
if queue == "" {
queue = b.GetConfig().DefaultQueue
}
key := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s", queue)
items, err := b.getTasks(context.Background(), key)
if err != nil {
return nil, err
}
rawTasks := make([]*tasks.Signature, 0, len(items))
for _, v := range items {
rawTasks = append(rawTasks, v.Signature)
}
return rawTasks, nil
}
// GetDelayedTasks 任务统计可使用
func (b *etcdBroker) GetDelayedTasks() ([]*tasks.Signature, error) {
key := "/machinery/v2/broker/delayed_tasks"
items, err := b.getTasks(context.Background(), key)
if err != nil {
return nil, err
}
rawTasks := make([]*tasks.Signature, 0, len(items))
for _, v := range items {
rawTasks = append(rawTasks, v.Signature)
}
return rawTasks, nil
}
func (b *etcdBroker) nextTask(queue string) (*Signature, error) {
keyPrefix := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s", queue)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
item, err := getFirstItem(ctx, b.client, keyPrefix)
if err != nil {
return nil, err
}
signature := new(Signature)
decoder := json.NewDecoder(bytes.NewReader(item))
decoder.UseNumber()
if err := decoder.Decode(signature); err != nil {
return nil, errs.NewErrCouldNotUnmarshalTaskSignature(item, err)
}
return signature, nil
}
func (b *etcdBroker) nextDelayedTask() (*Signature, error) {
keyPrefix := "/machinery/v2/broker/delayed_tasks"
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
item, err := getFirstETAItem(ctx, b.client, keyPrefix)
if err != nil {
return nil, err
}
signature := new(Signature)
decoder := json.NewDecoder(bytes.NewReader(item))
decoder.UseNumber()
if err := decoder.Decode(signature); err != nil {
return nil, errs.NewErrCouldNotUnmarshalTaskSignature(item, err)
}
return signature, nil
}
func (b *etcdBroker) requeueMessage(delivery *Signature, taskProcessor iface.TaskProcessor) {
queue := getQueue(b.GetConfig(), taskProcessor)
key := fmt.Sprintf("/machinery/v2/broker/pending_tasks/%s", queue)
now := time.Now()
s := Signature{
Signature: delivery.Signature,
CreateAt: now,
Score: now.UnixMilli(),
}
body, err := json.Marshal(s)
if err != nil {
log.ERROR.Print(err)
return
}
_, err = b.client.KV.Put(b.ctx, key, string(body))
if err != nil {
log.ERROR.Print(err)
}
}
func getQueue(config *config.Config, taskProcessor iface.TaskProcessor) string {
customQueue := taskProcessor.CustomQueue()
if customQueue == "" {
return config.DefaultQueue
}
return customQueue
}