diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/gtsmodel/workertask.go | 41 | ||||
-rw-r--r-- | internal/httpclient/transport.go (renamed from internal/httpclient/sign.go) | 7 | ||||
-rw-r--r-- | internal/messages/messages.go | 290 | ||||
-rw-r--r-- | internal/messages/messages_test.go | 292 | ||||
-rw-r--r-- | internal/transport/deliver.go | 19 | ||||
-rw-r--r-- | internal/transport/delivery/delivery.go | 331 | ||||
-rw-r--r-- | internal/transport/delivery/delivery_test.go | 265 | ||||
-rw-r--r-- | internal/transport/delivery/worker.go | 298 | ||||
-rw-r--r-- | internal/transport/delivery/worker_test.go | 220 |
9 files changed, 1328 insertions, 435 deletions
diff --git a/internal/gtsmodel/workertask.go b/internal/gtsmodel/workertask.go new file mode 100644 index 000000000..cc8433199 --- /dev/null +++ b/internal/gtsmodel/workertask.go @@ -0,0 +1,41 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package gtsmodel + +import "time" + +type WorkerType uint8 + +const ( + DeliveryWorker WorkerType = 1 + FederatorWorker WorkerType = 2 + ClientWorker WorkerType = 3 +) + +// WorkerTask represents a queued worker task +// that was persisted to the database on shutdown. +// This is only ever used on startup to pickup +// where we left off, and on shutdown to prevent +// queued tasks from being lost. It is simply a +// means to store a blob of serialized task data. +type WorkerTask struct { + ID uint `bun:""` + WorkerType uint8 `bun:""` + TaskData []byte `bun:""` + CreatedAt time.Time `bun:""` +} diff --git a/internal/httpclient/sign.go b/internal/httpclient/transport.go index eff20be49..350d24fab 100644 --- a/internal/httpclient/sign.go +++ b/internal/httpclient/transport.go @@ -21,7 +21,6 @@ import ( "net/http" "time" - "codeberg.org/gruf/go-byteutil" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" ) @@ -45,12 +44,6 @@ func (t *signingtransport) RoundTrip(r *http.Request) (*http.Response, error) { r.Header.Del("Signature") r.Header.Del("Digest") - // Rewind body reader and content-length if set. - if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { - rc.Rewind() // set len AFTER rewind - r.ContentLength = int64(rc.Len()) - } - // Sign the outgoing request. if err := sign(r); err != nil { return nil, err diff --git a/internal/messages/messages.go b/internal/messages/messages.go index c5488d586..7779633ba 100644 --- a/internal/messages/messages.go +++ b/internal/messages/messages.go @@ -18,9 +18,15 @@ package messages import ( + "context" + "encoding/json" "net/url" + "reflect" "codeberg.org/gruf/go-structr" + "github.com/superseriousbusiness/activity/streams" + "github.com/superseriousbusiness/activity/streams/vocab" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -50,6 +56,105 @@ type FromClientAPI struct { Target *gtsmodel.Account } +// fromClientAPI is an internal type +// for FromClientAPI that provides a +// json serialize / deserialize -able +// shape that minimizes required data. +type fromClientAPI struct { + APObjectType string `json:"ap_object_type,omitempty"` + APActivityType string `json:"ap_activity_type,omitempty"` + GTSModel json.RawMessage `json:"gts_model,omitempty"` + GTSModelType string `json:"gts_model_type,omitempty"` + TargetURI string `json:"target_uri,omitempty"` + OriginID string `json:"origin_id,omitempty"` + TargetID string `json:"target_id,omitempty"` +} + +// Serialize will serialize the worker data as data blob for storage, +// note that this will flatten some of the data e.g. only account IDs. +func (msg *FromClientAPI) Serialize() ([]byte, error) { + var ( + modelType string + originID string + targetID string + ) + + // Set database model type if any provided. + if t := reflect.TypeOf(msg.GTSModel); t != nil { + modelType = t.String() + } + + // Set origin account ID. + if msg.Origin != nil { + originID = msg.Origin.ID + } + + // Set target account ID. + if msg.Target != nil { + targetID = msg.Target.ID + } + + // Marshal GTS model as raw JSON block. + modelJSON, err := json.Marshal(msg.GTSModel) + if err != nil { + return nil, err + } + + // Marshal as internal JSON type. + return json.Marshal(fromClientAPI{ + APObjectType: msg.APObjectType, + APActivityType: msg.APActivityType, + GTSModel: modelJSON, + GTSModelType: modelType, + TargetURI: msg.TargetURI, + OriginID: originID, + TargetID: targetID, + }) +} + +// Deserialize will attempt to deserialize a blob of task data, +// which will involve unflattening previously serialized data and +// leave some message structures as placeholders to holding IDs. +func (msg *FromClientAPI) Deserialize(data []byte) error { + var imsg fromClientAPI + + // Unmarshal as internal JSON type. + err := json.Unmarshal(data, &imsg) + if err != nil { + return err + } + + // Copy over the simplest fields. + msg.APObjectType = imsg.APObjectType + msg.APActivityType = imsg.APActivityType + msg.TargetURI = imsg.TargetURI + + // Resolve Go type from JSON data. + msg.GTSModel, err = resolveGTSModel( + imsg.GTSModelType, + imsg.GTSModel, + ) + if err != nil { + return err + } + + if imsg.OriginID != "" { + // Set origin account ID using a + // barebones model (later filled in). + msg.Origin = new(gtsmodel.Account) + msg.Origin.ID = imsg.OriginID + } + + if imsg.TargetID != "" { + // Set target account ID using a + // barebones model (later filled in). + msg.Target = new(gtsmodel.Account) + msg.Target.ID = imsg.TargetID + } + + return nil +} + // ClientMsgIndices defines queue indices this // message type should be accessible / stored under. func ClientMsgIndices() []structr.IndexConfig { @@ -91,6 +196,133 @@ type FromFediAPI struct { Receiving *gtsmodel.Account } +// fromFediAPI is an internal type +// for FromFediAPI that provides a +// json serialize / deserialize -able +// shape that minimizes required data. +type fromFediAPI struct { + APObjectType string `json:"ap_object_type,omitempty"` + APActivityType string `json:"ap_activity_type,omitempty"` + APIRI string `json:"ap_iri,omitempty"` + APObject map[string]interface{} `json:"ap_object,omitempty"` + GTSModel json.RawMessage `json:"gts_model,omitempty"` + GTSModelType string `json:"gts_model_type,omitempty"` + TargetURI string `json:"target_uri,omitempty"` + RequestingID string `json:"requesting_id,omitempty"` + ReceivingID string `json:"receiving_id,omitempty"` +} + +// Serialize will serialize the worker data as data blob for storage, +// note that this will flatten some of the data e.g. only account IDs. +func (msg *FromFediAPI) Serialize() ([]byte, error) { + var ( + gtsModelType string + apIRI string + apObject map[string]interface{} + requestingID string + receivingID string + ) + + // Set AP IRI string. + if msg.APIRI != nil { + apIRI = msg.APIRI.String() + } + + // Set serialized AP object data if set. + if t, ok := msg.APObject.(vocab.Type); ok { + obj, err := t.Serialize() + if err != nil { + return nil, err + } + apObject = obj + } + + // Set database model type if any provided. + if t := reflect.TypeOf(msg.GTSModel); t != nil { + gtsModelType = t.String() + } + + // Set requesting account ID. + if msg.Requesting != nil { + requestingID = msg.Requesting.ID + } + + // Set receiving account ID. + if msg.Receiving != nil { + receivingID = msg.Receiving.ID + } + + // Marshal GTS model as raw JSON block. + modelJSON, err := json.Marshal(msg.GTSModel) + if err != nil { + return nil, err + } + + // Marshal as internal JSON type. + return json.Marshal(fromFediAPI{ + APObjectType: msg.APObjectType, + APActivityType: msg.APActivityType, + APIRI: apIRI, + APObject: apObject, + GTSModel: modelJSON, + GTSModelType: gtsModelType, + TargetURI: msg.TargetURI, + RequestingID: requestingID, + ReceivingID: receivingID, + }) +} + +// Deserialize will attempt to deserialize a blob of task data, +// which will involve unflattening previously serialized data and +// leave some message structures as placeholders to holding IDs. +func (msg *FromFediAPI) Deserialize(data []byte) error { + var imsg fromFediAPI + + // Unmarshal as internal JSON type. + err := json.Unmarshal(data, &imsg) + if err != nil { + return err + } + + // Copy over the simplest fields. + msg.APObjectType = imsg.APObjectType + msg.APActivityType = imsg.APActivityType + msg.TargetURI = imsg.TargetURI + + // Resolve AP object from JSON data. + msg.APObject, err = resolveAPObject( + imsg.APObject, + ) + if err != nil { + return err + } + + // Resolve Go type from JSON data. + msg.GTSModel, err = resolveGTSModel( + imsg.GTSModelType, + imsg.GTSModel, + ) + if err != nil { + return err + } + + if imsg.RequestingID != "" { + // Set requesting account ID using a + // barebones model (later filled in). + msg.Requesting = new(gtsmodel.Account) + msg.Requesting.ID = imsg.RequestingID + } + + if imsg.ReceivingID != "" { + // Set target account ID using a + // barebones model (later filled in). + msg.Receiving = new(gtsmodel.Account) + msg.Receiving.ID = imsg.ReceivingID + } + + return nil +} + // FederatorMsgIndices defines queue indices this // message type should be accessible / stored under. func FederatorMsgIndices() []structr.IndexConfig { @@ -101,3 +333,61 @@ func FederatorMsgIndices() []structr.IndexConfig { {Fields: "Receiving.ID", Multiple: true}, } } + +// resolveAPObject resolves an ActivityPub object from its "serialized" JSON map +// (yes the terminology here is weird, but that's how go-fed/activity is written). +func resolveAPObject(data map[string]interface{}) (interface{}, error) { + if len(data) == 0 { + // No data given. + return nil, nil + } + + // Resolve vocab.Type from "raw" input data map. + return streams.ToType(context.Background(), data) +} + +// resolveGTSModel is unfortunately where things get messy... our data is stored as JSON +// in the database, which serializes struct types as key-value pairs surrounded by curly +// braces. Deserializing from that gives us back a data blob of key-value pairs, which +// we then need to wrangle back into the original type. So we also store the type name +// and use this to determine the appropriate Go structure type to unmarshal into to. +func resolveGTSModel(typ string, data []byte) (interface{}, error) { + if typ == "" && data == nil { + // No data given. + return nil, nil + } + + var value interface{} + + switch typ { + case reflect.TypeOf((*gtsmodel.Account)(nil)).String(): + value = new(gtsmodel.Account) + case reflect.TypeOf((*gtsmodel.Block)(nil)).String(): + value = new(gtsmodel.Block) + case reflect.TypeOf((*gtsmodel.Follow)(nil)).String(): + value = new(gtsmodel.Follow) + case reflect.TypeOf((*gtsmodel.FollowRequest)(nil)).String(): + value = new(gtsmodel.FollowRequest) + case reflect.TypeOf((*gtsmodel.Move)(nil)).String(): + value = new(gtsmodel.Move) + case reflect.TypeOf((*gtsmodel.Poll)(nil)).String(): + value = new(gtsmodel.Poll) + case reflect.TypeOf((*gtsmodel.PollVote)(nil)).String(): + value = new(*gtsmodel.PollVote) + case reflect.TypeOf((*gtsmodel.Report)(nil)).String(): + value = new(gtsmodel.Report) + case reflect.TypeOf((*gtsmodel.Status)(nil)).String(): + value = new(gtsmodel.Status) + case reflect.TypeOf((*gtsmodel.StatusFave)(nil)).String(): + value = new(gtsmodel.StatusFave) + default: + return nil, gtserror.Newf("unknown type: %s", typ) + } + + // Attempt to unmarshal value JSON into destination. + if err := json.Unmarshal(data, &value); err != nil { + return nil, gtserror.Newf("error unmarshaling %s value data: %w", typ, err) + } + + return value, nil +} diff --git a/internal/messages/messages_test.go b/internal/messages/messages_test.go new file mode 100644 index 000000000..e5b2a2841 --- /dev/null +++ b/internal/messages/messages_test.go @@ -0,0 +1,292 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package messages_test + +import ( + "bytes" + "encoding/json" + "net/url" + "testing" + + "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/testrig" + + "github.com/google/go-cmp/cmp" +) + +var testStatus = testrig.NewTestStatuses()["admin_account_status_1"] + +var testAccount = testrig.NewTestAccounts()["admin_account"] + +var fromClientAPICases = []struct { + msg messages.FromClientAPI + data []byte +}{ + { + msg: messages.FromClientAPI{ + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + GTSModel: testStatus, + TargetURI: "https://gotosocial.org", + Origin: >smodel.Account{ID: "654321"}, + Target: >smodel.Account{ID: "123456"}, + }, + data: toJSON(map[string]any{ + "ap_object_type": ap.ObjectNote, + "ap_activity_type": ap.ActivityCreate, + "gts_model": json.RawMessage(toJSON(testStatus)), + "gts_model_type": "*gtsmodel.Status", + "target_uri": "https://gotosocial.org", + "origin_id": "654321", + "target_id": "123456", + }), + }, + { + msg: messages.FromClientAPI{ + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityUpdate, + GTSModel: testAccount, + TargetURI: "https://uk-queen-is-dead.org", + Origin: >smodel.Account{ID: "123456"}, + Target: >smodel.Account{ID: "654321"}, + }, + data: toJSON(map[string]any{ + "ap_object_type": ap.ObjectProfile, + "ap_activity_type": ap.ActivityUpdate, + "gts_model": json.RawMessage(toJSON(testAccount)), + "gts_model_type": "*gtsmodel.Account", + "target_uri": "https://uk-queen-is-dead.org", + "origin_id": "123456", + "target_id": "654321", + }), + }, +} + +var fromFediAPICases = []struct { + msg messages.FromFediAPI + data []byte +}{ + { + msg: messages.FromFediAPI{ + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + GTSModel: testStatus, + TargetURI: "https://gotosocial.org", + Requesting: >smodel.Account{ID: "654321"}, + Receiving: >smodel.Account{ID: "123456"}, + }, + data: toJSON(map[string]any{ + "ap_object_type": ap.ObjectNote, + "ap_activity_type": ap.ActivityCreate, + "gts_model": json.RawMessage(toJSON(testStatus)), + "gts_model_type": "*gtsmodel.Status", + "target_uri": "https://gotosocial.org", + "requesting_id": "654321", + "receiving_id": "123456", + }), + }, + { + msg: messages.FromFediAPI{ + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityUpdate, + GTSModel: testAccount, + TargetURI: "https://uk-queen-is-dead.org", + Requesting: >smodel.Account{ID: "123456"}, + Receiving: >smodel.Account{ID: "654321"}, + }, + data: toJSON(map[string]any{ + "ap_object_type": ap.ObjectProfile, + "ap_activity_type": ap.ActivityUpdate, + "gts_model": json.RawMessage(toJSON(testAccount)), + "gts_model_type": "*gtsmodel.Account", + "target_uri": "https://uk-queen-is-dead.org", + "requesting_id": "123456", + "receiving_id": "654321", + }), + }, +} + +func TestSerializeFromClientAPI(t *testing.T) { + for _, test := range fromClientAPICases { + // Serialize test message to blob. + data, err := test.msg.Serialize() + if err != nil { + t.Fatal(err) + } + + // Check serialized JSON data as expected. + assertJSONEqual(t, test.data, data) + } +} + +func TestDeserializeFromClientAPI(t *testing.T) { + for _, test := range fromClientAPICases { + var msg messages.FromClientAPI + + // Deserialize test message blob. + err := msg.Deserialize(test.data) + if err != nil { + t.Fatal(err) + } + + // Check that msg is as expected. + assertEqual(t, test.msg.APActivityType, msg.APActivityType) + assertEqual(t, test.msg.APObjectType, msg.APObjectType) + assertEqual(t, test.msg.GTSModel, msg.GTSModel) + assertEqual(t, test.msg.TargetURI, msg.TargetURI) + assertEqual(t, accountID(test.msg.Origin), accountID(msg.Origin)) + assertEqual(t, accountID(test.msg.Target), accountID(msg.Target)) + + // Perform final check to ensure + // account model keys deserialized. + assertEqualRSA(t, test.msg.GTSModel, msg.GTSModel) + } +} + +func TestSerializeFromFediAPI(t *testing.T) { + for _, test := range fromFediAPICases { + // Serialize test message to blob. + data, err := test.msg.Serialize() + if err != nil { + t.Fatal(err) + } + + // Check serialized JSON data as expected. + assertJSONEqual(t, test.data, data) + } +} + +func TestDeserializeFromFediAPI(t *testing.T) { + for _, test := range fromFediAPICases { + var msg messages.FromFediAPI + + // Deserialize test message blob. + err := msg.Deserialize(test.data) + if err != nil { + t.Fatal(err) + } + + // Check that msg is as expected. + assertEqual(t, test.msg.APActivityType, msg.APActivityType) + assertEqual(t, test.msg.APObjectType, msg.APObjectType) + assertEqual(t, urlStr(test.msg.APIRI), urlStr(msg.APIRI)) + assertEqual(t, test.msg.APObject, msg.APObject) + assertEqual(t, test.msg.GTSModel, msg.GTSModel) + assertEqual(t, test.msg.TargetURI, msg.TargetURI) + assertEqual(t, accountID(test.msg.Receiving), accountID(msg.Receiving)) + assertEqual(t, accountID(test.msg.Requesting), accountID(msg.Requesting)) + + // Perform final check to ensure + // account model keys deserialized. + assertEqualRSA(t, test.msg.GTSModel, msg.GTSModel) + } +} + +// assertEqualRSA asserts that test account model RSA keys are equal. +func assertEqualRSA(t *testing.T, expect, receive any) bool { + t.Helper() + + account1, ok1 := expect.(*gtsmodel.Account) + + account2, ok2 := receive.(*gtsmodel.Account) + + if ok1 != ok2 { + t.Errorf("different model types: expect=%T receive=%T", expect, receive) + return false + } else if !ok1 { + return true + } + + if !account1.PublicKey.Equal(account2.PublicKey) { + t.Error("public keys do not match") + return false + } + + t.Logf("publickey=%v", account1.PublicKey) + + if !account1.PrivateKey.Equal(account2.PrivateKey) { + t.Error("private keys do not match") + return false + } + + t.Logf("privatekey=%v", account1.PrivateKey) + + return true +} + +// assertEqual asserts that two values (of any type!) are equal, +// note we use the 'cmp' library here as it's much more useful in +// outputting debug information than testify, and handles more complex +// types like rsa public / private key comparisons correctly. +func assertEqual(t *testing.T, expect, receive any) bool { + t.Helper() + if diff := cmp.Diff(expect, receive); diff != "" { + t.Error(diff) + return false + } + return true +} + +// assertJSONEqual asserts that two slices of JSON data are equal. +func assertJSONEqual(t *testing.T, expect, receive []byte) bool { + t.Helper() + return assertEqual(t, fromJSON(expect), fromJSON(receive)) +} + +// urlStr returns url as string, or empty. +func urlStr(url *url.URL) string { + if url == nil { + return "" + } + return url.String() +} + +// accountID returns account's ID, or empty. +func accountID(account *gtsmodel.Account) string { + if account == nil { + return "" + } + return account.ID +} + +// fromJSON unmarshals input data as JSON. +func fromJSON(b []byte) any { + r := bytes.NewReader(b) + d := json.NewDecoder(r) + d.UseNumber() + var a any + err := d.Decode(&a) + if err != nil { + panic(err) + } + if d.More() { + panic("multiple json values in b") + } + return a +} + +// toJSON marshals input type as JSON data. +func toJSON(a any) []byte { + b, err := json.Marshal(a) + if err != nil { + panic(err) + } + return b +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index a7e73465d..30435b86f 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -18,12 +18,12 @@ package transport import ( + "bytes" "context" "encoding/json" "net/http" "net/url" - "codeberg.org/gruf/go-byteutil" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -130,25 +130,28 @@ func (t *transport) prepare( *delivery.Delivery, error, ) { - url := to.String() - - // Use rewindable reader for body. - var body byteutil.ReadNopCloser - body.Reset(data) - // Prepare POST signer. sign := t.signPOST(data) + // Use *bytes.Reader for request body, + // as NewRequest() automatically will + // set .GetBody and content-length. + // (this handles necessary rewinding). + body := bytes.NewReader(data) + // Update to-be-used request context with signing details. ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID) ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign) // Prepare a new request with data body directed at URL. - r, err := http.NewRequestWithContext(ctx, "POST", url, &body) + r, err := http.NewRequestWithContext(ctx, "POST", to.String(), body) if err != nil { return nil, gtserror.Newf("error preparing request: %w", err) } + // Set our predefined controller user-agent. + r.Header.Set("User-Agent", t.controller.userAgent) + // Set the standard ActivityPub content-type + charset headers. r.Header.Add("Content-Type", string(apiutil.AppActivityLDJSON)) r.Header.Add("Accept-Charset", "utf-8") diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go index 1e9126b2e..1e3ebb054 100644 --- a/internal/transport/delivery/delivery.go +++ b/internal/transport/delivery/delivery.go @@ -18,16 +18,13 @@ package delivery import ( - "context" - "slices" + "bytes" + "encoding/json" + "io" + "net/http" "time" - "codeberg.org/gruf/go-runners" - "codeberg.org/gruf/go-structr" - "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/httpclient" - "github.com/superseriousbusiness/gotosocial/internal/queue" - "github.com/superseriousbusiness/gotosocial/internal/util" ) // Delivery wraps an httpclient.Request{} @@ -36,6 +33,9 @@ import ( // be indexed (and so, dropped from queue) // by any of these possible ID IRIs. type Delivery struct { + // PubKeyID is the signing public key + // ID of the actor performing request. + PubKeyID string // ActorID contains the ActivityPub // actor ID IRI (if any) of the activity @@ -61,273 +61,98 @@ type Delivery struct { next time.Time } -func (dlv *Delivery) backoff() time.Duration { - if dlv.next.IsZero() { - return 0 - } - return time.Until(dlv.next) +// delivery is an internal type +// for Delivery{} that provides +// a json serialize / deserialize +// able shape that minimizes data. +type delivery struct { + PubKeyID string `json:"pub_key_id,omitempty"` + ActorID string `json:"actor_id,omitempty"` + ObjectID string `json:"object_id,omitempty"` + TargetID string `json:"target_id,omitempty"` + Method string `json:"method,omitempty"` + Header map[string][]string `json:"header,omitempty"` + URL string `json:"url,omitempty"` + Body []byte `json:"body,omitempty"` } -// WorkerPool wraps multiple Worker{}s in -// a singular struct for easy multi start/stop. -type WorkerPool struct { +// Serialize will serialize the delivery data as data blob for storage, +// note that this will flatten some of the data, dropping signing funcs. +func (dlv *Delivery) Serialize() ([]byte, error) { + var body []byte - // Client defines httpclient.Client{} - // passed to each of delivery pool Worker{}s. - Client *httpclient.Client + if dlv.Request.GetBody != nil { + // Fetch a fresh copy of request body. + rbody, err := dlv.Request.GetBody() + if err != nil { + return nil, err + } - // Queue is the embedded queue.StructQueue{} - // passed to each of delivery pool Worker{}s. - Queue queue.StructQueue[*Delivery] + // Read request body into memory. + body, err = io.ReadAll(rbody) - // internal fields. - workers []*Worker -} + // Done with body. + _ = rbody.Close() -// Init will initialize the Worker{} pool -// with given http client, request queue to pull -// from and number of delivery workers to spawn. -func (p *WorkerPool) Init(client *httpclient.Client) { - p.Client = client - p.Queue.Init(structr.QueueConfig[*Delivery]{ - Indices: []structr.IndexConfig{ - {Fields: "ActorID", Multiple: true}, - {Fields: "ObjectID", Multiple: true}, - {Fields: "TargetID", Multiple: true}, - }, - }) -} - -// Start will attempt to start 'n' Worker{}s. -func (p *WorkerPool) Start(n int) { - // Check whether workers are - // set (is already running). - ok := (len(p.workers) > 0) - if ok { - return + if err != nil { + return nil, err + } } - // Allocate new workers slice. - p.workers = make([]*Worker, n) - for i := range p.workers { - - // Allocate new Worker{}. - p.workers[i] = new(Worker) - p.workers[i].Client = p.Client - p.workers[i].Queue = &p.Queue - - // Attempt to start worker. - // Return bool not useful - // here, as true = started, - // false = already running. - _ = p.workers[i].Start() - } + // Marshal as internal JSON type. + return json.Marshal(delivery{ + PubKeyID: dlv.PubKeyID, + ActorID: dlv.ActorID, + ObjectID: dlv.ObjectID, + TargetID: dlv.TargetID, + Method: dlv.Request.Method, + Header: dlv.Request.Header, + URL: dlv.Request.URL.String(), + Body: body, + }) } -// Stop will attempt to stop contained Worker{}s. -func (p *WorkerPool) Stop() { - // Check whether workers are - // set (is currently running). - ok := (len(p.workers) == 0) - if ok { - return - } - - // Stop all running workers. - for i := range p.workers { +// Deserialize will attempt to deserialize a blob of task data, +// which will involve unflattening previously serialized data and +// leave delivery incomplete, still requiring signing func setup. +func (dlv *Delivery) Deserialize(data []byte) error { + var idlv delivery - // return bool not useful - // here, as true = stopped, - // false = never running. - _ = p.workers[i].Stop() + // Unmarshal as internal JSON type. + err := json.Unmarshal(data, &idlv) + if err != nil { + return err } - // Unset workers slice. - p.workers = p.workers[:0] -} - -// Worker wraps an httpclient.Client{} to feed -// from queue.StructQueue{} for ActivityPub reqs -// to deliver. It does so while prioritizing new -// queued requests over backlogged retries. -type Worker struct { - - // Client is the httpclient.Client{} that - // delivery worker will use for requests. - Client *httpclient.Client - - // Queue is the Delivery{} message queue - // that delivery worker will feed from. - Queue *queue.StructQueue[*Delivery] + // Copy over simplest fields. + dlv.PubKeyID = idlv.PubKeyID + dlv.ActorID = idlv.ActorID + dlv.ObjectID = idlv.ObjectID + dlv.TargetID = idlv.TargetID - // internal fields. - backlog []*Delivery - service runners.Service -} - -// Start will attempt to start the Worker{}. -func (w *Worker) Start() bool { - return w.service.GoRun(w.run) -} + var body io.Reader -// Stop will attempt to stop the Worker{}. -func (w *Worker) Stop() bool { - return w.service.Stop() -} - -// run wraps process to restart on any panic. -func (w *Worker) run(ctx context.Context) { - if w.Client == nil || w.Queue == nil { - panic("not yet initialized") - } - util.Must(func() { w.process(ctx) }) -} - -// process is the main delivery worker processing routine. -func (w *Worker) process(ctx context.Context) bool { - if w.Client == nil || w.Queue == nil { - // we perform this check here just - // to ensure the compiler knows these - // variables aren't nil in the loop, - // even if already checked by caller. - panic("not yet initialized") + if idlv.Body != nil { + // Create new body reader from data. + body = bytes.NewReader(idlv.Body) } -loop: - for { - // Get next delivery. - dlv, ok := w.next(ctx) - if !ok { - return true - } - - // Check whether backoff required. - const min = 100 * time.Millisecond - if d := dlv.backoff(); d > min { - - // Start backoff sleep timer. - backoff := time.NewTimer(d) - - select { - case <-ctx.Done(): - // Main ctx - // cancelled. - backoff.Stop() - return true - - case <-w.Queue.Wait(): - // A new message was - // queued, re-add this - // to backlog + retry. - w.pushBacklog(dlv) - backoff.Stop() - continue loop - - case <-backoff.C: - // success! - } - } - - // Attempt delivery of AP request. - rsp, retry, err := w.Client.DoOnce( - &dlv.Request, - ) - - if err == nil { - // Ensure body closed. - _ = rsp.Body.Close() - continue loop - } - - if !retry { - // Drop deliveries when no - // retry requested, or they - // reached max (either). - continue loop - } - - // Determine next delivery attempt. - backoff := dlv.Request.BackOff() - dlv.next = time.Now().Add(backoff) - - // Push to backlog. - w.pushBacklog(dlv) + // Create a new request object from unmarshaled details. + r, err := http.NewRequest(idlv.Method, idlv.URL, body) + if err != nil { + return err } -} - -// next gets the next available delivery, blocking until available if necessary. -func (w *Worker) next(ctx context.Context) (*Delivery, bool) { -loop: - for { - // Try pop next queued. - dlv, ok := w.Queue.Pop() - - if !ok { - // Check the backlog. - if len(w.backlog) > 0 { - - // Sort by 'next' time. - sortDeliveries(w.backlog) - - // Pop next delivery. - dlv := w.popBacklog() - - return dlv, true - } - select { - // Backlog is empty, we MUST - // block until next enqueued. - case <-w.Queue.Wait(): - continue loop + // Wrap request in httpclient type. + dlv.Request = httpclient.WrapRequest(r) - // Worker was stopped. - case <-ctx.Done(): - return nil, false - } - } - - // Replace request context for worker state canceling. - ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) - dlv.Request.Request = dlv.Request.Request.WithContext(ctx) - - return dlv, true - } + return nil } -// popBacklog pops next available from the backlog. -func (w *Worker) popBacklog() *Delivery { - if len(w.backlog) == 0 { - return nil +// backoff returns a valid (>= 0) backoff duration. +func (dlv *Delivery) backoff() time.Duration { + if dlv.next.IsZero() { + return 0 } - - // Pop from backlog. - dlv := w.backlog[0] - - // Shift backlog down by one. - copy(w.backlog, w.backlog[1:]) - w.backlog = w.backlog[:len(w.backlog)-1] - - return dlv -} - -// pushBacklog pushes the given delivery to backlog. -func (w *Worker) pushBacklog(dlv *Delivery) { - w.backlog = append(w.backlog, dlv) -} - -// sortDeliveries sorts deliveries according -// to when is the first requiring re-attempt. -func sortDeliveries(d []*Delivery) { - slices.SortFunc(d, func(a, b *Delivery) int { - const k = +1 - switch { - case a.next.Before(b.next): - return +k - case b.next.Before(a.next): - return -k - default: - return 0 - } - }) + return time.Until(dlv.next) } diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go index 48831f098..e9eaf8fd1 100644 --- a/internal/transport/delivery/delivery_test.go +++ b/internal/transport/delivery/delivery_test.go @@ -1,203 +1,134 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + package delivery_test import ( - "fmt" + "bytes" + "encoding/json" "io" - "math/rand" - "net" "net/http" - "strconv" - "strings" "testing" - "codeberg.org/gruf/go-byteutil" - "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/stretchr/testify/assert" "github.com/superseriousbusiness/gotosocial/internal/httpclient" - "github.com/superseriousbusiness/gotosocial/internal/queue" "github.com/superseriousbusiness/gotosocial/internal/transport/delivery" ) -func TestDeliveryWorkerPool(t *testing.T) { - for _, i := range []int{1, 2, 4, 8, 16, 32} { - t.Run("size="+strconv.Itoa(i), func(t *testing.T) { - testDeliveryWorkerPool(t, i, generateInput(100*i)) - }) - } -} - -func testDeliveryWorkerPool(t *testing.T, sz int, input []*testrequest) { - wp := new(delivery.WorkerPool) - wp.Init(httpclient.New(httpclient.Config{ - AllowRanges: config.MustParseIPPrefixes([]string{ - "127.0.0.0/8", +var deliveryCases = []struct { + msg delivery.Delivery + data []byte +}{ + { + msg: delivery.Delivery{ + PubKeyID: "https://google.com/users/bigboy#pubkey", + ActorID: "https://google.com/users/bigboy", + ObjectID: "https://google.com/users/bigboy/follow/1", + TargetID: "https://askjeeves.com/users/smallboy", + Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")), + }, + data: toJSON(map[string]any{ + "pub_key_id": "https://google.com/users/bigboy#pubkey", + "actor_id": "https://google.com/users/bigboy", + "object_id": "https://google.com/users/bigboy/follow/1", + "target_id": "https://askjeeves.com/users/smallboy", + "method": "POST", + "url": "https://askjeeves.com/users/smallboy/inbox", + "body": []byte("data!"), + // "header": map[string][]string{}, + }), + }, + { + msg: delivery.Delivery{ + Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")), + }, + data: toJSON(map[string]any{ + "method": "GET", + "url": "https://google.com", + "body": []byte("uwu im just a wittle seawch engwin"), + // "header": map[string][]string{}, }), - })) - wp.Start(sz) - defer wp.Stop() - test(t, &wp.Queue, input) + }, } -func test( - t *testing.T, - queue *queue.StructQueue[*delivery.Delivery], - input []*testrequest, -) { - expect := make(chan *testrequest) - errors := make(chan error) - - // Prepare an HTTP test handler that ensures expected delivery is received. - handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - errors <- (<-expect).Equal(r) - }) - - // Start new HTTP test server listener. - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer l.Close() - - // Start the HTTP server. - // - // specifically not using httptest.Server{} here as httptest - // links that server with its own http.Client{}, whereas we're - // using an httpclient.Client{} (well, delivery routine is). - srv := new(http.Server) - srv.Addr = "http://" + l.Addr().String() - srv.Handler = handler - go srv.Serve(l) - defer srv.Close() - - // Range over test input. - for _, test := range input { - - // Generate req for input. - req := test.Generate(srv.Addr) - r := httpclient.WrapRequest(req) - - // Wrap the request in delivery. - dlv := new(delivery.Delivery) - dlv.Request = r - - // Enqueue delivery! - queue.Push(dlv) - expect <- test - - // Wait for errors from handler. - if err := <-errors; err != nil { - t.Error(err) +func TestSerializeDelivery(t *testing.T) { + for _, test := range deliveryCases { + // Serialize test message to blob. + data, err := test.msg.Serialize() + if err != nil { + t.Fatal(err) } - } -} - -type testrequest struct { - method string - uri string - body []byte -} -// generateInput generates 'n' many testrequest cases. -func generateInput(n int) []*testrequest { - tests := make([]*testrequest, n) - for i := range tests { - tests[i] = new(testrequest) - tests[i].method = randomMethod() - tests[i].uri = randomURI() - tests[i].body = randomBody(tests[i].method) + // Check that serialized JSON data is as expected. + assert.JSONEq(t, string(test.data), string(data)) } - return tests } -var methods = []string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - http.MethodPut, - http.MethodTrace, -} +func TestDeserializeDelivery(t *testing.T) { + for _, test := range deliveryCases { + var msg delivery.Delivery -// randomMethod generates a random http method. -func randomMethod() string { - return methods[rand.Intn(len(methods))] -} + // Deserialize test message blob. + err := msg.Deserialize(test.data) + if err != nil { + t.Fatal(err) + } -// randomURI generates a random http uri. -func randomURI() string { - n := rand.Intn(5) - p := make([]string, n) - for i := range p { - p[i] = strconv.Itoa(rand.Int()) + // Check that delivery fields are as expected. + assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID) + assert.Equal(t, test.msg.ActorID, msg.ActorID) + assert.Equal(t, test.msg.ObjectID, msg.ObjectID) + assert.Equal(t, test.msg.TargetID, msg.TargetID) + assert.Equal(t, test.msg.Request.Method, msg.Request.Method) + assert.Equal(t, test.msg.Request.URL, msg.Request.URL) + assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body)) } - return "/" + strings.Join(p, "/") } -// randomBody generates a random http body DEPENDING on method. -func randomBody(method string) []byte { - if requiresBody(method) { - return []byte(method + " " + randomURI()) +// toRequest creates httpclient.Request from HTTP method, URL and body data. +func toRequest(method string, url string, body []byte) httpclient.Request { + var rbody io.Reader + if body != nil { + rbody = bytes.NewReader(body) } - return nil -} - -// requiresBody returns whether method requires body. -func requiresBody(method string) bool { - switch method { - case http.MethodPatch, - http.MethodPost, - http.MethodPut: - return true - default: - return false + req, err := http.NewRequest(method, url, rbody) + if err != nil { + panic(err) } + return httpclient.WrapRequest(req) } -// Generate will generate a real http.Request{} from test data. -func (t *testrequest) Generate(addr string) *http.Request { - var body io.ReadCloser - if t.body != nil { - var b byteutil.ReadNopCloser - b.Reset(t.body) - body = &b +// readBody reads the content of body io.ReadCloser into memory as byte slice. +func readBody(r io.ReadCloser) []byte { + if r == nil { + return nil } - req, err := http.NewRequest(t.method, addr+t.uri, body) + b, err := io.ReadAll(r) if err != nil { panic(err) } - return req + return b } -// Equal checks if request matches receiving test request. -func (t *testrequest) Equal(r *http.Request) error { - // Ensure methods match. - if t.method != r.Method { - return fmt.Errorf("differing request methods: t=%q r=%q", t.method, r.Method) - } - - // Ensure request URIs match. - if t.uri != r.URL.RequestURI() { - return fmt.Errorf("differing request urls: t=%q r=%q", t.uri, r.URL.RequestURI()) - } - - // Ensure body cases match. - if requiresBody(t.method) { - - // Read request into memory. - b, err := io.ReadAll(r.Body) - if err != nil { - return fmt.Errorf("error reading request body: %v", err) - } - - // Compare the request bodies. - st := strings.TrimSpace(string(t.body)) - sr := strings.TrimSpace(string(b)) - if st != sr { - return fmt.Errorf("differing request bodies: t=%q r=%q", st, sr) - } +// toJSON marshals input type as JSON data. +func toJSON(a any) []byte { + b, err := json.Marshal(a) + if err != nil { + panic(err) } - - return nil + return b } diff --git a/internal/transport/delivery/worker.go b/internal/transport/delivery/worker.go new file mode 100644 index 000000000..1ed974e84 --- /dev/null +++ b/internal/transport/delivery/worker.go @@ -0,0 +1,298 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package delivery + +import ( + "context" + "slices" + "time" + + "codeberg.org/gruf/go-runners" + "codeberg.org/gruf/go-structr" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/httpclient" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/queue" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// WorkerPool wraps multiple Worker{}s in +// a singular struct for easy multi start/stop. +type WorkerPool struct { + + // Client defines httpclient.Client{} + // passed to each of delivery pool Worker{}s. + Client *httpclient.Client + + // Queue is the embedded queue.StructQueue{} + // passed to each of delivery pool Worker{}s. + Queue queue.StructQueue[*Delivery] + + // internal fields. + workers []*Worker +} + +// Init will initialize the Worker{} pool +// with given http client, request queue to pull +// from and number of delivery workers to spawn. +func (p *WorkerPool) Init(client *httpclient.Client) { + p.Client = client + p.Queue.Init(structr.QueueConfig[*Delivery]{ + Indices: []structr.IndexConfig{ + {Fields: "ActorID", Multiple: true}, + {Fields: "ObjectID", Multiple: true}, + {Fields: "TargetID", Multiple: true}, + }, + }) +} + +// Start will attempt to start 'n' Worker{}s. +func (p *WorkerPool) Start(n int) { + // Check whether workers are + // set (is already running). + ok := (len(p.workers) > 0) + if ok { + return + } + + // Allocate new workers slice. + p.workers = make([]*Worker, n) + for i := range p.workers { + + // Allocate new Worker{}. + p.workers[i] = new(Worker) + p.workers[i].Client = p.Client + p.workers[i].Queue = &p.Queue + + // Attempt to start worker. + // Return bool not useful + // here, as true = started, + // false = already running. + _ = p.workers[i].Start() + } +} + +// Stop will attempt to stop contained Worker{}s. +func (p *WorkerPool) Stop() { + // Check whether workers are + // set (is currently running). + ok := (len(p.workers) == 0) + if ok { + return + } + + // Stop all running workers. + for i := range p.workers { + + // return bool not useful + // here, as true = stopped, + // false = never running. + _ = p.workers[i].Stop() + } + + // Unset workers slice. + p.workers = p.workers[:0] +} + +// Worker wraps an httpclient.Client{} to feed +// from queue.StructQueue{} for ActivityPub reqs +// to deliver. It does so while prioritizing new +// queued requests over backlogged retries. +type Worker struct { + + // Client is the httpclient.Client{} that + // delivery worker will use for requests. + Client *httpclient.Client + + // Queue is the Delivery{} message queue + // that delivery worker will feed from. + Queue *queue.StructQueue[*Delivery] + + // internal fields. + backlog []*Delivery + service runners.Service +} + +// Start will attempt to start the Worker{}. +func (w *Worker) Start() bool { + return w.service.GoRun(w.run) +} + +// Stop will attempt to stop the Worker{}. +func (w *Worker) Stop() bool { + return w.service.Stop() +} + +// run wraps process to restart on any panic. +func (w *Worker) run(ctx context.Context) { + if w.Client == nil || w.Queue == nil { + panic("not yet initialized") + } + log.Infof(ctx, "%p: starting worker", w) + defer log.Infof(ctx, "%p: stopped worker", w) + util.Must(func() { w.process(ctx) }) +} + +// process is the main delivery worker processing routine. +func (w *Worker) process(ctx context.Context) bool { + if w.Client == nil || w.Queue == nil { + // we perform this check here just + // to ensure the compiler knows these + // variables aren't nil in the loop, + // even if already checked by caller. + panic("not yet initialized") + } + +loop: + for { + // Get next delivery. + dlv, ok := w.next(ctx) + if !ok { + return true + } + + // Check whether backoff required. + const min = 100 * time.Millisecond + if d := dlv.backoff(); d > min { + + // Start backoff sleep timer. + backoff := time.NewTimer(d) + + select { + case <-ctx.Done(): + // Main ctx + // cancelled. + backoff.Stop() + return true + + case <-w.Queue.Wait(): + // A new message was + // queued, re-add this + // to backlog + retry. + w.pushBacklog(dlv) + backoff.Stop() + continue loop + + case <-backoff.C: + // success! + } + } + + // Attempt delivery of AP request. + rsp, retry, err := w.Client.DoOnce( + &dlv.Request, + ) + + if err == nil { + // Ensure body closed. + _ = rsp.Body.Close() + continue loop + } + + if !retry { + // Drop deliveries when no + // retry requested, or they + // reached max (either). + continue loop + } + + // Determine next delivery attempt. + backoff := dlv.Request.BackOff() + dlv.next = time.Now().Add(backoff) + + // Push to backlog. + w.pushBacklog(dlv) + } +} + +// next gets the next available delivery, blocking until available if necessary. +func (w *Worker) next(ctx context.Context) (*Delivery, bool) { +loop: + for { + // Try pop next queued. + dlv, ok := w.Queue.Pop() + + if !ok { + // Check the backlog. + if len(w.backlog) > 0 { + + // Sort by 'next' time. + sortDeliveries(w.backlog) + + // Pop next delivery. + dlv := w.popBacklog() + + return dlv, true + } + + select { + // Backlog is empty, we MUST + // block until next enqueued. + case <-w.Queue.Wait(): + continue loop + + // Worker was stopped. + case <-ctx.Done(): + return nil, false + } + } + + // Replace request context for worker state canceling. + ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) + dlv.Request.Request = dlv.Request.Request.WithContext(ctx) + + return dlv, true + } +} + +// popBacklog pops next available from the backlog. +func (w *Worker) popBacklog() *Delivery { + if len(w.backlog) == 0 { + return nil + } + + // Pop from backlog. + dlv := w.backlog[0] + + // Shift backlog down by one. + copy(w.backlog, w.backlog[1:]) + w.backlog = w.backlog[:len(w.backlog)-1] + + return dlv +} + +// pushBacklog pushes the given delivery to backlog. +func (w *Worker) pushBacklog(dlv *Delivery) { + w.backlog = append(w.backlog, dlv) +} + +// sortDeliveries sorts deliveries according +// to when is the first requiring re-attempt. +func sortDeliveries(d []*Delivery) { + slices.SortFunc(d, func(a, b *Delivery) int { + const k = +1 + switch { + case a.next.Before(b.next): + return +k + case b.next.Before(a.next): + return -k + default: + return 0 + } + }) +} diff --git a/internal/transport/delivery/worker_test.go b/internal/transport/delivery/worker_test.go new file mode 100644 index 000000000..936ce6e1d --- /dev/null +++ b/internal/transport/delivery/worker_test.go @@ -0,0 +1,220 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package delivery_test + +import ( + "fmt" + "io" + "math/rand" + "net" + "net/http" + "strconv" + "strings" + "testing" + + "codeberg.org/gruf/go-byteutil" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/httpclient" + "github.com/superseriousbusiness/gotosocial/internal/queue" + "github.com/superseriousbusiness/gotosocial/internal/transport/delivery" +) + +func TestDeliveryWorkerPool(t *testing.T) { + for _, i := range []int{1, 2, 4, 8, 16, 32} { + t.Run("size="+strconv.Itoa(i), func(t *testing.T) { + testDeliveryWorkerPool(t, i, generateInput(100*i)) + }) + } +} + +func testDeliveryWorkerPool(t *testing.T, sz int, input []*testrequest) { + wp := new(delivery.WorkerPool) + wp.Init(httpclient.New(httpclient.Config{ + AllowRanges: config.MustParseIPPrefixes([]string{ + "127.0.0.0/8", + }), + })) + wp.Start(sz) + defer wp.Stop() + test(t, &wp.Queue, input) +} + +func test( + t *testing.T, + queue *queue.StructQueue[*delivery.Delivery], + input []*testrequest, +) { + expect := make(chan *testrequest) + errors := make(chan error) + + // Prepare an HTTP test handler that ensures expected delivery is received. + handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + errors <- (<-expect).Equal(r) + }) + + // Start new HTTP test server listener. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + // Start the HTTP server. + // + // specifically not using httptest.Server{} here as httptest + // links that server with its own http.Client{}, whereas we're + // using an httpclient.Client{} (well, delivery routine is). + srv := new(http.Server) + srv.Addr = "http://" + l.Addr().String() + srv.Handler = handler + go srv.Serve(l) + defer srv.Close() + + // Range over test input. + for _, test := range input { + + // Generate req for input. + req := test.Generate(srv.Addr) + r := httpclient.WrapRequest(req) + + // Wrap the request in delivery. + dlv := new(delivery.Delivery) + dlv.Request = r + + // Enqueue delivery! + queue.Push(dlv) + expect <- test + + // Wait for errors from handler. + if err := <-errors; err != nil { + t.Error(err) + } + } +} + +type testrequest struct { + method string + uri string + body []byte +} + +// generateInput generates 'n' many testrequest cases. +func generateInput(n int) []*testrequest { + tests := make([]*testrequest, n) + for i := range tests { + tests[i] = new(testrequest) + tests[i].method = randomMethod() + tests[i].uri = randomURI() + tests[i].body = randomBody(tests[i].method) + } + return tests +} + +var methods = []string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + http.MethodPut, + http.MethodTrace, +} + +// randomMethod generates a random http method. +func randomMethod() string { + return methods[rand.Intn(len(methods))] +} + +// randomURI generates a random http uri. +func randomURI() string { + n := rand.Intn(5) + p := make([]string, n) + for i := range p { + p[i] = strconv.Itoa(rand.Int()) + } + return "/" + strings.Join(p, "/") +} + +// randomBody generates a random http body DEPENDING on method. +func randomBody(method string) []byte { + if requiresBody(method) { + return []byte(method + " " + randomURI()) + } + return nil +} + +// requiresBody returns whether method requires body. +func requiresBody(method string) bool { + switch method { + case http.MethodPatch, + http.MethodPost, + http.MethodPut: + return true + default: + return false + } +} + +// Generate will generate a real http.Request{} from test data. +func (t *testrequest) Generate(addr string) *http.Request { + var body io.ReadCloser + if t.body != nil { + var b byteutil.ReadNopCloser + b.Reset(t.body) + body = &b + } + req, err := http.NewRequest(t.method, addr+t.uri, body) + if err != nil { + panic(err) + } + return req +} + +// Equal checks if request matches receiving test request. +func (t *testrequest) Equal(r *http.Request) error { + // Ensure methods match. + if t.method != r.Method { + return fmt.Errorf("differing request methods: t=%q r=%q", t.method, r.Method) + } + + // Ensure request URIs match. + if t.uri != r.URL.RequestURI() { + return fmt.Errorf("differing request urls: t=%q r=%q", t.uri, r.URL.RequestURI()) + } + + // Ensure body cases match. + if requiresBody(t.method) { + + // Read request into memory. + b, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("error reading request body: %v", err) + } + + // Compare the request bodies. + st := strings.TrimSpace(string(t.body)) + sr := strings.TrimSpace(string(b)) + if st != sr { + return fmt.Errorf("differing request bodies: t=%q r=%q", st, sr) + } + } + + return nil +} |