diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/gtsmodel/workertask.go | 41 | ||||
| -rw-r--r-- | internal/httpclient/transport.go (renamed from internal/httpclient/sign.go) | 7 | ||||
| -rw-r--r-- | internal/messages/messages.go | 290 | ||||
| -rw-r--r-- | internal/messages/messages_test.go | 292 | ||||
| -rw-r--r-- | internal/transport/deliver.go | 19 | ||||
| -rw-r--r-- | internal/transport/delivery/delivery.go | 331 | ||||
| -rw-r--r-- | internal/transport/delivery/delivery_test.go | 265 | ||||
| -rw-r--r-- | internal/transport/delivery/worker.go | 298 | ||||
| -rw-r--r-- | internal/transport/delivery/worker_test.go | 220 | 
9 files changed, 1328 insertions, 435 deletions
diff --git a/internal/gtsmodel/workertask.go b/internal/gtsmodel/workertask.go new file mode 100644 index 000000000..cc8433199 --- /dev/null +++ b/internal/gtsmodel/workertask.go @@ -0,0 +1,41 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package gtsmodel + +import "time" + +type WorkerType uint8 + +const ( +	DeliveryWorker  WorkerType = 1 +	FederatorWorker WorkerType = 2 +	ClientWorker    WorkerType = 3 +) + +// WorkerTask represents a queued worker task +// that was persisted to the database on shutdown. +// This is only ever used on startup to pickup +// where we left off, and on shutdown to prevent +// queued tasks from being lost. It is simply a +// means to store a blob of serialized task data. +type WorkerTask struct { +	ID         uint      `bun:""` +	WorkerType uint8     `bun:""` +	TaskData   []byte    `bun:""` +	CreatedAt  time.Time `bun:""` +} diff --git a/internal/httpclient/sign.go b/internal/httpclient/transport.go index eff20be49..350d24fab 100644 --- a/internal/httpclient/sign.go +++ b/internal/httpclient/transport.go @@ -21,7 +21,6 @@ import (  	"net/http"  	"time" -	"codeberg.org/gruf/go-byteutil"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  ) @@ -45,12 +44,6 @@ func (t *signingtransport) RoundTrip(r *http.Request) (*http.Response, error) {  		r.Header.Del("Signature")  		r.Header.Del("Digest") -		// Rewind body reader and content-length if set. -		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { -			rc.Rewind() // set len AFTER rewind -			r.ContentLength = int64(rc.Len()) -		} -  		// Sign the outgoing request.  		if err := sign(r); err != nil {  			return nil, err diff --git a/internal/messages/messages.go b/internal/messages/messages.go index c5488d586..7779633ba 100644 --- a/internal/messages/messages.go +++ b/internal/messages/messages.go @@ -18,9 +18,15 @@  package messages  import ( +	"context" +	"encoding/json"  	"net/url" +	"reflect"  	"codeberg.org/gruf/go-structr" +	"github.com/superseriousbusiness/activity/streams" +	"github.com/superseriousbusiness/activity/streams/vocab" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  ) @@ -50,6 +56,105 @@ type FromClientAPI struct {  	Target *gtsmodel.Account  } +// fromClientAPI is an internal type +// for FromClientAPI that provides a +// json serialize / deserialize -able +// shape that minimizes required data. +type fromClientAPI struct { +	APObjectType   string          `json:"ap_object_type,omitempty"` +	APActivityType string          `json:"ap_activity_type,omitempty"` +	GTSModel       json.RawMessage `json:"gts_model,omitempty"` +	GTSModelType   string          `json:"gts_model_type,omitempty"` +	TargetURI      string          `json:"target_uri,omitempty"` +	OriginID       string          `json:"origin_id,omitempty"` +	TargetID       string          `json:"target_id,omitempty"` +} + +// Serialize will serialize the worker data as data blob for storage, +// note that this will flatten some of the data e.g. only account IDs. +func (msg *FromClientAPI) Serialize() ([]byte, error) { +	var ( +		modelType string +		originID  string +		targetID  string +	) + +	// Set database model type if any provided. +	if t := reflect.TypeOf(msg.GTSModel); t != nil { +		modelType = t.String() +	} + +	// Set origin account ID. +	if msg.Origin != nil { +		originID = msg.Origin.ID +	} + +	// Set target account ID. +	if msg.Target != nil { +		targetID = msg.Target.ID +	} + +	// Marshal GTS model as raw JSON block. +	modelJSON, err := json.Marshal(msg.GTSModel) +	if err != nil { +		return nil, err +	} + +	// Marshal as internal JSON type. +	return json.Marshal(fromClientAPI{ +		APObjectType:   msg.APObjectType, +		APActivityType: msg.APActivityType, +		GTSModel:       modelJSON, +		GTSModelType:   modelType, +		TargetURI:      msg.TargetURI, +		OriginID:       originID, +		TargetID:       targetID, +	}) +} + +// Deserialize will attempt to deserialize a blob of task data, +// which will involve unflattening previously serialized data and +// leave some message structures as placeholders to holding IDs. +func (msg *FromClientAPI) Deserialize(data []byte) error { +	var imsg fromClientAPI + +	// Unmarshal as internal JSON type. +	err := json.Unmarshal(data, &imsg) +	if err != nil { +		return err +	} + +	// Copy over the simplest fields. +	msg.APObjectType = imsg.APObjectType +	msg.APActivityType = imsg.APActivityType +	msg.TargetURI = imsg.TargetURI + +	// Resolve Go type from JSON data. +	msg.GTSModel, err = resolveGTSModel( +		imsg.GTSModelType, +		imsg.GTSModel, +	) +	if err != nil { +		return err +	} + +	if imsg.OriginID != "" { +		// Set origin account ID using a +		// barebones model (later filled in). +		msg.Origin = new(gtsmodel.Account) +		msg.Origin.ID = imsg.OriginID +	} + +	if imsg.TargetID != "" { +		// Set target account ID using a +		// barebones model (later filled in). +		msg.Target = new(gtsmodel.Account) +		msg.Target.ID = imsg.TargetID +	} + +	return nil +} +  // ClientMsgIndices defines queue indices this  // message type should be accessible / stored under.  func ClientMsgIndices() []structr.IndexConfig { @@ -91,6 +196,133 @@ type FromFediAPI struct {  	Receiving *gtsmodel.Account  } +// fromFediAPI is an internal type +// for FromFediAPI that provides a +// json serialize / deserialize -able +// shape that minimizes required data. +type fromFediAPI struct { +	APObjectType   string                 `json:"ap_object_type,omitempty"` +	APActivityType string                 `json:"ap_activity_type,omitempty"` +	APIRI          string                 `json:"ap_iri,omitempty"` +	APObject       map[string]interface{} `json:"ap_object,omitempty"` +	GTSModel       json.RawMessage        `json:"gts_model,omitempty"` +	GTSModelType   string                 `json:"gts_model_type,omitempty"` +	TargetURI      string                 `json:"target_uri,omitempty"` +	RequestingID   string                 `json:"requesting_id,omitempty"` +	ReceivingID    string                 `json:"receiving_id,omitempty"` +} + +// Serialize will serialize the worker data as data blob for storage, +// note that this will flatten some of the data e.g. only account IDs. +func (msg *FromFediAPI) Serialize() ([]byte, error) { +	var ( +		gtsModelType string +		apIRI        string +		apObject     map[string]interface{} +		requestingID string +		receivingID  string +	) + +	// Set AP IRI string. +	if msg.APIRI != nil { +		apIRI = msg.APIRI.String() +	} + +	// Set serialized AP object data if set. +	if t, ok := msg.APObject.(vocab.Type); ok { +		obj, err := t.Serialize() +		if err != nil { +			return nil, err +		} +		apObject = obj +	} + +	// Set database model type if any provided. +	if t := reflect.TypeOf(msg.GTSModel); t != nil { +		gtsModelType = t.String() +	} + +	// Set requesting account ID. +	if msg.Requesting != nil { +		requestingID = msg.Requesting.ID +	} + +	// Set receiving account ID. +	if msg.Receiving != nil { +		receivingID = msg.Receiving.ID +	} + +	// Marshal GTS model as raw JSON block. +	modelJSON, err := json.Marshal(msg.GTSModel) +	if err != nil { +		return nil, err +	} + +	// Marshal as internal JSON type. +	return json.Marshal(fromFediAPI{ +		APObjectType:   msg.APObjectType, +		APActivityType: msg.APActivityType, +		APIRI:          apIRI, +		APObject:       apObject, +		GTSModel:       modelJSON, +		GTSModelType:   gtsModelType, +		TargetURI:      msg.TargetURI, +		RequestingID:   requestingID, +		ReceivingID:    receivingID, +	}) +} + +// Deserialize will attempt to deserialize a blob of task data, +// which will involve unflattening previously serialized data and +// leave some message structures as placeholders to holding IDs. +func (msg *FromFediAPI) Deserialize(data []byte) error { +	var imsg fromFediAPI + +	// Unmarshal as internal JSON type. +	err := json.Unmarshal(data, &imsg) +	if err != nil { +		return err +	} + +	// Copy over the simplest fields. +	msg.APObjectType = imsg.APObjectType +	msg.APActivityType = imsg.APActivityType +	msg.TargetURI = imsg.TargetURI + +	// Resolve AP object from JSON data. +	msg.APObject, err = resolveAPObject( +		imsg.APObject, +	) +	if err != nil { +		return err +	} + +	// Resolve Go type from JSON data. +	msg.GTSModel, err = resolveGTSModel( +		imsg.GTSModelType, +		imsg.GTSModel, +	) +	if err != nil { +		return err +	} + +	if imsg.RequestingID != "" { +		// Set requesting account ID using a +		// barebones model (later filled in). +		msg.Requesting = new(gtsmodel.Account) +		msg.Requesting.ID = imsg.RequestingID +	} + +	if imsg.ReceivingID != "" { +		// Set target account ID using a +		// barebones model (later filled in). +		msg.Receiving = new(gtsmodel.Account) +		msg.Receiving.ID = imsg.ReceivingID +	} + +	return nil +} +  // FederatorMsgIndices defines queue indices this  // message type should be accessible / stored under.  func FederatorMsgIndices() []structr.IndexConfig { @@ -101,3 +333,61 @@ func FederatorMsgIndices() []structr.IndexConfig {  		{Fields: "Receiving.ID", Multiple: true},  	}  } + +// resolveAPObject resolves an ActivityPub object from its "serialized" JSON map +// (yes the terminology here is weird, but that's how go-fed/activity is written). +func resolveAPObject(data map[string]interface{}) (interface{}, error) { +	if len(data) == 0 { +		// No data given. +		return nil, nil +	} + +	// Resolve vocab.Type from "raw" input data map. +	return streams.ToType(context.Background(), data) +} + +// resolveGTSModel is unfortunately where things get messy... our data is stored as JSON +// in the database, which serializes struct types as key-value pairs surrounded by curly +// braces. Deserializing from that gives us back a data blob of key-value pairs, which +// we then need to wrangle back into the original type. So we also store the type name +// and use this to determine the appropriate Go structure type to unmarshal into to. +func resolveGTSModel(typ string, data []byte) (interface{}, error) { +	if typ == "" && data == nil { +		// No data given. +		return nil, nil +	} + +	var value interface{} + +	switch typ { +	case reflect.TypeOf((*gtsmodel.Account)(nil)).String(): +		value = new(gtsmodel.Account) +	case reflect.TypeOf((*gtsmodel.Block)(nil)).String(): +		value = new(gtsmodel.Block) +	case reflect.TypeOf((*gtsmodel.Follow)(nil)).String(): +		value = new(gtsmodel.Follow) +	case reflect.TypeOf((*gtsmodel.FollowRequest)(nil)).String(): +		value = new(gtsmodel.FollowRequest) +	case reflect.TypeOf((*gtsmodel.Move)(nil)).String(): +		value = new(gtsmodel.Move) +	case reflect.TypeOf((*gtsmodel.Poll)(nil)).String(): +		value = new(gtsmodel.Poll) +	case reflect.TypeOf((*gtsmodel.PollVote)(nil)).String(): +		value = new(*gtsmodel.PollVote) +	case reflect.TypeOf((*gtsmodel.Report)(nil)).String(): +		value = new(gtsmodel.Report) +	case reflect.TypeOf((*gtsmodel.Status)(nil)).String(): +		value = new(gtsmodel.Status) +	case reflect.TypeOf((*gtsmodel.StatusFave)(nil)).String(): +		value = new(gtsmodel.StatusFave) +	default: +		return nil, gtserror.Newf("unknown type: %s", typ) +	} + +	// Attempt to unmarshal value JSON into destination. +	if err := json.Unmarshal(data, &value); err != nil { +		return nil, gtserror.Newf("error unmarshaling %s value data: %w", typ, err) +	} + +	return value, nil +} diff --git a/internal/messages/messages_test.go b/internal/messages/messages_test.go new file mode 100644 index 000000000..e5b2a2841 --- /dev/null +++ b/internal/messages/messages_test.go @@ -0,0 +1,292 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package messages_test + +import ( +	"bytes" +	"encoding/json" +	"net/url" +	"testing" + +	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/testrig" + +	"github.com/google/go-cmp/cmp" +) + +var testStatus = testrig.NewTestStatuses()["admin_account_status_1"] + +var testAccount = testrig.NewTestAccounts()["admin_account"] + +var fromClientAPICases = []struct { +	msg  messages.FromClientAPI +	data []byte +}{ +	{ +		msg: messages.FromClientAPI{ +			APObjectType:   ap.ObjectNote, +			APActivityType: ap.ActivityCreate, +			GTSModel:       testStatus, +			TargetURI:      "https://gotosocial.org", +			Origin:         >smodel.Account{ID: "654321"}, +			Target:         >smodel.Account{ID: "123456"}, +		}, +		data: toJSON(map[string]any{ +			"ap_object_type":   ap.ObjectNote, +			"ap_activity_type": ap.ActivityCreate, +			"gts_model":        json.RawMessage(toJSON(testStatus)), +			"gts_model_type":   "*gtsmodel.Status", +			"target_uri":       "https://gotosocial.org", +			"origin_id":        "654321", +			"target_id":        "123456", +		}), +	}, +	{ +		msg: messages.FromClientAPI{ +			APObjectType:   ap.ObjectProfile, +			APActivityType: ap.ActivityUpdate, +			GTSModel:       testAccount, +			TargetURI:      "https://uk-queen-is-dead.org", +			Origin:         >smodel.Account{ID: "123456"}, +			Target:         >smodel.Account{ID: "654321"}, +		}, +		data: toJSON(map[string]any{ +			"ap_object_type":   ap.ObjectProfile, +			"ap_activity_type": ap.ActivityUpdate, +			"gts_model":        json.RawMessage(toJSON(testAccount)), +			"gts_model_type":   "*gtsmodel.Account", +			"target_uri":       "https://uk-queen-is-dead.org", +			"origin_id":        "123456", +			"target_id":        "654321", +		}), +	}, +} + +var fromFediAPICases = []struct { +	msg  messages.FromFediAPI +	data []byte +}{ +	{ +		msg: messages.FromFediAPI{ +			APObjectType:   ap.ObjectNote, +			APActivityType: ap.ActivityCreate, +			GTSModel:       testStatus, +			TargetURI:      "https://gotosocial.org", +			Requesting:     >smodel.Account{ID: "654321"}, +			Receiving:      >smodel.Account{ID: "123456"}, +		}, +		data: toJSON(map[string]any{ +			"ap_object_type":   ap.ObjectNote, +			"ap_activity_type": ap.ActivityCreate, +			"gts_model":        json.RawMessage(toJSON(testStatus)), +			"gts_model_type":   "*gtsmodel.Status", +			"target_uri":       "https://gotosocial.org", +			"requesting_id":    "654321", +			"receiving_id":     "123456", +		}), +	}, +	{ +		msg: messages.FromFediAPI{ +			APObjectType:   ap.ObjectProfile, +			APActivityType: ap.ActivityUpdate, +			GTSModel:       testAccount, +			TargetURI:      "https://uk-queen-is-dead.org", +			Requesting:     >smodel.Account{ID: "123456"}, +			Receiving:      >smodel.Account{ID: "654321"}, +		}, +		data: toJSON(map[string]any{ +			"ap_object_type":   ap.ObjectProfile, +			"ap_activity_type": ap.ActivityUpdate, +			"gts_model":        json.RawMessage(toJSON(testAccount)), +			"gts_model_type":   "*gtsmodel.Account", +			"target_uri":       "https://uk-queen-is-dead.org", +			"requesting_id":    "123456", +			"receiving_id":     "654321", +		}), +	}, +} + +func TestSerializeFromClientAPI(t *testing.T) { +	for _, test := range fromClientAPICases { +		// Serialize test message to blob. +		data, err := test.msg.Serialize() +		if err != nil { +			t.Fatal(err) +		} + +		// Check serialized JSON data as expected. +		assertJSONEqual(t, test.data, data) +	} +} + +func TestDeserializeFromClientAPI(t *testing.T) { +	for _, test := range fromClientAPICases { +		var msg messages.FromClientAPI + +		// Deserialize test message blob. +		err := msg.Deserialize(test.data) +		if err != nil { +			t.Fatal(err) +		} + +		// Check that msg is as expected. +		assertEqual(t, test.msg.APActivityType, msg.APActivityType) +		assertEqual(t, test.msg.APObjectType, msg.APObjectType) +		assertEqual(t, test.msg.GTSModel, msg.GTSModel) +		assertEqual(t, test.msg.TargetURI, msg.TargetURI) +		assertEqual(t, accountID(test.msg.Origin), accountID(msg.Origin)) +		assertEqual(t, accountID(test.msg.Target), accountID(msg.Target)) + +		// Perform final check to ensure +		// account model keys deserialized. +		assertEqualRSA(t, test.msg.GTSModel, msg.GTSModel) +	} +} + +func TestSerializeFromFediAPI(t *testing.T) { +	for _, test := range fromFediAPICases { +		// Serialize test message to blob. +		data, err := test.msg.Serialize() +		if err != nil { +			t.Fatal(err) +		} + +		// Check serialized JSON data as expected. +		assertJSONEqual(t, test.data, data) +	} +} + +func TestDeserializeFromFediAPI(t *testing.T) { +	for _, test := range fromFediAPICases { +		var msg messages.FromFediAPI + +		// Deserialize test message blob. +		err := msg.Deserialize(test.data) +		if err != nil { +			t.Fatal(err) +		} + +		// Check that msg is as expected. +		assertEqual(t, test.msg.APActivityType, msg.APActivityType) +		assertEqual(t, test.msg.APObjectType, msg.APObjectType) +		assertEqual(t, urlStr(test.msg.APIRI), urlStr(msg.APIRI)) +		assertEqual(t, test.msg.APObject, msg.APObject) +		assertEqual(t, test.msg.GTSModel, msg.GTSModel) +		assertEqual(t, test.msg.TargetURI, msg.TargetURI) +		assertEqual(t, accountID(test.msg.Receiving), accountID(msg.Receiving)) +		assertEqual(t, accountID(test.msg.Requesting), accountID(msg.Requesting)) + +		// Perform final check to ensure +		// account model keys deserialized. +		assertEqualRSA(t, test.msg.GTSModel, msg.GTSModel) +	} +} + +// assertEqualRSA asserts that test account model RSA keys are equal. +func assertEqualRSA(t *testing.T, expect, receive any) bool { +	t.Helper() + +	account1, ok1 := expect.(*gtsmodel.Account) + +	account2, ok2 := receive.(*gtsmodel.Account) + +	if ok1 != ok2 { +		t.Errorf("different model types: expect=%T receive=%T", expect, receive) +		return false +	} else if !ok1 { +		return true +	} + +	if !account1.PublicKey.Equal(account2.PublicKey) { +		t.Error("public keys do not match") +		return false +	} + +	t.Logf("publickey=%v", account1.PublicKey) + +	if !account1.PrivateKey.Equal(account2.PrivateKey) { +		t.Error("private keys do not match") +		return false +	} + +	t.Logf("privatekey=%v", account1.PrivateKey) + +	return true +} + +// assertEqual asserts that two values (of any type!) are equal, +// note we use the 'cmp' library here as it's much more useful in +// outputting debug information than testify, and handles more complex +// types like rsa public / private key comparisons correctly. +func assertEqual(t *testing.T, expect, receive any) bool { +	t.Helper() +	if diff := cmp.Diff(expect, receive); diff != "" { +		t.Error(diff) +		return false +	} +	return true +} + +// assertJSONEqual asserts that two slices of JSON data are equal. +func assertJSONEqual(t *testing.T, expect, receive []byte) bool { +	t.Helper() +	return assertEqual(t, fromJSON(expect), fromJSON(receive)) +} + +// urlStr returns url as string, or empty. +func urlStr(url *url.URL) string { +	if url == nil { +		return "" +	} +	return url.String() +} + +// accountID returns account's ID, or empty. +func accountID(account *gtsmodel.Account) string { +	if account == nil { +		return "" +	} +	return account.ID +} + +// fromJSON unmarshals input data as JSON. +func fromJSON(b []byte) any { +	r := bytes.NewReader(b) +	d := json.NewDecoder(r) +	d.UseNumber() +	var a any +	err := d.Decode(&a) +	if err != nil { +		panic(err) +	} +	if d.More() { +		panic("multiple json values in b") +	} +	return a +} + +// toJSON marshals input type as JSON data. +func toJSON(a any) []byte { +	b, err := json.Marshal(a) +	if err != nil { +		panic(err) +	} +	return b +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index a7e73465d..30435b86f 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -18,12 +18,12 @@  package transport  import ( +	"bytes"  	"context"  	"encoding/json"  	"net/http"  	"net/url" -	"codeberg.org/gruf/go-byteutil"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -130,25 +130,28 @@ func (t *transport) prepare(  	*delivery.Delivery,  	error,  ) { -	url := to.String() - -	// Use rewindable reader for body. -	var body byteutil.ReadNopCloser -	body.Reset(data) -  	// Prepare POST signer.  	sign := t.signPOST(data) +	// Use *bytes.Reader for request body, +	// as NewRequest() automatically will +	// set .GetBody and content-length. +	// (this handles necessary rewinding). +	body := bytes.NewReader(data) +  	// Update to-be-used request context with signing details.  	ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID)  	ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)  	// Prepare a new request with data body directed at URL. -	r, err := http.NewRequestWithContext(ctx, "POST", url, &body) +	r, err := http.NewRequestWithContext(ctx, "POST", to.String(), body)  	if err != nil {  		return nil, gtserror.Newf("error preparing request: %w", err)  	} +	// Set our predefined controller user-agent. +	r.Header.Set("User-Agent", t.controller.userAgent) +  	// Set the standard ActivityPub content-type + charset headers.  	r.Header.Add("Content-Type", string(apiutil.AppActivityLDJSON))  	r.Header.Add("Accept-Charset", "utf-8") diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go index 1e9126b2e..1e3ebb054 100644 --- a/internal/transport/delivery/delivery.go +++ b/internal/transport/delivery/delivery.go @@ -18,16 +18,13 @@  package delivery  import ( -	"context" -	"slices" +	"bytes" +	"encoding/json" +	"io" +	"net/http"  	"time" -	"codeberg.org/gruf/go-runners" -	"codeberg.org/gruf/go-structr" -	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/httpclient" -	"github.com/superseriousbusiness/gotosocial/internal/queue" -	"github.com/superseriousbusiness/gotosocial/internal/util"  )  // Delivery wraps an httpclient.Request{} @@ -36,6 +33,9 @@ import (  // be indexed (and so, dropped from queue)  // by any of these possible ID IRIs.  type Delivery struct { +	// PubKeyID is the signing public key +	// ID of the actor performing request. +	PubKeyID string  	// ActorID contains the ActivityPub  	// actor ID IRI (if any) of the activity @@ -61,273 +61,98 @@ type Delivery struct {  	next time.Time  } -func (dlv *Delivery) backoff() time.Duration { -	if dlv.next.IsZero() { -		return 0 -	} -	return time.Until(dlv.next) +// delivery is an internal type +// for Delivery{} that provides +// a json serialize / deserialize +// able shape that minimizes data. +type delivery struct { +	PubKeyID string              `json:"pub_key_id,omitempty"` +	ActorID  string              `json:"actor_id,omitempty"` +	ObjectID string              `json:"object_id,omitempty"` +	TargetID string              `json:"target_id,omitempty"` +	Method   string              `json:"method,omitempty"` +	Header   map[string][]string `json:"header,omitempty"` +	URL      string              `json:"url,omitempty"` +	Body     []byte              `json:"body,omitempty"`  } -// WorkerPool wraps multiple Worker{}s in -// a singular struct for easy multi start/stop. -type WorkerPool struct { +// Serialize will serialize the delivery data as data blob for storage, +// note that this will flatten some of the data, dropping signing funcs. +func (dlv *Delivery) Serialize() ([]byte, error) { +	var body []byte -	// Client defines httpclient.Client{} -	// passed to each of delivery pool Worker{}s. -	Client *httpclient.Client +	if dlv.Request.GetBody != nil { +		// Fetch a fresh copy of request body. +		rbody, err := dlv.Request.GetBody() +		if err != nil { +			return nil, err +		} -	// Queue is the embedded queue.StructQueue{} -	// passed to each of delivery pool Worker{}s. -	Queue queue.StructQueue[*Delivery] +		// Read request body into memory. +		body, err = io.ReadAll(rbody) -	// internal fields. -	workers []*Worker -} +		// Done with body. +		_ = rbody.Close() -// Init will initialize the Worker{} pool -// with given http client, request queue to pull -// from and number of delivery workers to spawn. -func (p *WorkerPool) Init(client *httpclient.Client) { -	p.Client = client -	p.Queue.Init(structr.QueueConfig[*Delivery]{ -		Indices: []structr.IndexConfig{ -			{Fields: "ActorID", Multiple: true}, -			{Fields: "ObjectID", Multiple: true}, -			{Fields: "TargetID", Multiple: true}, -		}, -	}) -} - -// Start will attempt to start 'n' Worker{}s. -func (p *WorkerPool) Start(n int) { -	// Check whether workers are -	// set (is already running). -	ok := (len(p.workers) > 0) -	if ok { -		return +		if err != nil { +			return nil, err +		}  	} -	// Allocate new workers slice. -	p.workers = make([]*Worker, n) -	for i := range p.workers { - -		// Allocate new Worker{}. -		p.workers[i] = new(Worker) -		p.workers[i].Client = p.Client -		p.workers[i].Queue = &p.Queue - -		// Attempt to start worker. -		// Return bool not useful -		// here, as true = started, -		// false = already running. -		_ = p.workers[i].Start() -	} +	// Marshal as internal JSON type. +	return json.Marshal(delivery{ +		PubKeyID: dlv.PubKeyID, +		ActorID:  dlv.ActorID, +		ObjectID: dlv.ObjectID, +		TargetID: dlv.TargetID, +		Method:   dlv.Request.Method, +		Header:   dlv.Request.Header, +		URL:      dlv.Request.URL.String(), +		Body:     body, +	})  } -// Stop will attempt to stop contained Worker{}s. -func (p *WorkerPool) Stop() { -	// Check whether workers are -	// set (is currently running). -	ok := (len(p.workers) == 0) -	if ok { -		return -	} - -	// Stop all running workers. -	for i := range p.workers { +// Deserialize will attempt to deserialize a blob of task data, +// which will involve unflattening previously serialized data and +// leave delivery incomplete, still requiring signing func setup. +func (dlv *Delivery) Deserialize(data []byte) error { +	var idlv delivery -		// return bool not useful -		// here, as true = stopped, -		// false = never running. -		_ = p.workers[i].Stop() +	// Unmarshal as internal JSON type. +	err := json.Unmarshal(data, &idlv) +	if err != nil { +		return err  	} -	// Unset workers slice. -	p.workers = p.workers[:0] -} - -// Worker wraps an httpclient.Client{} to feed -// from queue.StructQueue{} for ActivityPub reqs -// to deliver. It does so while prioritizing new -// queued requests over backlogged retries. -type Worker struct { - -	// Client is the httpclient.Client{} that -	// delivery worker will use for requests. -	Client *httpclient.Client - -	// Queue is the Delivery{} message queue -	// that delivery worker will feed from. -	Queue *queue.StructQueue[*Delivery] +	// Copy over simplest fields. +	dlv.PubKeyID = idlv.PubKeyID +	dlv.ActorID = idlv.ActorID +	dlv.ObjectID = idlv.ObjectID +	dlv.TargetID = idlv.TargetID -	// internal fields. -	backlog []*Delivery -	service runners.Service -} - -// Start will attempt to start the Worker{}. -func (w *Worker) Start() bool { -	return w.service.GoRun(w.run) -} +	var body io.Reader -// Stop will attempt to stop the Worker{}. -func (w *Worker) Stop() bool { -	return w.service.Stop() -} - -// run wraps process to restart on any panic. -func (w *Worker) run(ctx context.Context) { -	if w.Client == nil || w.Queue == nil { -		panic("not yet initialized") -	} -	util.Must(func() { w.process(ctx) }) -} - -// process is the main delivery worker processing routine. -func (w *Worker) process(ctx context.Context) bool { -	if w.Client == nil || w.Queue == nil { -		// we perform this check here just -		// to ensure the compiler knows these -		// variables aren't nil in the loop, -		// even if already checked by caller. -		panic("not yet initialized") +	if idlv.Body != nil { +		// Create new body reader from data. +		body = bytes.NewReader(idlv.Body)  	} -loop: -	for { -		// Get next delivery. -		dlv, ok := w.next(ctx) -		if !ok { -			return true -		} - -		// Check whether backoff required. -		const min = 100 * time.Millisecond -		if d := dlv.backoff(); d > min { - -			// Start backoff sleep timer. -			backoff := time.NewTimer(d) - -			select { -			case <-ctx.Done(): -				// Main ctx -				// cancelled. -				backoff.Stop() -				return true - -			case <-w.Queue.Wait(): -				// A new message was -				// queued, re-add this -				// to backlog + retry. -				w.pushBacklog(dlv) -				backoff.Stop() -				continue loop - -			case <-backoff.C: -				// success! -			} -		} - -		// Attempt delivery of AP request. -		rsp, retry, err := w.Client.DoOnce( -			&dlv.Request, -		) - -		if err == nil { -			// Ensure body closed. -			_ = rsp.Body.Close() -			continue loop -		} - -		if !retry { -			// Drop deliveries when no -			// retry requested, or they -			// reached max (either). -			continue loop -		} - -		// Determine next delivery attempt. -		backoff := dlv.Request.BackOff() -		dlv.next = time.Now().Add(backoff) - -		// Push to backlog. -		w.pushBacklog(dlv) +	// Create a new request object from unmarshaled details. +	r, err := http.NewRequest(idlv.Method, idlv.URL, body) +	if err != nil { +		return err  	} -} - -// next gets the next available delivery, blocking until available if necessary. -func (w *Worker) next(ctx context.Context) (*Delivery, bool) { -loop: -	for { -		// Try pop next queued. -		dlv, ok := w.Queue.Pop() - -		if !ok { -			// Check the backlog. -			if len(w.backlog) > 0 { - -				// Sort by 'next' time. -				sortDeliveries(w.backlog) - -				// Pop next delivery. -				dlv := w.popBacklog() - -				return dlv, true -			} -			select { -			// Backlog is empty, we MUST -			// block until next enqueued. -			case <-w.Queue.Wait(): -				continue loop +	// Wrap request in httpclient type. +	dlv.Request = httpclient.WrapRequest(r) -			// Worker was stopped. -			case <-ctx.Done(): -				return nil, false -			} -		} - -		// Replace request context for worker state canceling. -		ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) -		dlv.Request.Request = dlv.Request.Request.WithContext(ctx) - -		return dlv, true -	} +	return nil  } -// popBacklog pops next available from the backlog. -func (w *Worker) popBacklog() *Delivery { -	if len(w.backlog) == 0 { -		return nil +// backoff returns a valid (>= 0) backoff duration. +func (dlv *Delivery) backoff() time.Duration { +	if dlv.next.IsZero() { +		return 0  	} - -	// Pop from backlog. -	dlv := w.backlog[0] - -	// Shift backlog down by one. -	copy(w.backlog, w.backlog[1:]) -	w.backlog = w.backlog[:len(w.backlog)-1] - -	return dlv -} - -// pushBacklog pushes the given delivery to backlog. -func (w *Worker) pushBacklog(dlv *Delivery) { -	w.backlog = append(w.backlog, dlv) -} - -// sortDeliveries sorts deliveries according -// to when is the first requiring re-attempt. -func sortDeliveries(d []*Delivery) { -	slices.SortFunc(d, func(a, b *Delivery) int { -		const k = +1 -		switch { -		case a.next.Before(b.next): -			return +k -		case b.next.Before(a.next): -			return -k -		default: -			return 0 -		} -	}) +	return time.Until(dlv.next)  } diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go index 48831f098..e9eaf8fd1 100644 --- a/internal/transport/delivery/delivery_test.go +++ b/internal/transport/delivery/delivery_test.go @@ -1,203 +1,134 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. +  package delivery_test  import ( -	"fmt" +	"bytes" +	"encoding/json"  	"io" -	"math/rand" -	"net"  	"net/http" -	"strconv" -	"strings"  	"testing" -	"codeberg.org/gruf/go-byteutil" -	"github.com/superseriousbusiness/gotosocial/internal/config" +	"github.com/stretchr/testify/assert"  	"github.com/superseriousbusiness/gotosocial/internal/httpclient" -	"github.com/superseriousbusiness/gotosocial/internal/queue"  	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"  ) -func TestDeliveryWorkerPool(t *testing.T) { -	for _, i := range []int{1, 2, 4, 8, 16, 32} { -		t.Run("size="+strconv.Itoa(i), func(t *testing.T) { -			testDeliveryWorkerPool(t, i, generateInput(100*i)) -		}) -	} -} - -func testDeliveryWorkerPool(t *testing.T, sz int, input []*testrequest) { -	wp := new(delivery.WorkerPool) -	wp.Init(httpclient.New(httpclient.Config{ -		AllowRanges: config.MustParseIPPrefixes([]string{ -			"127.0.0.0/8", +var deliveryCases = []struct { +	msg  delivery.Delivery +	data []byte +}{ +	{ +		msg: delivery.Delivery{ +			PubKeyID: "https://google.com/users/bigboy#pubkey", +			ActorID:  "https://google.com/users/bigboy", +			ObjectID: "https://google.com/users/bigboy/follow/1", +			TargetID: "https://askjeeves.com/users/smallboy", +			Request:  toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")), +		}, +		data: toJSON(map[string]any{ +			"pub_key_id": "https://google.com/users/bigboy#pubkey", +			"actor_id":   "https://google.com/users/bigboy", +			"object_id":  "https://google.com/users/bigboy/follow/1", +			"target_id":  "https://askjeeves.com/users/smallboy", +			"method":     "POST", +			"url":        "https://askjeeves.com/users/smallboy/inbox", +			"body":       []byte("data!"), +			// "header":     map[string][]string{}, +		}), +	}, +	{ +		msg: delivery.Delivery{ +			Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")), +		}, +		data: toJSON(map[string]any{ +			"method": "GET", +			"url":    "https://google.com", +			"body":   []byte("uwu im just a wittle seawch engwin"), +			// "header":     map[string][]string{},  		}), -	})) -	wp.Start(sz) -	defer wp.Stop() -	test(t, &wp.Queue, input) +	},  } -func test( -	t *testing.T, -	queue *queue.StructQueue[*delivery.Delivery], -	input []*testrequest, -) { -	expect := make(chan *testrequest) -	errors := make(chan error) - -	// Prepare an HTTP test handler that ensures expected delivery is received. -	handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { -		errors <- (<-expect).Equal(r) -	}) - -	// Start new HTTP test server listener. -	l, err := net.Listen("tcp", "127.0.0.1:0") -	if err != nil { -		t.Fatal(err) -	} -	defer l.Close() - -	// Start the HTTP server. -	// -	// specifically not using httptest.Server{} here as httptest -	// links that server with its own http.Client{}, whereas we're -	// using an httpclient.Client{} (well, delivery routine is). -	srv := new(http.Server) -	srv.Addr = "http://" + l.Addr().String() -	srv.Handler = handler -	go srv.Serve(l) -	defer srv.Close() - -	// Range over test input. -	for _, test := range input { - -		// Generate req for input. -		req := test.Generate(srv.Addr) -		r := httpclient.WrapRequest(req) - -		// Wrap the request in delivery. -		dlv := new(delivery.Delivery) -		dlv.Request = r - -		// Enqueue delivery! -		queue.Push(dlv) -		expect <- test - -		// Wait for errors from handler. -		if err := <-errors; err != nil { -			t.Error(err) +func TestSerializeDelivery(t *testing.T) { +	for _, test := range deliveryCases { +		// Serialize test message to blob. +		data, err := test.msg.Serialize() +		if err != nil { +			t.Fatal(err)  		} -	} -} - -type testrequest struct { -	method string -	uri    string -	body   []byte -} -// generateInput generates 'n' many testrequest cases. -func generateInput(n int) []*testrequest { -	tests := make([]*testrequest, n) -	for i := range tests { -		tests[i] = new(testrequest) -		tests[i].method = randomMethod() -		tests[i].uri = randomURI() -		tests[i].body = randomBody(tests[i].method) +		// Check that serialized JSON data is as expected. +		assert.JSONEq(t, string(test.data), string(data))  	} -	return tests  } -var methods = []string{ -	http.MethodConnect, -	http.MethodDelete, -	http.MethodGet, -	http.MethodHead, -	http.MethodOptions, -	http.MethodPatch, -	http.MethodPost, -	http.MethodPut, -	http.MethodTrace, -} +func TestDeserializeDelivery(t *testing.T) { +	for _, test := range deliveryCases { +		var msg delivery.Delivery -// randomMethod generates a random http method. -func randomMethod() string { -	return methods[rand.Intn(len(methods))] -} +		// Deserialize test message blob. +		err := msg.Deserialize(test.data) +		if err != nil { +			t.Fatal(err) +		} -// randomURI generates a random http uri. -func randomURI() string { -	n := rand.Intn(5) -	p := make([]string, n) -	for i := range p { -		p[i] = strconv.Itoa(rand.Int()) +		// Check that delivery fields are as expected. +		assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID) +		assert.Equal(t, test.msg.ActorID, msg.ActorID) +		assert.Equal(t, test.msg.ObjectID, msg.ObjectID) +		assert.Equal(t, test.msg.TargetID, msg.TargetID) +		assert.Equal(t, test.msg.Request.Method, msg.Request.Method) +		assert.Equal(t, test.msg.Request.URL, msg.Request.URL) +		assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body))  	} -	return "/" + strings.Join(p, "/")  } -// randomBody generates a random http body DEPENDING on method. -func randomBody(method string) []byte { -	if requiresBody(method) { -		return []byte(method + " " + randomURI()) +// toRequest creates httpclient.Request from HTTP method, URL and body data. +func toRequest(method string, url string, body []byte) httpclient.Request { +	var rbody io.Reader +	if body != nil { +		rbody = bytes.NewReader(body)  	} -	return nil -} - -// requiresBody returns whether method requires body. -func requiresBody(method string) bool { -	switch method { -	case http.MethodPatch, -		http.MethodPost, -		http.MethodPut: -		return true -	default: -		return false +	req, err := http.NewRequest(method, url, rbody) +	if err != nil { +		panic(err)  	} +	return httpclient.WrapRequest(req)  } -// Generate will generate a real http.Request{} from test data. -func (t *testrequest) Generate(addr string) *http.Request { -	var body io.ReadCloser -	if t.body != nil { -		var b byteutil.ReadNopCloser -		b.Reset(t.body) -		body = &b +// readBody reads the content of body io.ReadCloser into memory as byte slice. +func readBody(r io.ReadCloser) []byte { +	if r == nil { +		return nil  	} -	req, err := http.NewRequest(t.method, addr+t.uri, body) +	b, err := io.ReadAll(r)  	if err != nil {  		panic(err)  	} -	return req +	return b  } -// Equal checks if request matches receiving test request. -func (t *testrequest) Equal(r *http.Request) error { -	// Ensure methods match. -	if t.method != r.Method { -		return fmt.Errorf("differing request methods: t=%q r=%q", t.method, r.Method) -	} - -	// Ensure request URIs match. -	if t.uri != r.URL.RequestURI() { -		return fmt.Errorf("differing request urls: t=%q r=%q", t.uri, r.URL.RequestURI()) -	} - -	// Ensure body cases match. -	if requiresBody(t.method) { - -		// Read request into memory. -		b, err := io.ReadAll(r.Body) -		if err != nil { -			return fmt.Errorf("error reading request body: %v", err) -		} - -		// Compare the request bodies. -		st := strings.TrimSpace(string(t.body)) -		sr := strings.TrimSpace(string(b)) -		if st != sr { -			return fmt.Errorf("differing request bodies: t=%q r=%q", st, sr) -		} +// toJSON marshals input type as JSON data. +func toJSON(a any) []byte { +	b, err := json.Marshal(a) +	if err != nil { +		panic(err)  	} - -	return nil +	return b  } diff --git a/internal/transport/delivery/worker.go b/internal/transport/delivery/worker.go new file mode 100644 index 000000000..1ed974e84 --- /dev/null +++ b/internal/transport/delivery/worker.go @@ -0,0 +1,298 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package delivery + +import ( +	"context" +	"slices" +	"time" + +	"codeberg.org/gruf/go-runners" +	"codeberg.org/gruf/go-structr" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/queue" +	"github.com/superseriousbusiness/gotosocial/internal/util" +) + +// WorkerPool wraps multiple Worker{}s in +// a singular struct for easy multi start/stop. +type WorkerPool struct { + +	// Client defines httpclient.Client{} +	// passed to each of delivery pool Worker{}s. +	Client *httpclient.Client + +	// Queue is the embedded queue.StructQueue{} +	// passed to each of delivery pool Worker{}s. +	Queue queue.StructQueue[*Delivery] + +	// internal fields. +	workers []*Worker +} + +// Init will initialize the Worker{} pool +// with given http client, request queue to pull +// from and number of delivery workers to spawn. +func (p *WorkerPool) Init(client *httpclient.Client) { +	p.Client = client +	p.Queue.Init(structr.QueueConfig[*Delivery]{ +		Indices: []structr.IndexConfig{ +			{Fields: "ActorID", Multiple: true}, +			{Fields: "ObjectID", Multiple: true}, +			{Fields: "TargetID", Multiple: true}, +		}, +	}) +} + +// Start will attempt to start 'n' Worker{}s. +func (p *WorkerPool) Start(n int) { +	// Check whether workers are +	// set (is already running). +	ok := (len(p.workers) > 0) +	if ok { +		return +	} + +	// Allocate new workers slice. +	p.workers = make([]*Worker, n) +	for i := range p.workers { + +		// Allocate new Worker{}. +		p.workers[i] = new(Worker) +		p.workers[i].Client = p.Client +		p.workers[i].Queue = &p.Queue + +		// Attempt to start worker. +		// Return bool not useful +		// here, as true = started, +		// false = already running. +		_ = p.workers[i].Start() +	} +} + +// Stop will attempt to stop contained Worker{}s. +func (p *WorkerPool) Stop() { +	// Check whether workers are +	// set (is currently running). +	ok := (len(p.workers) == 0) +	if ok { +		return +	} + +	// Stop all running workers. +	for i := range p.workers { + +		// return bool not useful +		// here, as true = stopped, +		// false = never running. +		_ = p.workers[i].Stop() +	} + +	// Unset workers slice. +	p.workers = p.workers[:0] +} + +// Worker wraps an httpclient.Client{} to feed +// from queue.StructQueue{} for ActivityPub reqs +// to deliver. It does so while prioritizing new +// queued requests over backlogged retries. +type Worker struct { + +	// Client is the httpclient.Client{} that +	// delivery worker will use for requests. +	Client *httpclient.Client + +	// Queue is the Delivery{} message queue +	// that delivery worker will feed from. +	Queue *queue.StructQueue[*Delivery] + +	// internal fields. +	backlog []*Delivery +	service runners.Service +} + +// Start will attempt to start the Worker{}. +func (w *Worker) Start() bool { +	return w.service.GoRun(w.run) +} + +// Stop will attempt to stop the Worker{}. +func (w *Worker) Stop() bool { +	return w.service.Stop() +} + +// run wraps process to restart on any panic. +func (w *Worker) run(ctx context.Context) { +	if w.Client == nil || w.Queue == nil { +		panic("not yet initialized") +	} +	log.Infof(ctx, "%p: starting worker", w) +	defer log.Infof(ctx, "%p: stopped worker", w) +	util.Must(func() { w.process(ctx) }) +} + +// process is the main delivery worker processing routine. +func (w *Worker) process(ctx context.Context) bool { +	if w.Client == nil || w.Queue == nil { +		// we perform this check here just +		// to ensure the compiler knows these +		// variables aren't nil in the loop, +		// even if already checked by caller. +		panic("not yet initialized") +	} + +loop: +	for { +		// Get next delivery. +		dlv, ok := w.next(ctx) +		if !ok { +			return true +		} + +		// Check whether backoff required. +		const min = 100 * time.Millisecond +		if d := dlv.backoff(); d > min { + +			// Start backoff sleep timer. +			backoff := time.NewTimer(d) + +			select { +			case <-ctx.Done(): +				// Main ctx +				// cancelled. +				backoff.Stop() +				return true + +			case <-w.Queue.Wait(): +				// A new message was +				// queued, re-add this +				// to backlog + retry. +				w.pushBacklog(dlv) +				backoff.Stop() +				continue loop + +			case <-backoff.C: +				// success! +			} +		} + +		// Attempt delivery of AP request. +		rsp, retry, err := w.Client.DoOnce( +			&dlv.Request, +		) + +		if err == nil { +			// Ensure body closed. +			_ = rsp.Body.Close() +			continue loop +		} + +		if !retry { +			// Drop deliveries when no +			// retry requested, or they +			// reached max (either). +			continue loop +		} + +		// Determine next delivery attempt. +		backoff := dlv.Request.BackOff() +		dlv.next = time.Now().Add(backoff) + +		// Push to backlog. +		w.pushBacklog(dlv) +	} +} + +// next gets the next available delivery, blocking until available if necessary. +func (w *Worker) next(ctx context.Context) (*Delivery, bool) { +loop: +	for { +		// Try pop next queued. +		dlv, ok := w.Queue.Pop() + +		if !ok { +			// Check the backlog. +			if len(w.backlog) > 0 { + +				// Sort by 'next' time. +				sortDeliveries(w.backlog) + +				// Pop next delivery. +				dlv := w.popBacklog() + +				return dlv, true +			} + +			select { +			// Backlog is empty, we MUST +			// block until next enqueued. +			case <-w.Queue.Wait(): +				continue loop + +			// Worker was stopped. +			case <-ctx.Done(): +				return nil, false +			} +		} + +		// Replace request context for worker state canceling. +		ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) +		dlv.Request.Request = dlv.Request.Request.WithContext(ctx) + +		return dlv, true +	} +} + +// popBacklog pops next available from the backlog. +func (w *Worker) popBacklog() *Delivery { +	if len(w.backlog) == 0 { +		return nil +	} + +	// Pop from backlog. +	dlv := w.backlog[0] + +	// Shift backlog down by one. +	copy(w.backlog, w.backlog[1:]) +	w.backlog = w.backlog[:len(w.backlog)-1] + +	return dlv +} + +// pushBacklog pushes the given delivery to backlog. +func (w *Worker) pushBacklog(dlv *Delivery) { +	w.backlog = append(w.backlog, dlv) +} + +// sortDeliveries sorts deliveries according +// to when is the first requiring re-attempt. +func sortDeliveries(d []*Delivery) { +	slices.SortFunc(d, func(a, b *Delivery) int { +		const k = +1 +		switch { +		case a.next.Before(b.next): +			return +k +		case b.next.Before(a.next): +			return -k +		default: +			return 0 +		} +	}) +} diff --git a/internal/transport/delivery/worker_test.go b/internal/transport/delivery/worker_test.go new file mode 100644 index 000000000..936ce6e1d --- /dev/null +++ b/internal/transport/delivery/worker_test.go @@ -0,0 +1,220 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package delivery_test + +import ( +	"fmt" +	"io" +	"math/rand" +	"net" +	"net/http" +	"strconv" +	"strings" +	"testing" + +	"codeberg.org/gruf/go-byteutil" +	"github.com/superseriousbusiness/gotosocial/internal/config" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient" +	"github.com/superseriousbusiness/gotosocial/internal/queue" +	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery" +) + +func TestDeliveryWorkerPool(t *testing.T) { +	for _, i := range []int{1, 2, 4, 8, 16, 32} { +		t.Run("size="+strconv.Itoa(i), func(t *testing.T) { +			testDeliveryWorkerPool(t, i, generateInput(100*i)) +		}) +	} +} + +func testDeliveryWorkerPool(t *testing.T, sz int, input []*testrequest) { +	wp := new(delivery.WorkerPool) +	wp.Init(httpclient.New(httpclient.Config{ +		AllowRanges: config.MustParseIPPrefixes([]string{ +			"127.0.0.0/8", +		}), +	})) +	wp.Start(sz) +	defer wp.Stop() +	test(t, &wp.Queue, input) +} + +func test( +	t *testing.T, +	queue *queue.StructQueue[*delivery.Delivery], +	input []*testrequest, +) { +	expect := make(chan *testrequest) +	errors := make(chan error) + +	// Prepare an HTTP test handler that ensures expected delivery is received. +	handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { +		errors <- (<-expect).Equal(r) +	}) + +	// Start new HTTP test server listener. +	l, err := net.Listen("tcp", "127.0.0.1:0") +	if err != nil { +		t.Fatal(err) +	} +	defer l.Close() + +	// Start the HTTP server. +	// +	// specifically not using httptest.Server{} here as httptest +	// links that server with its own http.Client{}, whereas we're +	// using an httpclient.Client{} (well, delivery routine is). +	srv := new(http.Server) +	srv.Addr = "http://" + l.Addr().String() +	srv.Handler = handler +	go srv.Serve(l) +	defer srv.Close() + +	// Range over test input. +	for _, test := range input { + +		// Generate req for input. +		req := test.Generate(srv.Addr) +		r := httpclient.WrapRequest(req) + +		// Wrap the request in delivery. +		dlv := new(delivery.Delivery) +		dlv.Request = r + +		// Enqueue delivery! +		queue.Push(dlv) +		expect <- test + +		// Wait for errors from handler. +		if err := <-errors; err != nil { +			t.Error(err) +		} +	} +} + +type testrequest struct { +	method string +	uri    string +	body   []byte +} + +// generateInput generates 'n' many testrequest cases. +func generateInput(n int) []*testrequest { +	tests := make([]*testrequest, n) +	for i := range tests { +		tests[i] = new(testrequest) +		tests[i].method = randomMethod() +		tests[i].uri = randomURI() +		tests[i].body = randomBody(tests[i].method) +	} +	return tests +} + +var methods = []string{ +	http.MethodConnect, +	http.MethodDelete, +	http.MethodGet, +	http.MethodHead, +	http.MethodOptions, +	http.MethodPatch, +	http.MethodPost, +	http.MethodPut, +	http.MethodTrace, +} + +// randomMethod generates a random http method. +func randomMethod() string { +	return methods[rand.Intn(len(methods))] +} + +// randomURI generates a random http uri. +func randomURI() string { +	n := rand.Intn(5) +	p := make([]string, n) +	for i := range p { +		p[i] = strconv.Itoa(rand.Int()) +	} +	return "/" + strings.Join(p, "/") +} + +// randomBody generates a random http body DEPENDING on method. +func randomBody(method string) []byte { +	if requiresBody(method) { +		return []byte(method + " " + randomURI()) +	} +	return nil +} + +// requiresBody returns whether method requires body. +func requiresBody(method string) bool { +	switch method { +	case http.MethodPatch, +		http.MethodPost, +		http.MethodPut: +		return true +	default: +		return false +	} +} + +// Generate will generate a real http.Request{} from test data. +func (t *testrequest) Generate(addr string) *http.Request { +	var body io.ReadCloser +	if t.body != nil { +		var b byteutil.ReadNopCloser +		b.Reset(t.body) +		body = &b +	} +	req, err := http.NewRequest(t.method, addr+t.uri, body) +	if err != nil { +		panic(err) +	} +	return req +} + +// Equal checks if request matches receiving test request. +func (t *testrequest) Equal(r *http.Request) error { +	// Ensure methods match. +	if t.method != r.Method { +		return fmt.Errorf("differing request methods: t=%q r=%q", t.method, r.Method) +	} + +	// Ensure request URIs match. +	if t.uri != r.URL.RequestURI() { +		return fmt.Errorf("differing request urls: t=%q r=%q", t.uri, r.URL.RequestURI()) +	} + +	// Ensure body cases match. +	if requiresBody(t.method) { + +		// Read request into memory. +		b, err := io.ReadAll(r.Body) +		if err != nil { +			return fmt.Errorf("error reading request body: %v", err) +		} + +		// Compare the request bodies. +		st := strings.TrimSpace(string(t.body)) +		sr := strings.TrimSpace(string(b)) +		if st != sr { +			return fmt.Errorf("differing request bodies: t=%q r=%q", st, sr) +		} +	} + +	return nil +}  | 
