diff options
Diffstat (limited to 'internal/processing/admin')
-rw-r--r-- | internal/processing/admin/workertask.go | 426 | ||||
-rw-r--r-- | internal/processing/admin/workertask_test.go | 421 |
2 files changed, 847 insertions, 0 deletions
diff --git a/internal/processing/admin/workertask.go b/internal/processing/admin/workertask.go new file mode 100644 index 000000000..6d7cc7b7a --- /dev/null +++ b/internal/processing/admin/workertask.go @@ -0,0 +1,426 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package admin + +import ( + "context" + "fmt" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/transport" + "github.com/superseriousbusiness/gotosocial/internal/transport/delivery" +) + +// NOTE: +// Having these functions in the processor, which is +// usually the intermediary that performs *processing* +// between the HTTP route handlers and the underlying +// database / storage layers is a little odd, so this +// may be subject to change! +// +// For now at least, this is a useful place that has +// access to the underlying database, workers and +// causes no dependency cycles with this use case! + +// FillWorkerQueues recovers all serialized worker tasks from the database +// (if any!), and pushes them to each of their relevant worker queues. +func (p *Processor) FillWorkerQueues(ctx context.Context) error { + log.Info(ctx, "rehydrate!") + + // Get all persisted worker tasks from db. + // + // (database returns these as ASCENDING, i.e. + // returned in the order they were inserted). + tasks, err := p.state.DB.GetWorkerTasks(ctx) + if err != nil { + return gtserror.Newf("error fetching worker tasks from db: %w", err) + } + + var ( + // Counts of each task type + // successfully recovered. + delivery int + federator int + client int + + // Failed recoveries. + errors int + ) + +loop: + + // Handle each persisted task, removing + // all those we can't handle. Leaving us + // with a slice of tasks we can safely + // delete from being persisted in the DB. + for i := 0; i < len(tasks); { + var err error + + // Task at index. + task := tasks[i] + + // Appropriate task count + // pointer to increment. + var counter *int + + // Attempt to recovery persisted + // task depending on worker type. + switch task.WorkerType { + case gtsmodel.DeliveryWorker: + err = p.pushDelivery(ctx, task) + counter = &delivery + case gtsmodel.FederatorWorker: + err = p.pushFederator(ctx, task) + counter = &federator + case gtsmodel.ClientWorker: + err = p.pushClient(ctx, task) + counter = &client + default: + err = fmt.Errorf("invalid worker type %d", task.WorkerType) + } + + if err != nil { + log.Errorf(ctx, "error pushing task %d: %v", task.ID, err) + + // Drop error'd task from slice. + tasks = slices.Delete(tasks, i, i+1) + + // Incr errors. + errors++ + continue loop + } + + // Increment slice + // index & counter. + (*counter)++ + i++ + } + + // Tasks that worker successfully pushed + // to their appropriate workers, we can + // safely now remove from the database. + for _, task := range tasks { + if err := p.state.DB.DeleteWorkerTaskByID(ctx, task.ID); err != nil { + log.Errorf(ctx, "error deleting task from db: %v", err) + } + } + + // Log recovered tasks. + log.WithContext(ctx). + WithField("delivery", delivery). + WithField("federator", federator). + WithField("client", client). + WithField("errors", errors). + Info("recovered queued tasks") + + return nil +} + +// PersistWorkerQueues pops all queued worker tasks (that are themselves persistable, i.e. not +// dereference tasks which are just function ptrs), serializes and persists them to the database. +func (p *Processor) PersistWorkerQueues(ctx context.Context) error { + log.Info(ctx, "dehydrate!") + + var ( + // Counts of each task type + // successfully persisted. + delivery int + federator int + client int + + // Failed persists. + errors int + + // Serialized tasks to persist. + tasks []*gtsmodel.WorkerTask + ) + + for { + // Pop all queued deliveries. + task, err := p.popDelivery() + if err != nil { + log.Errorf(ctx, "error popping delivery: %v", err) + errors++ // incr error count. + continue + } + + if task == nil { + // No more queue + // tasks to pop! + break + } + + // Append serialized task. + tasks = append(tasks, task) + delivery++ // incr count + } + + for { + // Pop queued federator msgs. + task, err := p.popFederator() + if err != nil { + log.Errorf(ctx, "error popping federator message: %v", err) + errors++ // incr count + continue + } + + if task == nil { + // No more queue + // tasks to pop! + break + } + + // Append serialized task. + tasks = append(tasks, task) + federator++ // incr count + } + + for { + // Pop queued client msgs. + task, err := p.popClient() + if err != nil { + log.Errorf(ctx, "error popping client message: %v", err) + continue + } + + if task == nil { + // No more queue + // tasks to pop! + break + } + + // Append serialized task. + tasks = append(tasks, task) + client++ // incr count + } + + // Persist all serialized queued worker tasks to database. + if err := p.state.DB.PutWorkerTasks(ctx, tasks); err != nil { + return gtserror.Newf("error putting tasks in db: %w", err) + } + + // Log recovered tasks. + log.WithContext(ctx). + WithField("delivery", delivery). + WithField("federator", federator). + WithField("client", client). + WithField("errors", errors). + Info("persisted queued tasks") + + return nil +} + +// pushDelivery parses a valid delivery.Delivery{} from serialized task data and pushes to queue. +func (p *Processor) pushDelivery(ctx context.Context, task *gtsmodel.WorkerTask) error { + dlv := new(delivery.Delivery) + + // Deserialize the raw worker task data into delivery. + if err := dlv.Deserialize(task.TaskData); err != nil { + return gtserror.Newf("error deserializing delivery: %w", err) + } + + var tsport transport.Transport + + if uri := dlv.ActorID; uri != "" { + // Fetch the actor account by provided URI from db. + account, err := p.state.DB.GetAccountByURI(ctx, uri) + if err != nil { + return gtserror.Newf("error getting actor account %s from db: %w", uri, err) + } + + // Fetch a transport for request signing for actor's account username. + tsport, err = p.transport.NewTransportForUsername(ctx, account.Username) + if err != nil { + return gtserror.Newf("error getting transport for actor %s: %w", uri, err) + } + } else { + var err error + + // No actor was given, will be signed by instance account. + tsport, err = p.transport.NewTransportForUsername(ctx, "") + if err != nil { + return gtserror.Newf("error getting instance account transport: %w", err) + } + } + + // Using transport, add actor signature to delivery. + if err := tsport.SignDelivery(dlv); err != nil { + return gtserror.Newf("error signing delivery: %w", err) + } + + // Push deserialized task to delivery queue. + p.state.Workers.Delivery.Queue.Push(dlv) + + return nil +} + +// popDelivery pops delivery.Delivery{} from queue and serializes as valid task data. +func (p *Processor) popDelivery() (*gtsmodel.WorkerTask, error) { + + // Pop waiting delivery from the delivery worker. + delivery, ok := p.state.Workers.Delivery.Queue.Pop() + if !ok { + return nil, nil + } + + // Serialize the delivery task data. + data, err := delivery.Serialize() + if err != nil { + return nil, gtserror.Newf("error serializing delivery: %w", err) + } + + return >smodel.WorkerTask{ + // ID is autoincrement + WorkerType: gtsmodel.DeliveryWorker, + TaskData: data, + CreatedAt: time.Now(), + }, nil +} + +// pushClient parses a valid messages.FromFediAPI{} from serialized task data and pushes to queue. +func (p *Processor) pushFederator(ctx context.Context, task *gtsmodel.WorkerTask) error { + var msg messages.FromFediAPI + + // Deserialize the raw worker task data into message. + if err := msg.Deserialize(task.TaskData); err != nil { + return gtserror.Newf("error deserializing federator message: %w", err) + } + + if rcv := msg.Receiving; rcv != nil { + // Only a placeholder receiving account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, rcv.ID) + if err != nil { + return gtserror.Newf("error fetching receiving account %s from db: %w", rcv.ID, err) + } + + // Set the now populated + // receiving account model. + msg.Receiving = account + } + + if req := msg.Requesting; req != nil { + // Only a placeholder requesting account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, req.ID) + if err != nil { + return gtserror.Newf("error fetching requesting account %s from db: %w", req.ID, err) + } + + // Set the now populated + // requesting account model. + msg.Requesting = account + } + + // Push populated task to the federator queue. + p.state.Workers.Federator.Queue.Push(&msg) + + return nil +} + +// popFederator pops messages.FromFediAPI{} from queue and serializes as valid task data. +func (p *Processor) popFederator() (*gtsmodel.WorkerTask, error) { + + // Pop waiting message from the federator worker. + msg, ok := p.state.Workers.Federator.Queue.Pop() + if !ok { + return nil, nil + } + + // Serialize message task data. + data, err := msg.Serialize() + if err != nil { + return nil, gtserror.Newf("error serializing federator message: %w", err) + } + + return >smodel.WorkerTask{ + // ID is autoincrement + WorkerType: gtsmodel.FederatorWorker, + TaskData: data, + CreatedAt: time.Now(), + }, nil +} + +// pushClient parses a valid messages.FromClientAPI{} from serialized task data and pushes to queue. +func (p *Processor) pushClient(ctx context.Context, task *gtsmodel.WorkerTask) error { + var msg messages.FromClientAPI + + // Deserialize the raw worker task data into message. + if err := msg.Deserialize(task.TaskData); err != nil { + return gtserror.Newf("error deserializing client message: %w", err) + } + + if org := msg.Origin; org != nil { + // Only a placeholder origin account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, org.ID) + if err != nil { + return gtserror.Newf("error fetching origin account %s from db: %w", org.ID, err) + } + + // Set the now populated + // origin account model. + msg.Origin = account + } + + if trg := msg.Target; trg != nil { + // Only a placeholder target account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, trg.ID) + if err != nil { + return gtserror.Newf("error fetching target account %s from db: %w", trg.ID, err) + } + + // Set the now populated + // target account model. + msg.Target = account + } + + // Push populated task to the federator queue. + p.state.Workers.Client.Queue.Push(&msg) + + return nil +} + +// popClient pops messages.FromClientAPI{} from queue and serializes as valid task data. +func (p *Processor) popClient() (*gtsmodel.WorkerTask, error) { + + // Pop waiting message from the client worker. + msg, ok := p.state.Workers.Client.Queue.Pop() + if !ok { + return nil, nil + } + + // Serialize message task data. + data, err := msg.Serialize() + if err != nil { + return nil, gtserror.Newf("error serializing client message: %w", err) + } + + return >smodel.WorkerTask{ + // ID is autoincrement + WorkerType: gtsmodel.ClientWorker, + TaskData: data, + CreatedAt: time.Now(), + }, nil +} diff --git a/internal/processing/admin/workertask_test.go b/internal/processing/admin/workertask_test.go new file mode 100644 index 000000000..bf326bafd --- /dev/null +++ b/internal/processing/admin/workertask_test.go @@ -0,0 +1,421 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package admin_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/httpclient" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/transport/delivery" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +var ( + // TODO: move these test values into + // the testrig test models area. They'll + // need to be as both WorkerTask and as + // the raw types themselves. + + testDeliveries = []*delivery.Delivery{ + { + ObjectID: "https://google.com/users/bigboy/follow/1", + TargetID: "https://askjeeves.com/users/smallboy", + Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Host": {"https://askjeeves.com"}}), + }, + { + Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), http.Header{"Host": {"https://google.com"}}), + }, + } + + testFederatorMsgs = []*messages.FromFediAPI{ + { + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + TargetURI: "https://gotosocial.org", + Requesting: >smodel.Account{ID: "654321"}, + Receiving: >smodel.Account{ID: "123456"}, + }, + { + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityUpdate, + TargetURI: "https://uk-queen-is-dead.org", + Requesting: >smodel.Account{ID: "123456"}, + Receiving: >smodel.Account{ID: "654321"}, + }, + } + + testClientMsgs = []*messages.FromClientAPI{ + { + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + TargetURI: "https://gotosocial.org", + Origin: >smodel.Account{ID: "654321"}, + Target: >smodel.Account{ID: "123456"}, + }, + { + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityUpdate, + TargetURI: "https://uk-queen-is-dead.org", + Origin: >smodel.Account{ID: "123456"}, + Target: >smodel.Account{ID: "654321"}, + }, + } +) + +type WorkerTaskTestSuite struct { + AdminStandardTestSuite +} + +func (suite *WorkerTaskTestSuite) TestFillWorkerQueues() { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + var tasks []*gtsmodel.WorkerTask + + for _, dlv := range testDeliveries { + // Serialize all test deliveries. + data, err := dlv.Serialize() + if err != nil { + panic(err) + } + + // Append each serialized delivery to tasks. + tasks = append(tasks, >smodel.WorkerTask{ + WorkerType: gtsmodel.DeliveryWorker, + TaskData: data, + }) + } + + for _, msg := range testFederatorMsgs { + // Serialize all test messages. + data, err := msg.Serialize() + if err != nil { + panic(err) + } + + if msg.Receiving != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Receiving) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Receiving.ID, + }) + } + + if msg.Requesting != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Requesting) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Requesting.ID, + }) + } + + // Append each serialized message to tasks. + tasks = append(tasks, >smodel.WorkerTask{ + WorkerType: gtsmodel.FederatorWorker, + TaskData: data, + }) + } + + for _, msg := range testClientMsgs { + // Serialize all test messages. + data, err := msg.Serialize() + if err != nil { + panic(err) + } + + if msg.Origin != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Origin) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Origin.ID, + }) + } + + if msg.Target != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Target) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Target.ID, + }) + } + + // Append each serialized message to tasks. + tasks = append(tasks, >smodel.WorkerTask{ + WorkerType: gtsmodel.ClientWorker, + TaskData: data, + }) + } + + // Persist all test worker tasks to the database. + err := suite.state.DB.PutWorkerTasks(ctx, tasks) + suite.NoError(err) + + // Fill the worker queues from persisted task data. + err = suite.adminProcessor.FillWorkerQueues(ctx) + suite.NoError(err) + + var ( + // Recovered + // task counts. + ndelivery int + nfederator int + nclient int + ) + + // Fetch current gotosocial instance account, for later checks. + instanceAcc, err := suite.state.DB.GetInstanceAccount(ctx, "") + suite.NoError(err) + + for { + // Pop all queued delivery tasks from worker queue. + dlv, ok := suite.state.Workers.Delivery.Queue.Pop() + if !ok { + break + } + + // Incr count. + ndelivery++ + + // Check that we have this message in slice. + err = containsSerializable(testDeliveries, dlv) + suite.NoError(err) + + // Check that delivery request context has instance account pubkey. + pubKeyID := gtscontext.OutgoingPublicKeyID(dlv.Request.Context()) + suite.Equal(instanceAcc.PublicKeyURI, pubKeyID) + signfn := gtscontext.HTTPClientSignFunc(dlv.Request.Context()) + suite.NotNil(signfn) + } + + for { + // Pop all queued federator messages from worker queue. + msg, ok := suite.state.Workers.Federator.Queue.Pop() + if !ok { + break + } + + // Incr count. + nfederator++ + + // Check that we have this message in slice. + err = containsSerializable(testFederatorMsgs, msg) + suite.NoError(err) + } + + for { + // Pop all queued client messages from worker queue. + msg, ok := suite.state.Workers.Client.Queue.Pop() + if !ok { + break + } + + // Incr count. + nclient++ + + // Check that we have this message in slice. + err = containsSerializable(testClientMsgs, msg) + suite.NoError(err) + } + + // Ensure recovered task counts as expected. + suite.Equal(len(testDeliveries), ndelivery) + suite.Equal(len(testFederatorMsgs), nfederator) + suite.Equal(len(testClientMsgs), nclient) +} + +func (suite *WorkerTaskTestSuite) TestPersistWorkerQueues() { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Push all test worker tasks to their respective queues. + suite.state.Workers.Delivery.Queue.Push(testDeliveries...) + suite.state.Workers.Federator.Queue.Push(testFederatorMsgs...) + suite.state.Workers.Client.Queue.Push(testClientMsgs...) + + // Persist the worker queued tasks to database. + err := suite.adminProcessor.PersistWorkerQueues(ctx) + suite.NoError(err) + + // Fetch all the persisted tasks from database. + tasks, err := suite.state.DB.GetWorkerTasks(ctx) + suite.NoError(err) + + var ( + // Persisted + // task counts. + ndelivery int + nfederator int + nclient int + ) + + // Check persisted task data. + for _, task := range tasks { + switch task.WorkerType { + case gtsmodel.DeliveryWorker: + var dlv delivery.Delivery + + // Incr count. + ndelivery++ + + // Deserialize the persisted task data. + err := dlv.Deserialize(task.TaskData) + suite.NoError(err) + + // Check that we have this delivery in slice. + err = containsSerializable(testDeliveries, &dlv) + suite.NoError(err) + + case gtsmodel.FederatorWorker: + var msg messages.FromFediAPI + + // Incr count. + nfederator++ + + // Deserialize the persisted task data. + err := msg.Deserialize(task.TaskData) + suite.NoError(err) + + // Check that we have this message in slice. + err = containsSerializable(testFederatorMsgs, &msg) + suite.NoError(err) + + case gtsmodel.ClientWorker: + var msg messages.FromClientAPI + + // Incr count. + nclient++ + + // Deserialize the persisted task data. + err := msg.Deserialize(task.TaskData) + suite.NoError(err) + + // Check that we have this message in slice. + err = containsSerializable(testClientMsgs, &msg) + suite.NoError(err) + + default: + suite.T().Errorf("unexpected worker type: %d", task.WorkerType) + } + } + + // Ensure persisted task counts as expected. + suite.Equal(len(testDeliveries), ndelivery) + suite.Equal(len(testFederatorMsgs), nfederator) + suite.Equal(len(testClientMsgs), nclient) +} + +func (suite *WorkerTaskTestSuite) SetupTest() { + suite.AdminStandardTestSuite.SetupTest() + // we don't want workers running + testrig.StopWorkers(&suite.state) +} + +func TestWorkerTaskTestSuite(t *testing.T) { + suite.Run(t, new(WorkerTaskTestSuite)) +} + +// containsSerializeable returns whether slice of serializables contains given serializable entry. +func containsSerializable[T interface{ Serialize() ([]byte, error) }](expect []T, have T) error { + // Serialize wanted value. + bh, err := have.Serialize() + if err != nil { + panic(err) + } + + var strings []string + + for _, t := range expect { + // Serialize expected value. + be, err := t.Serialize() + if err != nil { + panic(err) + } + + // Alloc as string. + se := string(be) + + if se == string(bh) { + // We have this entry! + return nil + } + + // Add to serialized strings. + strings = append(strings, se) + } + + return fmt.Errorf("could not find %s in %s", string(bh), strings) +} + +// urlStr simply returns u.String() or "" if nil. +func urlStr(u *url.URL) string { + if u == nil { + return "" + } + return u.String() +} + +// accountID simply returns account.ID or "" if nil. +func accountID(account *gtsmodel.Account) string { + if account == nil { + return "" + } + return account.ID +} + +// toRequest creates httpclient.Request from HTTP method, URL and body data. +func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request { + var rbody io.Reader + if body != nil { + rbody = bytes.NewReader(body) + } + req, err := http.NewRequest(method, url, rbody) + if err != nil { + panic(err) + } + for key, values := range hdr { + for _, value := range values { + req.Header.Add(key, value) + } + } + return httpclient.WrapRequest(req) +} + +// toJSON marshals input type as JSON data. +func toJSON(a any) []byte { + b, err := json.Marshal(a) + if err != nil { + panic(err) + } + return b +} |