diff options
| author | 2024-07-30 11:58:31 +0000 | |
|---|---|---|
| committer | 2024-07-30 13:58:31 +0200 | |
| commit | 87cff71af95d2cef095a5feea40e48b40576b3d0 (patch) | |
| tree | 9725ac3ab67d050e78016a2246d2b020635edcb7 | |
| parent | [chore] replace UniqueStrings with Deduplicate (#3154) (diff) | |
| download | gotosocial-87cff71af95d2cef095a5feea40e48b40576b3d0.tar.xz | |
[feature] persist worker queues to db (#3042)
* persist queued worker tasks to database on shutdown, fill worker queues from database on startup
* ensure the tasks are sorted by creation time before pushing them
* add migration to insert WorkerTask{} into database, add test for worker task persistence
* add test for recovering worker queues from database
* quick tweak
* whoops we ended up with double cleaner job scheduling
* insert each task separately, because bun is throwing some reflection error??
* add specific checking of cancelled worker contexts
* add http request signing to deliveries recovered from database
* add test for outgoing public key ID being correctly set on delivery
* replace select with Queue.PopCtx()
* get rid of loop now we don't use it
* remove field now we don't use it
* ensure that signing func is set
* header values weren't being copied over :facepalm:
* use ptr for httpclient.Request in delivery
* move worker queue filling to later in server init process
* fix rebase issues
* make logging less shouty
* use slices.Delete() instead of copying / reslicing
* have database return tasks in ascending order instead of sorting them
* add a 1 minute timeout to persisting worker queues
| -rw-r--r-- | cmd/gotosocial/action/server/server.go | 67 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go | 51 | ||||
| -rw-r--r-- | internal/db/bundb/workertask.go | 58 | ||||
| -rw-r--r-- | internal/db/db.go | 1 | ||||
| -rw-r--r-- | internal/db/workertask.go | 35 | ||||
| -rw-r--r-- | internal/gtsmodel/workertask.go | 8 | ||||
| -rw-r--r-- | internal/httpclient/client.go | 4 | ||||
| -rw-r--r-- | internal/httpclient/request.go | 4 | ||||
| -rw-r--r-- | internal/messages/messages.go | 2 | ||||
| -rw-r--r-- | internal/processing/admin/workertask.go | 426 | ||||
| -rw-r--r-- | internal/processing/admin/workertask_test.go | 421 | ||||
| -rw-r--r-- | internal/transport/deliver.go | 33 | ||||
| -rw-r--r-- | internal/transport/delivery/delivery.go | 16 | ||||
| -rw-r--r-- | internal/transport/delivery/delivery_test.go | 31 | ||||
| -rw-r--r-- | internal/transport/delivery/worker.go | 80 | ||||
| -rw-r--r-- | internal/transport/transport.go | 5 | ||||
| -rw-r--r-- | internal/workers/worker_msg.go | 21 | ||||
| -rw-r--r-- | internal/workers/workers.go | 10 | ||||
| -rw-r--r-- | testrig/db.go | 5 | 
20 files changed, 1190 insertions, 92 deletions
diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index 42cbf318b..68b039d0c 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -87,9 +87,9 @@ var Start action.GTSAction = func(ctx context.Context) error {  		// defer function for safe shutdown  		// depending on what services were  		// managed to be started. - -		state = new(state.State) -		route *router.Router +		state   = new(state.State) +		route   *router.Router +		process *processing.Processor  	)  	defer func() { @@ -125,6 +125,23 @@ var Start action.GTSAction = func(ctx context.Context) error {  			}  		} +		if process != nil { +			const timeout = time.Minute + +			// Use a new timeout context to ensure +			// persisting queued tasks does not fail! +			// The main ctx is very likely canceled. +			ctx := context.WithoutCancel(ctx) +			ctx, cncl := context.WithTimeout(ctx, timeout) +			defer cncl() + +			// Now that all the "moving" components have been stopped, +			// persist any remaining queued worker tasks to the database. +			if err := process.Admin().PersistWorkerQueues(ctx); err != nil { +				log.Errorf(ctx, "error persisting worker queues: %v", err) +			} +		} +  		if state.DB != nil {  			// Lastly, if database service was started,  			// ensure it gets closed now all else stopped. @@ -270,7 +287,7 @@ var Start action.GTSAction = func(ctx context.Context) error {  	// Create the processor using all the  	// other services we've created so far. -	processor := processing.NewProcessor( +	process = processing.NewProcessor(  		cleaner,  		typeConverter,  		federator, @@ -286,14 +303,14 @@ var Start action.GTSAction = func(ctx context.Context) error {  	state.Workers.Client.Init(messages.ClientMsgIndices())  	state.Workers.Federator.Init(messages.FederatorMsgIndices())  	state.Workers.Delivery.Init(client) -	state.Workers.Client.Process = processor.Workers().ProcessFromClientAPI -	state.Workers.Federator.Process = processor.Workers().ProcessFromFediAPI +	state.Workers.Client.Process = process.Workers().ProcessFromClientAPI +	state.Workers.Federator.Process = process.Workers().ProcessFromFediAPI  	// Now start workers!  	state.Workers.Start()  	// Schedule notif tasks for all existing poll expiries. -	if err := processor.Polls().ScheduleAll(ctx); err != nil { +	if err := process.Polls().ScheduleAll(ctx); err != nil {  		return fmt.Errorf("error scheduling poll expiries: %w", err)  	} @@ -303,7 +320,7 @@ var Start action.GTSAction = func(ctx context.Context) error {  	}  	// Run advanced migrations. -	if err := processor.AdvancedMigrations().Migrate(ctx); err != nil { +	if err := process.AdvancedMigrations().Migrate(ctx); err != nil {  		return err  	} @@ -370,7 +387,7 @@ var Start action.GTSAction = func(ctx context.Context) error {  	// attach global no route / 404 handler to the router  	route.AttachNoRouteHandler(func(c *gin.Context) { -		apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGetV1) +		apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), process.InstanceGetV1)  	})  	// build router modules @@ -393,15 +410,15 @@ var Start action.GTSAction = func(ctx context.Context) error {  	}  	var ( -		authModule        = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths -		clientModule      = api.NewClient(state, processor)                                    // api client endpoints -		metricsModule     = api.NewMetrics()                                                   // Metrics endpoints -		healthModule      = api.NewHealth(dbService.Ready)                                     // Health check endpoints -		fileserverModule  = api.NewFileserver(processor)                                       // fileserver endpoints -		wellKnownModule   = api.NewWellKnown(processor)                                        // .well-known endpoints -		nodeInfoModule    = api.NewNodeInfo(processor)                                         // nodeinfo endpoint -		activityPubModule = api.NewActivityPub(dbService, processor)                           // ActivityPub endpoints -		webModule         = web.New(dbService, processor)                                      // web pages + user profiles + settings panels etc +		authModule        = api.NewAuth(dbService, process, idp, routerSession, sessionName) // auth/oauth paths +		clientModule      = api.NewClient(state, process)                                    // api client endpoints +		metricsModule     = api.NewMetrics()                                                 // Metrics endpoints +		healthModule      = api.NewHealth(dbService.Ready)                                   // Health check endpoints +		fileserverModule  = api.NewFileserver(process)                                       // fileserver endpoints +		wellKnownModule   = api.NewWellKnown(process)                                        // .well-known endpoints +		nodeInfoModule    = api.NewNodeInfo(process)                                         // nodeinfo endpoint +		activityPubModule = api.NewActivityPub(dbService, process)                           // ActivityPub endpoints +		webModule         = web.New(dbService, process)                                      // web pages + user profiles + settings panels etc  	)  	// create required middleware @@ -416,10 +433,11 @@ var Start action.GTSAction = func(ctx context.Context) error {  	// throttling  	cpuMultiplier := config.GetAdvancedThrottlingMultiplier()  	retryAfter := config.GetAdvancedThrottlingRetryAfter() -	clThrottle := middleware.Throttle(cpuMultiplier, retryAfter)  // client api -	s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // server-to-server (AP) -	fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter)  // fileserver / web templates / emojis -	pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter)  // throttle public key endpoint separately +	clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api +	s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter) +	// server-to-server (AP) +	fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis +	pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately  	gzip := middleware.Gzip() // applied to all except fileserver @@ -442,6 +460,11 @@ var Start action.GTSAction = func(ctx context.Context) error {  		return fmt.Errorf("error starting router: %w", err)  	} +	// Fill worker queues from persisted task data in database. +	if err := process.Admin().FillWorkerQueues(ctx); err != nil { +		return fmt.Errorf("error filling worker queues: %w", err) +	} +  	// catch shutdown signals from the operating system  	sigs := make(chan os.Signal, 1)  	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 070d4eb91..d5071d141 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -84,6 +84,7 @@ type DBService struct {  	db.Timeline  	db.User  	db.Tombstone +	db.WorkerTask  	db *bun.DB  } @@ -302,6 +303,9 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			db:    db,  			state: state,  		}, +		WorkerTask: &workerTaskDB{ +			db: db, +		},  		db: db,  	} diff --git a/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go b/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go new file mode 100644 index 000000000..3b0ebcfd8 --- /dev/null +++ b/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go @@ -0,0 +1,51 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			// WorkerTask table. +			if _, err := tx. +				NewCreateTable(). +				Model(>smodel.WorkerTask{}). +				IfNotExists(). +				Exec(ctx); err != nil { +				return err +			} +			return nil +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/bundb/workertask.go b/internal/db/bundb/workertask.go new file mode 100644 index 000000000..eec51530d --- /dev/null +++ b/internal/db/bundb/workertask.go @@ -0,0 +1,58 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"errors" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +type workerTaskDB struct{ db *bun.DB } + +func (w *workerTaskDB) GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error) { +	var tasks []*gtsmodel.WorkerTask +	if err := w.db.NewSelect(). +		Model(&tasks). +		OrderExpr("? ASC", bun.Ident("created_at")). +		Scan(ctx); err != nil { +		return nil, err +	} +	return tasks, nil +} + +func (w *workerTaskDB) PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error { +	var errs []error +	for _, task := range tasks { +		_, err := w.db.NewInsert().Model(task).Exec(ctx) +		if err != nil { +			errs = append(errs, err) +		} +	} +	return errors.Join(errs...) +} + +func (w *workerTaskDB) DeleteWorkerTaskByID(ctx context.Context, id uint) error { +	_, err := w.db.NewDelete(). +		Table("worker_tasks"). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx) +	return err +} diff --git a/internal/db/db.go b/internal/db/db.go index 4b2152732..cd621871a 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -56,4 +56,5 @@ type DB interface {  	Timeline  	User  	Tombstone +	WorkerTask  } diff --git a/internal/db/workertask.go b/internal/db/workertask.go new file mode 100644 index 000000000..0276f231a --- /dev/null +++ b/internal/db/workertask.go @@ -0,0 +1,35 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type WorkerTask interface { +	// GetWorkerTasks fetches all persisted worker tasks from the database. +	GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error) + +	// PutWorkerTasks persists the given worker tasks to the database. +	PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error + +	// DeleteWorkerTask deletes worker task with given ID from database. +	DeleteWorkerTaskByID(ctx context.Context, id uint) error +} diff --git a/internal/gtsmodel/workertask.go b/internal/gtsmodel/workertask.go index cc8433199..758fc4cd7 100644 --- a/internal/gtsmodel/workertask.go +++ b/internal/gtsmodel/workertask.go @@ -34,8 +34,8 @@ const (  // queued tasks from being lost. It is simply a  // means to store a blob of serialized task data.  type WorkerTask struct { -	ID         uint      `bun:""` -	WorkerType uint8     `bun:""` -	TaskData   []byte    `bun:""` -	CreatedAt  time.Time `bun:""` +	ID         uint       `bun:",pk,autoincrement"` +	WorkerType WorkerType `bun:",notnull"` +	TaskData   []byte     `bun:",nullzero,notnull"` +	CreatedAt  time.Time  `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`  } diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index b78dbc2d9..30ef0b04d 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -197,7 +197,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {  		// If the fast-fail flag was set, just  		// attempt a single iteration instead of  		// following the below retry-backoff loop. -		rsp, _, err = c.DoOnce(&req) +		rsp, _, err = c.DoOnce(req)  		if err != nil {  			return nil, fmt.Errorf("%w (fast fail)", err)  		} @@ -208,7 +208,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {  		var retry bool  		// Perform the http request. -		rsp, retry, err = c.DoOnce(&req) +		rsp, retry, err = c.DoOnce(req)  		if err == nil {  			return rsp, nil  		} diff --git a/internal/httpclient/request.go b/internal/httpclient/request.go index e5a7f44d3..dfe51b160 100644 --- a/internal/httpclient/request.go +++ b/internal/httpclient/request.go @@ -47,8 +47,8 @@ type Request struct {  // WrapRequest wraps an existing http.Request within  // our own httpclient.Request with retry / backoff tracking. -func WrapRequest(r *http.Request) Request { -	var rr Request +func WrapRequest(r *http.Request) *Request { +	rr := new(Request)  	rr.Request = r  	entry := log.WithContext(r.Context())  	entry = entry.WithField("method", r.Method) diff --git a/internal/messages/messages.go b/internal/messages/messages.go index 7779633ba..d652c0c5c 100644 --- a/internal/messages/messages.go +++ b/internal/messages/messages.go @@ -352,7 +352,7 @@ func resolveAPObject(data map[string]interface{}) (interface{}, error) {  // we then need to wrangle back into the original type. So we also store the type name  // and use this to determine the appropriate Go structure type to unmarshal into to.  func resolveGTSModel(typ string, data []byte) (interface{}, error) { -	if typ == "" && data == nil { +	if typ == "" {  		// No data given.  		return nil, nil  	} diff --git a/internal/processing/admin/workertask.go b/internal/processing/admin/workertask.go new file mode 100644 index 000000000..6d7cc7b7a --- /dev/null +++ b/internal/processing/admin/workertask.go @@ -0,0 +1,426 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package admin + +import ( +	"context" +	"fmt" +	"slices" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/transport" +	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery" +) + +// NOTE: +// Having these functions in the processor, which is +// usually the intermediary that performs *processing* +// between the HTTP route handlers and the underlying +// database / storage layers is a little odd, so this +// may be subject to change! +// +// For now at least, this is a useful place that has +// access to the underlying database, workers and +// causes no dependency cycles with this use case! + +// FillWorkerQueues recovers all serialized worker tasks from the database +// (if any!), and pushes them to each of their relevant worker queues. +func (p *Processor) FillWorkerQueues(ctx context.Context) error { +	log.Info(ctx, "rehydrate!") + +	// Get all persisted worker tasks from db. +	// +	// (database returns these as ASCENDING, i.e. +	// returned in the order they were inserted). +	tasks, err := p.state.DB.GetWorkerTasks(ctx) +	if err != nil { +		return gtserror.Newf("error fetching worker tasks from db: %w", err) +	} + +	var ( +		// Counts of each task type +		// successfully recovered. +		delivery  int +		federator int +		client    int + +		// Failed recoveries. +		errors int +	) + +loop: + +	// Handle each persisted task, removing +	// all those we can't handle. Leaving us +	// with a slice of tasks we can safely +	// delete from being persisted in the DB. +	for i := 0; i < len(tasks); { +		var err error + +		// Task at index. +		task := tasks[i] + +		// Appropriate task count +		// pointer to increment. +		var counter *int + +		// Attempt to recovery persisted +		// task depending on worker type. +		switch task.WorkerType { +		case gtsmodel.DeliveryWorker: +			err = p.pushDelivery(ctx, task) +			counter = &delivery +		case gtsmodel.FederatorWorker: +			err = p.pushFederator(ctx, task) +			counter = &federator +		case gtsmodel.ClientWorker: +			err = p.pushClient(ctx, task) +			counter = &client +		default: +			err = fmt.Errorf("invalid worker type %d", task.WorkerType) +		} + +		if err != nil { +			log.Errorf(ctx, "error pushing task %d: %v", task.ID, err) + +			// Drop error'd task from slice. +			tasks = slices.Delete(tasks, i, i+1) + +			// Incr errors. +			errors++ +			continue loop +		} + +		// Increment slice +		// index & counter. +		(*counter)++ +		i++ +	} + +	// Tasks that worker successfully pushed +	// to their appropriate workers, we can +	// safely now remove from the database. +	for _, task := range tasks { +		if err := p.state.DB.DeleteWorkerTaskByID(ctx, task.ID); err != nil { +			log.Errorf(ctx, "error deleting task from db: %v", err) +		} +	} + +	// Log recovered tasks. +	log.WithContext(ctx). +		WithField("delivery", delivery). +		WithField("federator", federator). +		WithField("client", client). +		WithField("errors", errors). +		Info("recovered queued tasks") + +	return nil +} + +// PersistWorkerQueues pops all queued worker tasks (that are themselves persistable, i.e. not +// dereference tasks which are just function ptrs), serializes and persists them to the database. +func (p *Processor) PersistWorkerQueues(ctx context.Context) error { +	log.Info(ctx, "dehydrate!") + +	var ( +		// Counts of each task type +		// successfully persisted. +		delivery  int +		federator int +		client    int + +		// Failed persists. +		errors int + +		// Serialized tasks to persist. +		tasks []*gtsmodel.WorkerTask +	) + +	for { +		// Pop all queued deliveries. +		task, err := p.popDelivery() +		if err != nil { +			log.Errorf(ctx, "error popping delivery: %v", err) +			errors++ // incr error count. +			continue +		} + +		if task == nil { +			// No more queue +			// tasks to pop! +			break +		} + +		// Append serialized task. +		tasks = append(tasks, task) +		delivery++ // incr count +	} + +	for { +		// Pop queued federator msgs. +		task, err := p.popFederator() +		if err != nil { +			log.Errorf(ctx, "error popping federator message: %v", err) +			errors++ // incr count +			continue +		} + +		if task == nil { +			// No more queue +			// tasks to pop! +			break +		} + +		// Append serialized task. +		tasks = append(tasks, task) +		federator++ // incr count +	} + +	for { +		// Pop queued client msgs. +		task, err := p.popClient() +		if err != nil { +			log.Errorf(ctx, "error popping client message: %v", err) +			continue +		} + +		if task == nil { +			// No more queue +			// tasks to pop! +			break +		} + +		// Append serialized task. +		tasks = append(tasks, task) +		client++ // incr count +	} + +	// Persist all serialized queued worker tasks to database. +	if err := p.state.DB.PutWorkerTasks(ctx, tasks); err != nil { +		return gtserror.Newf("error putting tasks in db: %w", err) +	} + +	// Log recovered tasks. +	log.WithContext(ctx). +		WithField("delivery", delivery). +		WithField("federator", federator). +		WithField("client", client). +		WithField("errors", errors). +		Info("persisted queued tasks") + +	return nil +} + +// pushDelivery parses a valid delivery.Delivery{} from serialized task data and pushes to queue. +func (p *Processor) pushDelivery(ctx context.Context, task *gtsmodel.WorkerTask) error { +	dlv := new(delivery.Delivery) + +	// Deserialize the raw worker task data into delivery. +	if err := dlv.Deserialize(task.TaskData); err != nil { +		return gtserror.Newf("error deserializing delivery: %w", err) +	} + +	var tsport transport.Transport + +	if uri := dlv.ActorID; uri != "" { +		// Fetch the actor account by provided URI from db. +		account, err := p.state.DB.GetAccountByURI(ctx, uri) +		if err != nil { +			return gtserror.Newf("error getting actor account %s from db: %w", uri, err) +		} + +		// Fetch a transport for request signing for actor's account username. +		tsport, err = p.transport.NewTransportForUsername(ctx, account.Username) +		if err != nil { +			return gtserror.Newf("error getting transport for actor %s: %w", uri, err) +		} +	} else { +		var err error + +		// No actor was given, will be signed by instance account. +		tsport, err = p.transport.NewTransportForUsername(ctx, "") +		if err != nil { +			return gtserror.Newf("error getting instance account transport: %w", err) +		} +	} + +	// Using transport, add actor signature to delivery. +	if err := tsport.SignDelivery(dlv); err != nil { +		return gtserror.Newf("error signing delivery: %w", err) +	} + +	// Push deserialized task to delivery queue. +	p.state.Workers.Delivery.Queue.Push(dlv) + +	return nil +} + +// popDelivery pops delivery.Delivery{} from queue and serializes as valid task data. +func (p *Processor) popDelivery() (*gtsmodel.WorkerTask, error) { + +	// Pop waiting delivery from the delivery worker. +	delivery, ok := p.state.Workers.Delivery.Queue.Pop() +	if !ok { +		return nil, nil +	} + +	// Serialize the delivery task data. +	data, err := delivery.Serialize() +	if err != nil { +		return nil, gtserror.Newf("error serializing delivery: %w", err) +	} + +	return >smodel.WorkerTask{ +		// ID is autoincrement +		WorkerType: gtsmodel.DeliveryWorker, +		TaskData:   data, +		CreatedAt:  time.Now(), +	}, nil +} + +// pushClient parses a valid messages.FromFediAPI{} from serialized task data and pushes to queue. +func (p *Processor) pushFederator(ctx context.Context, task *gtsmodel.WorkerTask) error { +	var msg messages.FromFediAPI + +	// Deserialize the raw worker task data into message. +	if err := msg.Deserialize(task.TaskData); err != nil { +		return gtserror.Newf("error deserializing federator message: %w", err) +	} + +	if rcv := msg.Receiving; rcv != nil { +		// Only a placeholder receiving account will be populated, +		// fetch the actual model from database by persisted ID. +		account, err := p.state.DB.GetAccountByID(ctx, rcv.ID) +		if err != nil { +			return gtserror.Newf("error fetching receiving account %s from db: %w", rcv.ID, err) +		} + +		// Set the now populated +		// receiving account model. +		msg.Receiving = account +	} + +	if req := msg.Requesting; req != nil { +		// Only a placeholder requesting account will be populated, +		// fetch the actual model from database by persisted ID. +		account, err := p.state.DB.GetAccountByID(ctx, req.ID) +		if err != nil { +			return gtserror.Newf("error fetching requesting account %s from db: %w", req.ID, err) +		} + +		// Set the now populated +		// requesting account model. +		msg.Requesting = account +	} + +	// Push populated task to the federator queue. +	p.state.Workers.Federator.Queue.Push(&msg) + +	return nil +} + +// popFederator pops messages.FromFediAPI{} from queue and serializes as valid task data. +func (p *Processor) popFederator() (*gtsmodel.WorkerTask, error) { + +	// Pop waiting message from the federator worker. +	msg, ok := p.state.Workers.Federator.Queue.Pop() +	if !ok { +		return nil, nil +	} + +	// Serialize message task data. +	data, err := msg.Serialize() +	if err != nil { +		return nil, gtserror.Newf("error serializing federator message: %w", err) +	} + +	return >smodel.WorkerTask{ +		// ID is autoincrement +		WorkerType: gtsmodel.FederatorWorker, +		TaskData:   data, +		CreatedAt:  time.Now(), +	}, nil +} + +// pushClient parses a valid messages.FromClientAPI{} from serialized task data and pushes to queue. +func (p *Processor) pushClient(ctx context.Context, task *gtsmodel.WorkerTask) error { +	var msg messages.FromClientAPI + +	// Deserialize the raw worker task data into message. +	if err := msg.Deserialize(task.TaskData); err != nil { +		return gtserror.Newf("error deserializing client message: %w", err) +	} + +	if org := msg.Origin; org != nil { +		// Only a placeholder origin account will be populated, +		// fetch the actual model from database by persisted ID. +		account, err := p.state.DB.GetAccountByID(ctx, org.ID) +		if err != nil { +			return gtserror.Newf("error fetching origin account %s from db: %w", org.ID, err) +		} + +		// Set the now populated +		// origin account model. +		msg.Origin = account +	} + +	if trg := msg.Target; trg != nil { +		// Only a placeholder target account will be populated, +		// fetch the actual model from database by persisted ID. +		account, err := p.state.DB.GetAccountByID(ctx, trg.ID) +		if err != nil { +			return gtserror.Newf("error fetching target account %s from db: %w", trg.ID, err) +		} + +		// Set the now populated +		// target account model. +		msg.Target = account +	} + +	// Push populated task to the federator queue. +	p.state.Workers.Client.Queue.Push(&msg) + +	return nil +} + +// popClient pops messages.FromClientAPI{} from queue and serializes as valid task data. +func (p *Processor) popClient() (*gtsmodel.WorkerTask, error) { + +	// Pop waiting message from the client worker. +	msg, ok := p.state.Workers.Client.Queue.Pop() +	if !ok { +		return nil, nil +	} + +	// Serialize message task data. +	data, err := msg.Serialize() +	if err != nil { +		return nil, gtserror.Newf("error serializing client message: %w", err) +	} + +	return >smodel.WorkerTask{ +		// ID is autoincrement +		WorkerType: gtsmodel.ClientWorker, +		TaskData:   data, +		CreatedAt:  time.Now(), +	}, nil +} diff --git a/internal/processing/admin/workertask_test.go b/internal/processing/admin/workertask_test.go new file mode 100644 index 000000000..bf326bafd --- /dev/null +++ b/internal/processing/admin/workertask_test.go @@ -0,0 +1,421 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package admin_test + +import ( +	"bytes" +	"context" +	"encoding/json" +	"fmt" +	"io" +	"net/http" +	"net/url" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient" +	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery" +	"github.com/superseriousbusiness/gotosocial/testrig" +) + +var ( +	// TODO: move these test values into +	// the testrig test models area. They'll +	// need to be as both WorkerTask and as +	// the raw types themselves. + +	testDeliveries = []*delivery.Delivery{ +		{ +			ObjectID: "https://google.com/users/bigboy/follow/1", +			TargetID: "https://askjeeves.com/users/smallboy", +			Request:  toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Host": {"https://askjeeves.com"}}), +		}, +		{ +			Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), http.Header{"Host": {"https://google.com"}}), +		}, +	} + +	testFederatorMsgs = []*messages.FromFediAPI{ +		{ +			APObjectType:   ap.ObjectNote, +			APActivityType: ap.ActivityCreate, +			TargetURI:      "https://gotosocial.org", +			Requesting:     >smodel.Account{ID: "654321"}, +			Receiving:      >smodel.Account{ID: "123456"}, +		}, +		{ +			APObjectType:   ap.ObjectProfile, +			APActivityType: ap.ActivityUpdate, +			TargetURI:      "https://uk-queen-is-dead.org", +			Requesting:     >smodel.Account{ID: "123456"}, +			Receiving:      >smodel.Account{ID: "654321"}, +		}, +	} + +	testClientMsgs = []*messages.FromClientAPI{ +		{ +			APObjectType:   ap.ObjectNote, +			APActivityType: ap.ActivityCreate, +			TargetURI:      "https://gotosocial.org", +			Origin:         >smodel.Account{ID: "654321"}, +			Target:         >smodel.Account{ID: "123456"}, +		}, +		{ +			APObjectType:   ap.ObjectProfile, +			APActivityType: ap.ActivityUpdate, +			TargetURI:      "https://uk-queen-is-dead.org", +			Origin:         >smodel.Account{ID: "123456"}, +			Target:         >smodel.Account{ID: "654321"}, +		}, +	} +) + +type WorkerTaskTestSuite struct { +	AdminStandardTestSuite +} + +func (suite *WorkerTaskTestSuite) TestFillWorkerQueues() { +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	var tasks []*gtsmodel.WorkerTask + +	for _, dlv := range testDeliveries { +		// Serialize all test deliveries. +		data, err := dlv.Serialize() +		if err != nil { +			panic(err) +		} + +		// Append each serialized delivery to tasks. +		tasks = append(tasks, >smodel.WorkerTask{ +			WorkerType: gtsmodel.DeliveryWorker, +			TaskData:   data, +		}) +	} + +	for _, msg := range testFederatorMsgs { +		// Serialize all test messages. +		data, err := msg.Serialize() +		if err != nil { +			panic(err) +		} + +		if msg.Receiving != nil { +			// Quick hack to bypass database errors for non-existing +			// accounts, instead we just insert this into cache ;). +			suite.state.Caches.DB.Account.Put(msg.Receiving) +			suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ +				AccountID: msg.Receiving.ID, +			}) +		} + +		if msg.Requesting != nil { +			// Quick hack to bypass database errors for non-existing +			// accounts, instead we just insert this into cache ;). +			suite.state.Caches.DB.Account.Put(msg.Requesting) +			suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ +				AccountID: msg.Requesting.ID, +			}) +		} + +		// Append each serialized message to tasks. +		tasks = append(tasks, >smodel.WorkerTask{ +			WorkerType: gtsmodel.FederatorWorker, +			TaskData:   data, +		}) +	} + +	for _, msg := range testClientMsgs { +		// Serialize all test messages. +		data, err := msg.Serialize() +		if err != nil { +			panic(err) +		} + +		if msg.Origin != nil { +			// Quick hack to bypass database errors for non-existing +			// accounts, instead we just insert this into cache ;). +			suite.state.Caches.DB.Account.Put(msg.Origin) +			suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ +				AccountID: msg.Origin.ID, +			}) +		} + +		if msg.Target != nil { +			// Quick hack to bypass database errors for non-existing +			// accounts, instead we just insert this into cache ;). +			suite.state.Caches.DB.Account.Put(msg.Target) +			suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ +				AccountID: msg.Target.ID, +			}) +		} + +		// Append each serialized message to tasks. +		tasks = append(tasks, >smodel.WorkerTask{ +			WorkerType: gtsmodel.ClientWorker, +			TaskData:   data, +		}) +	} + +	// Persist all test worker tasks to the database. +	err := suite.state.DB.PutWorkerTasks(ctx, tasks) +	suite.NoError(err) + +	// Fill the worker queues from persisted task data. +	err = suite.adminProcessor.FillWorkerQueues(ctx) +	suite.NoError(err) + +	var ( +		// Recovered +		// task counts. +		ndelivery  int +		nfederator int +		nclient    int +	) + +	// Fetch current gotosocial instance account, for later checks. +	instanceAcc, err := suite.state.DB.GetInstanceAccount(ctx, "") +	suite.NoError(err) + +	for { +		// Pop all queued delivery tasks from worker queue. +		dlv, ok := suite.state.Workers.Delivery.Queue.Pop() +		if !ok { +			break +		} + +		// Incr count. +		ndelivery++ + +		// Check that we have this message in slice. +		err = containsSerializable(testDeliveries, dlv) +		suite.NoError(err) + +		// Check that delivery request context has instance account pubkey. +		pubKeyID := gtscontext.OutgoingPublicKeyID(dlv.Request.Context()) +		suite.Equal(instanceAcc.PublicKeyURI, pubKeyID) +		signfn := gtscontext.HTTPClientSignFunc(dlv.Request.Context()) +		suite.NotNil(signfn) +	} + +	for { +		// Pop all queued federator messages from worker queue. +		msg, ok := suite.state.Workers.Federator.Queue.Pop() +		if !ok { +			break +		} + +		// Incr count. +		nfederator++ + +		// Check that we have this message in slice. +		err = containsSerializable(testFederatorMsgs, msg) +		suite.NoError(err) +	} + +	for { +		// Pop all queued client messages from worker queue. +		msg, ok := suite.state.Workers.Client.Queue.Pop() +		if !ok { +			break +		} + +		// Incr count. +		nclient++ + +		// Check that we have this message in slice. +		err = containsSerializable(testClientMsgs, msg) +		suite.NoError(err) +	} + +	// Ensure recovered task counts as expected. +	suite.Equal(len(testDeliveries), ndelivery) +	suite.Equal(len(testFederatorMsgs), nfederator) +	suite.Equal(len(testClientMsgs), nclient) +} + +func (suite *WorkerTaskTestSuite) TestPersistWorkerQueues() { +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// Push all test worker tasks to their respective queues. +	suite.state.Workers.Delivery.Queue.Push(testDeliveries...) +	suite.state.Workers.Federator.Queue.Push(testFederatorMsgs...) +	suite.state.Workers.Client.Queue.Push(testClientMsgs...) + +	// Persist the worker queued tasks to database. +	err := suite.adminProcessor.PersistWorkerQueues(ctx) +	suite.NoError(err) + +	// Fetch all the persisted tasks from database. +	tasks, err := suite.state.DB.GetWorkerTasks(ctx) +	suite.NoError(err) + +	var ( +		// Persisted +		// task counts. +		ndelivery  int +		nfederator int +		nclient    int +	) + +	// Check persisted task data. +	for _, task := range tasks { +		switch task.WorkerType { +		case gtsmodel.DeliveryWorker: +			var dlv delivery.Delivery + +			// Incr count. +			ndelivery++ + +			// Deserialize the persisted task data. +			err := dlv.Deserialize(task.TaskData) +			suite.NoError(err) + +			// Check that we have this delivery in slice. +			err = containsSerializable(testDeliveries, &dlv) +			suite.NoError(err) + +		case gtsmodel.FederatorWorker: +			var msg messages.FromFediAPI + +			// Incr count. +			nfederator++ + +			// Deserialize the persisted task data. +			err := msg.Deserialize(task.TaskData) +			suite.NoError(err) + +			// Check that we have this message in slice. +			err = containsSerializable(testFederatorMsgs, &msg) +			suite.NoError(err) + +		case gtsmodel.ClientWorker: +			var msg messages.FromClientAPI + +			// Incr count. +			nclient++ + +			// Deserialize the persisted task data. +			err := msg.Deserialize(task.TaskData) +			suite.NoError(err) + +			// Check that we have this message in slice. +			err = containsSerializable(testClientMsgs, &msg) +			suite.NoError(err) + +		default: +			suite.T().Errorf("unexpected worker type: %d", task.WorkerType) +		} +	} + +	// Ensure persisted task counts as expected. +	suite.Equal(len(testDeliveries), ndelivery) +	suite.Equal(len(testFederatorMsgs), nfederator) +	suite.Equal(len(testClientMsgs), nclient) +} + +func (suite *WorkerTaskTestSuite) SetupTest() { +	suite.AdminStandardTestSuite.SetupTest() +	// we don't want workers running +	testrig.StopWorkers(&suite.state) +} + +func TestWorkerTaskTestSuite(t *testing.T) { +	suite.Run(t, new(WorkerTaskTestSuite)) +} + +// containsSerializeable returns whether slice of serializables contains given serializable entry. +func containsSerializable[T interface{ Serialize() ([]byte, error) }](expect []T, have T) error { +	// Serialize wanted value. +	bh, err := have.Serialize() +	if err != nil { +		panic(err) +	} + +	var strings []string + +	for _, t := range expect { +		// Serialize expected value. +		be, err := t.Serialize() +		if err != nil { +			panic(err) +		} + +		// Alloc as string. +		se := string(be) + +		if se == string(bh) { +			// We have this entry! +			return nil +		} + +		// Add to serialized strings. +		strings = append(strings, se) +	} + +	return fmt.Errorf("could not find %s in %s", string(bh), strings) +} + +// urlStr simply returns u.String() or "" if nil. +func urlStr(u *url.URL) string { +	if u == nil { +		return "" +	} +	return u.String() +} + +// accountID simply returns account.ID or "" if nil. +func accountID(account *gtsmodel.Account) string { +	if account == nil { +		return "" +	} +	return account.ID +} + +// toRequest creates httpclient.Request from HTTP method, URL and body data. +func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request { +	var rbody io.Reader +	if body != nil { +		rbody = bytes.NewReader(body) +	} +	req, err := http.NewRequest(method, url, rbody) +	if err != nil { +		panic(err) +	} +	for key, values := range hdr { +		for _, value := range values { +			req.Header.Add(key, value) +		} +	} +	return httpclient.WrapRequest(req) +} + +// toJSON marshals input type as JSON data. +func toJSON(a any) []byte { +	b, err := json.Marshal(a) +	if err != nil { +		panic(err) +	} +	return b +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index 30435b86f..36ad6f015 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -21,6 +21,7 @@ import (  	"bytes"  	"context"  	"encoding/json" +	"io"  	"net/http"  	"net/url" @@ -169,6 +170,38 @@ func (t *transport) prepare(  	}, nil  } +func (t *transport) SignDelivery(dlv *delivery.Delivery) error { +	if dlv.Request.GetBody == nil { +		return gtserror.New("delivery request body not rewindable") +	} + +	// Get a new copy of the request body. +	body, err := dlv.Request.GetBody() +	if err != nil { +		return gtserror.Newf("error getting request body: %w", err) +	} + +	// Read body data into memory. +	data, err := io.ReadAll(body) +	if err != nil { +		return gtserror.Newf("error reading request body: %w", err) +	} + +	// Get signing function for POST data. +	// (note that delivery is ALWAYS POST). +	sign := t.signPOST(data) + +	// Extract delivery context. +	ctx := dlv.Request.Context() + +	// Update delivery request context with signing details. +	ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID) +	ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign) +	dlv.Request.Request = dlv.Request.Request.WithContext(ctx) + +	return nil +} +  // getObjectID extracts an object ID from 'serialized' ActivityPub object map.  func getObjectID(obj map[string]interface{}) string {  	switch t := obj["object"].(type) { diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go index 1e3ebb054..e11eea83c 100644 --- a/internal/transport/delivery/delivery.go +++ b/internal/transport/delivery/delivery.go @@ -33,10 +33,6 @@ import (  // be indexed (and so, dropped from queue)  // by any of these possible ID IRIs.  type Delivery struct { -	// PubKeyID is the signing public key -	// ID of the actor performing request. -	PubKeyID string -  	// ActorID contains the ActivityPub  	// actor ID IRI (if any) of the activity  	// being sent out by this request. @@ -55,7 +51,7 @@ type Delivery struct {  	// Request is the prepared (+ wrapped)  	// httpclient.Client{} request that  	// constitutes this ActivtyPub delivery. -	Request httpclient.Request +	Request *httpclient.Request  	// internal fields.  	next time.Time @@ -66,7 +62,6 @@ type Delivery struct {  // a json serialize / deserialize  // able shape that minimizes data.  type delivery struct { -	PubKeyID string              `json:"pub_key_id,omitempty"`  	ActorID  string              `json:"actor_id,omitempty"`  	ObjectID string              `json:"object_id,omitempty"`  	TargetID string              `json:"target_id,omitempty"` @@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) {  	// Marshal as internal JSON type.  	return json.Marshal(delivery{ -		PubKeyID: dlv.PubKeyID,  		ActorID:  dlv.ActorID,  		ObjectID: dlv.ObjectID,  		TargetID: dlv.TargetID, @@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error {  	}  	// Copy over simplest fields. -	dlv.PubKeyID = idlv.PubKeyID  	dlv.ActorID = idlv.ActorID  	dlv.ObjectID = idlv.ObjectID  	dlv.TargetID = idlv.TargetID @@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error {  		return err  	} +	// Copy over any stored header values. +	for key, values := range idlv.Header { +		for _, value := range values { +			r.Header.Add(key, value) +		} +	} +  	// Wrap request in httpclient type.  	dlv.Request = httpclient.WrapRequest(r) diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go index e9eaf8fd1..81f32d5f8 100644 --- a/internal/transport/delivery/delivery_test.go +++ b/internal/transport/delivery/delivery_test.go @@ -35,32 +35,30 @@ var deliveryCases = []struct {  }{  	{  		msg: delivery.Delivery{ -			PubKeyID: "https://google.com/users/bigboy#pubkey",  			ActorID:  "https://google.com/users/bigboy",  			ObjectID: "https://google.com/users/bigboy/follow/1",  			TargetID: "https://askjeeves.com/users/smallboy", -			Request:  toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")), +			Request:  toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}),  		},  		data: toJSON(map[string]any{ -			"pub_key_id": "https://google.com/users/bigboy#pubkey", -			"actor_id":   "https://google.com/users/bigboy", -			"object_id":  "https://google.com/users/bigboy/follow/1", -			"target_id":  "https://askjeeves.com/users/smallboy", -			"method":     "POST", -			"url":        "https://askjeeves.com/users/smallboy/inbox", -			"body":       []byte("data!"), -			// "header":     map[string][]string{}, +			"actor_id":  "https://google.com/users/bigboy", +			"object_id": "https://google.com/users/bigboy/follow/1", +			"target_id": "https://askjeeves.com/users/smallboy", +			"method":    "POST", +			"url":       "https://askjeeves.com/users/smallboy/inbox", +			"body":      []byte("data!"), +			"header":    map[string][]string{"Hello": {"world1", "world2"}},  		}),  	},  	{  		msg: delivery.Delivery{ -			Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")), +			Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil),  		},  		data: toJSON(map[string]any{  			"method": "GET",  			"url":    "https://google.com",  			"body":   []byte("uwu im just a wittle seawch engwin"), -			// "header":     map[string][]string{}, +			// "header": map[string][]string{},  		}),  	},  } @@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) {  		}  		// Check that delivery fields are as expected. -		assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID)  		assert.Equal(t, test.msg.ActorID, msg.ActorID)  		assert.Equal(t, test.msg.ObjectID, msg.ObjectID)  		assert.Equal(t, test.msg.TargetID, msg.TargetID)  		assert.Equal(t, test.msg.Request.Method, msg.Request.Method)  		assert.Equal(t, test.msg.Request.URL, msg.Request.URL)  		assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body)) +		assert.Equal(t, test.msg.Request.Header, msg.Request.Header)  	}  }  // toRequest creates httpclient.Request from HTTP method, URL and body data. -func toRequest(method string, url string, body []byte) httpclient.Request { +func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {  	var rbody io.Reader  	if body != nil {  		rbody = bytes.NewReader(body) @@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request {  	if err != nil {  		panic(err)  	} +	for key, values := range hdr { +		for _, value := range values { +			req.Header.Add(key, value) +		} +	}  	return httpclient.WrapRequest(req)  } diff --git a/internal/transport/delivery/worker.go b/internal/transport/delivery/worker.go index ef31e94a6..d6d253769 100644 --- a/internal/transport/delivery/worker.go +++ b/internal/transport/delivery/worker.go @@ -19,6 +19,7 @@ package delivery  import (  	"context" +	"errors"  	"slices"  	"time" @@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool {  loop:  	for { +		// Before trying to get +		// next delivery, check +		// context still valid. +		if ctx.Err() != nil { +			return true +		} +  		// Get next delivery.  		dlv, ok := w.next(ctx)  		if !ok { @@ -195,16 +203,30 @@ loop:  		// Attempt delivery of AP request.  		rsp, retry, err := w.Client.DoOnce( -			&dlv.Request, +			dlv.Request,  		) -		if err == nil { +		switch { +		case err == nil:  			// Ensure body closed.  			_ = rsp.Body.Close()  			continue loop -		} -		if !retry { +		case errors.Is(err, context.Canceled) && +			ctx.Err() != nil: +			// In the case of our own context +			// being cancelled, push delivery +			// back onto queue for persisting. +			// +			// Note we specifically check against +			// context.Canceled here as it will +			// be faster than the mutex lock of +			// ctx.Err(), so gives an initial +			// faster check in the if-clause. +			w.Queue.Push(dlv) +			continue loop + +		case !retry:  			// Drop deliveries when no  			// retry requested, or they  			// reached max (either). @@ -222,42 +244,36 @@ loop:  // next gets the next available delivery, blocking until available if necessary.  func (w *Worker) next(ctx context.Context) (*Delivery, bool) { -loop: -	for { -		// Try pop next queued. -		dlv, ok := w.Queue.Pop() +	// Try a fast-pop of queued +	// delivery before anything. +	dlv, ok := w.Queue.Pop() -		if !ok { -			// Check the backlog. -			if len(w.backlog) > 0 { +	if !ok { +		// Check the backlog. +		if len(w.backlog) > 0 { -				// Sort by 'next' time. -				sortDeliveries(w.backlog) +			// Sort by 'next' time. +			sortDeliveries(w.backlog) -				// Pop next delivery. -				dlv := w.popBacklog() +			// Pop next delivery. +			dlv := w.popBacklog() -				return dlv, true -			} - -			select { -			// Backlog is empty, we MUST -			// block until next enqueued. -			case <-w.Queue.Wait(): -				continue loop +			return dlv, true +		} -			// Worker was stopped. -			case <-ctx.Done(): -				return nil, false -			} +		// Block on next delivery push +		// OR worker context canceled. +		dlv, ok = w.Queue.PopCtx(ctx) +		if !ok { +			return nil, false  		} +	} -		// Replace request context for worker state canceling. -		ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) -		dlv.Request.Request = dlv.Request.Request.WithContext(ctx) +	// Replace request context for worker state canceling. +	ctx = gtscontext.WithValues(ctx, dlv.Request.Context()) +	dlv.Request.Request = dlv.Request.Request.WithContext(ctx) -		return dlv, true -	} +	return dlv, true  }  // popBacklog pops next available from the backlog. diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 2971ca603..7f7e985fc 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -30,6 +30,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/httpclient" +	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"  	"github.com/superseriousbusiness/httpsig"  ) @@ -50,6 +51,10 @@ type Transport interface {  	// transport client, retrying on certain preset errors.  	POST(*http.Request, []byte) (*http.Response, error) +	// SignDelivery adds HTTP request signing client "middleware" +	// to the request context within given delivery.Delivery{}. +	SignDelivery(*delivery.Delivery) error +  	// Deliver sends an ActivityStreams object.  	Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error diff --git a/internal/workers/worker_msg.go b/internal/workers/worker_msg.go index 92180651a..c7dc568d7 100644 --- a/internal/workers/worker_msg.go +++ b/internal/workers/worker_msg.go @@ -19,6 +19,7 @@ package workers  import (  	"context" +	"errors"  	"codeberg.org/gruf/go-runners"  	"codeberg.org/gruf/go-structr" @@ -147,9 +148,25 @@ func (w *MsgWorker[T]) process(ctx context.Context) {  			return  		} -		// Attempt to process popped message type. -		if err := w.Process(ctx, msg); err != nil { +		// Attempt to process message. +		err := w.Process(ctx, msg) +		if err != nil {  			log.Errorf(ctx, "%p: error processing: %v", w, err) + +			if errors.Is(err, context.Canceled) && +				ctx.Err() != nil { +				// In the case of our own context +				// being cancelled, push message +				// back onto queue for persisting. +				// +				// Note we specifically check against +				// context.Canceled here as it will +				// be faster than the mutex lock of +				// ctx.Err(), so gives an initial +				// faster check in the if-clause. +				w.Queue.Push(msg) +				break +			}  		}  	}  } diff --git a/internal/workers/workers.go b/internal/workers/workers.go index 4d2b146b6..377a9d899 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -55,7 +55,8 @@ type Workers struct {  // StartScheduler starts the job scheduler.  func (w *Workers) StartScheduler() { -	_ = w.Scheduler.Start() // false = already running +	_ = w.Scheduler.Start() +	// false = already running  	log.Info(nil, "started scheduler")  } @@ -82,9 +83,12 @@ func (w *Workers) Start() {  	log.Infof(nil, "started %d dereference workers", n)  } -// Stop will stop all of the contained worker pools (and global scheduler). +// Stop will stop all of the contained +// worker pools (and global scheduler).  func (w *Workers) Stop() { -	_ = w.Scheduler.Stop() // false = not running +	_ = w.Scheduler.Stop() +	// false = not running +	log.Info(nil, "stopped scheduler")  	w.Delivery.Stop()  	log.Info(nil, "stopped delivery workers") diff --git a/testrig/db.go b/testrig/db.go index 67a7e2439..e6b40c846 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -29,6 +29,8 @@ import (  var testModels = []interface{}{  	>smodel.Account{}, +	>smodel.AccountNote{}, +	>smodel.AccountSettings{},  	>smodel.AccountToEmoji{},  	>smodel.Application{},  	>smodel.Block{}, @@ -67,8 +69,7 @@ var testModels = []interface{}{  	>smodel.Tombstone{},  	>smodel.Report{},  	>smodel.Rule{}, -	>smodel.AccountNote{}, -	>smodel.AccountSettings{}, +	>smodel.WorkerTask{},  }  // NewTestDB returns a new initialized, empty database for testing.  | 
