diff options
Diffstat (limited to 'internal')
23 files changed, 1145 insertions, 296 deletions
diff --git a/internal/api/activitypub/users/inboxpost.go b/internal/api/activitypub/users/inboxpost.go index 03ba5c5a6..b0a9a49ee 100644 --- a/internal/api/activitypub/users/inboxpost.go +++ b/internal/api/activitypub/users/inboxpost.go @@ -18,13 +18,14 @@  package users  import ( -	"errors"  	"net/http"  	"github.com/gin-gonic/gin"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/log" + +	errorsv2 "codeberg.org/gruf/go-errors/v2"  )  // InboxPOSTHandler deals with incoming POST requests to an actor's inbox. @@ -32,18 +33,18 @@ import (  func (m *Module) InboxPOSTHandler(c *gin.Context) {  	_, err := m.processor.Fedi().InboxPost(c.Request.Context(), c.Writer, c.Request)  	if err != nil { -		errWithCode := new(gtserror.WithCode) +		errWithCode := errorsv2.AsV2[gtserror.WithCode](err) -		if !errors.As(err, errWithCode) { +		if errWithCode == nil {  			// Something else went wrong, and someone forgot to return  			// an errWithCode! It's chill though. Log the error but don't  			// return it as-is to the caller, to avoid leaking internals.  			log.Errorf(c.Request.Context(), "returning Bad Request to caller, err was: %q", err) -			*errWithCode = gtserror.NewErrorBadRequest(err) +			errWithCode = gtserror.NewErrorBadRequest(err)  		}  		// Pass along confirmed error with code to the main error handler -		apiutil.ErrorHandler(c, *errWithCode, m.processor.InstanceGetV1) +		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)  		return  	} diff --git a/internal/cache/cache.go b/internal/cache/cache.go index fd715b8e6..3aa21cdd0 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -25,6 +25,7 @@ import (  )  type Caches struct { +  	// GTS provides access to the collection of  	// gtsmodel object caches. (used by the database).  	GTS GTSCaches diff --git a/internal/federation/federatingactor.go b/internal/federation/federatingactor.go index 18cee1666..b9b2c8001 100644 --- a/internal/federation/federatingactor.go +++ b/internal/federation/federatingactor.go @@ -89,7 +89,7 @@ func (f *federatingActor) PostInboxScheme(ctx context.Context, w http.ResponseWr  	// so we specifically have to check for already wrapped with code.  	//  	ctx, authenticated, err := f.sideEffectActor.AuthenticatePostInbox(ctx, w, r) -	if errors.As(err, new(gtserror.WithCode)) { +	if errorsv2.AsV2[gtserror.WithCode](err) != nil {  		// If it was already wrapped with an  		// HTTP code then don't bother rewrapping  		// it, just return it as-is for caller to @@ -131,7 +131,7 @@ func (f *federatingActor) PostInboxScheme(ctx context.Context, w http.ResponseWr  	// Check authorization of the activity; this will include blocks.  	authorized, err := f.sideEffectActor.AuthorizePostInbox(ctx, w, activity)  	if err != nil { -		if errors.As(err, new(errOtherIRIBlocked)) { +		if errorsv2.AsV2[*errOtherIRIBlocked](err) != nil {  			// There's no direct block between requester(s) and  			// receiver. However, one or more of the other IRIs  			// involved in the request (account replied to, note @@ -139,7 +139,7 @@ func (f *federatingActor) PostInboxScheme(ctx context.Context, w http.ResponseWr  			// by the receiver. We don't need to return 403 here,  			// instead, just return 202 accepted but don't do any  			// further processing of the activity. -			return true, nil +			return true, nil //nolint  		}  		// Real error has occurred. diff --git a/internal/federation/federatingactor_test.go b/internal/federation/federatingactor_test.go index 0c805a2c6..b5b65827b 100644 --- a/internal/federation/federatingactor_test.go +++ b/internal/federation/federatingactor_test.go @@ -21,6 +21,7 @@ import (  	"bytes"  	"context"  	"encoding/json" +	"io"  	"net/url"  	"testing"  	"time" @@ -129,23 +130,27 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {  	suite.NotNil(activity)  	// because we added 1 remote follower for zork, there should be a url in sentMessage -	var sent [][]byte +	var sent []byte  	if !testrig.WaitFor(func() bool { -		sentI, ok := httpClient.SentMessages.Load(*testRemoteAccount.SharedInboxURI) -		if ok { -			sent, ok = sentI.([][]byte) -			if !ok { -				panic("SentMessages entry was not []byte") -			} -			return true +		delivery, ok := suite.state.Workers.Delivery.Queue.Pop() +		if !ok { +			return false  		} -		return false +		if !testrig.EqualRequestURIs(delivery.Request.URL, *testRemoteAccount.SharedInboxURI) { +			panic("differing request uris") +		} +		sent, err = io.ReadAll(delivery.Request.Body) +		if err != nil { +			panic("error reading body: " + err.Error()) +		} +		return true +  	}) {  		suite.FailNow("timed out waiting for message")  	}  	dst := new(bytes.Buffer) -	err = json.Indent(dst, sent[0], "", "  ") +	err = json.Indent(dst, sent, "", "  ")  	suite.NoError(err)  	suite.Equal(`{    "@context": "https://www.w3.org/ns/activitystreams", diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go index 7390cf0f5..384291463 100644 --- a/internal/federation/federatingdb/delete.go +++ b/internal/federation/federatingdb/delete.go @@ -51,7 +51,7 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {  	// in a delete we only get the URI, we can't know if we have a status or a profile or something else,  	// so we have to try a few different things...  	if s, err := f.state.DB.GetStatusByURI(ctx, id.String()); err == nil && requestingAcct.ID == s.AccountID { -		l.Debugf("uri is for STATUS with id: %s", s.ID) +		l.Debugf("deleting status: %s", s.ID)  		f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{  			APObjectType:     ap.ObjectNote,  			APActivityType:   ap.ActivityDelete, @@ -61,7 +61,7 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {  	}  	if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAcct.ID == a.ID { -		l.Debugf("uri is for ACCOUNT with id %s", a.ID) +		l.Debugf("deleting account: %s", a.ID)  		f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{  			APObjectType:     ap.ObjectProfile,  			APActivityType:   ap.ActivityDelete, diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index f8e5b4c09..1a655994c 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -44,7 +44,7 @@ type errOtherIRIBlocked struct {  	iriStrs     []string  } -func (e errOtherIRIBlocked) Error() string { +func (e *errOtherIRIBlocked) Error() string {  	iriStrsNice := "[" + strings.Join(e.iriStrs, ", ") + "]"  	if e.domainBlock {  		return "domain block exists for one or more of " + iriStrsNice @@ -67,7 +67,7 @@ func newErrOtherIRIBlocked(  		e.iriStrs = append(e.iriStrs, iri.String())  	} -	return e +	return &e  }  /* diff --git a/internal/federation/federatingprotocol_test.go b/internal/federation/federatingprotocol_test.go index f975cd7d6..085d6c474 100644 --- a/internal/federation/federatingprotocol_test.go +++ b/internal/federation/federatingprotocol_test.go @@ -21,13 +21,13 @@ import (  	"bytes"  	"context"  	"encoding/json" -	"errors"  	"io"  	"net/http"  	"net/http/httptest"  	"net/url"  	"testing" +	errorsv2 "codeberg.org/gruf/go-errors/v2"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -101,8 +101,8 @@ func (suite *FederatingProtocolTestSuite) authenticatePostInbox(  	recorder := httptest.NewRecorder()  	newContext, authed, err := suite.federator.AuthenticatePostInbox(ctx, recorder, request) -	if withCode := new(gtserror.WithCode); (errors.As(err, withCode) && -		(*withCode).Code() >= 500) || (err != nil && (*withCode) == nil) { +	if withCode := errorsv2.AsV2[gtserror.WithCode](err); // nocollapse +	(withCode != nil && withCode.Code() >= 500) || (err != nil && withCode == nil) {  		// NOTE: the behaviour here is a little strange as we have  		// the competing code styles of the go-fed interface expecting  		// that any err is a no-go, but authed bool is intended to be diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index 6c2427372..31c6df7d0 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -35,7 +35,6 @@ import (  	"codeberg.org/gruf/go-cache/v3"  	errorsv2 "codeberg.org/gruf/go-errors/v2"  	"codeberg.org/gruf/go-iotools" -	"codeberg.org/gruf/go-kv"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/log" @@ -109,11 +108,13 @@ type Client struct {  	client   http.Client  	badHosts cache.TTLCache[string, struct{}]  	bodyMax  int64 +	retries  uint  }  // New returns a new instance of Client initialized using configuration.  func New(cfg Config) *Client {  	var c Client +	c.retries = 5  	d := &net.Dialer{  		Timeout:   15 * time.Second, @@ -177,7 +178,7 @@ func New(cfg Config) *Client {  	}}  	// Initiate outgoing bad hosts lookup cache. -	c.badHosts = cache.NewTTL[string, struct{}](0, 1000, 0) +	c.badHosts = cache.NewTTL[string, struct{}](0, 512, 0)  	c.badHosts.SetTTL(time.Hour, false)  	if !c.badHosts.Start(time.Minute) {  		log.Panic(nil, "failed to start transport controller cache") @@ -187,154 +188,184 @@ func New(cfg Config) *Client {  }  // Do will essentially perform http.Client{}.Do() with retry-backoff functionality. -func (c *Client) Do(r *http.Request) (*http.Response, error) { -	return c.DoSigned(r, func(r *http.Request) error { -		return nil // no request signing -	}) -} - -// DoSigned will essentially perform http.Client{}.Do() with retry-backoff functionality and requesting signing.. -func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, err error) { -	const ( -		// max no. attempts. -		maxRetries = 5 - -		// starting backoff duration. -		baseBackoff = 2 * time.Second -	) +func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {  	// First validate incoming request.  	if err := ValidateRequest(r); err != nil {  		return nil, err  	} -	// Get request hostname. -	host := r.URL.Hostname() - -	// Check whether request should fast fail. -	fastFail := gtscontext.IsFastfail(r.Context()) -	if !fastFail { -		// Check if recently reached max retries for this host -		// so we don't bother with a retry-backoff loop. The only -		// errors that are retried upon are server failure, TLS -		// and domain resolution type errors, so this cached result -		// indicates this server is likely having issues. -		fastFail = c.badHosts.Has(host) -		defer func() { -			if err != nil { -				// On error return mark as bad-host. -				c.badHosts.Set(host, struct{}{}) -			} -		}() +	// Wrap in our own request +	// type for retry-backoff. +	req := WrapRequest(r) + +	if gtscontext.IsFastfail(r.Context()) { +		// If the fast-fail flag was set, just +		// attempt a single iteration instead of +		// following the below retry-backoff loop. +		rsp, _, err = c.DoOnce(&req) +		if err != nil { +			return nil, fmt.Errorf("%w (fast fail)", err) +		} +		return rsp, nil  	} -	// Start a log entry for this request -	l := log.WithContext(r.Context()). -		WithFields(kv.Fields{ -			{"method", r.Method}, -			{"url", r.URL.String()}, -		}...) +	for { +		var retry bool -	for i := 0; i < maxRetries; i++ { -		var backoff time.Duration +		// Perform the http request. +		rsp, retry, err = c.DoOnce(&req) +		if err == nil { +			return rsp, nil +		} -		l.Info("performing request") +		if !retry { +			// reached max retries, don't further backoff +			return nil, fmt.Errorf("%w (max retries)", err) +		} -		// Perform the request. -		rsp, err = c.do(r) -		if err == nil { //nolint:gocritic +		// Start new backoff sleep timer. +		backoff := time.NewTimer(req.BackOff()) -			// TooManyRequest means we need to slow -			// down and retry our request. Codes over -			// 500 generally indicate temp. outages. -			if code := rsp.StatusCode; code < 500 && -				code != http.StatusTooManyRequests { -				return rsp, nil -			} +		select { +		// Request ctx cancelled. +		case <-r.Context().Done(): +			backoff.Stop() -			// Create loggable error from response status code. -			err = fmt.Errorf(`http response: %s`, rsp.Status) +			// Return context error. +			err = r.Context().Err() +			return nil, err -			// Search for a provided "Retry-After" header value. -			if after := rsp.Header.Get("Retry-After"); after != "" { +		// Backoff for time. +		case <-backoff.C: +		} +	} +} -				// Get current time. -				now := time.Now() +// DoOnce wraps an underlying http.Client{}.Do() to perform our wrapped request type: +// rewinding response body to permit reuse, signing request data when SignFunc provided, +// marking erroring hosts, updating retry attempt counts and setting backoff from header. +func (c *Client) DoOnce(r *Request) (rsp *http.Response, retry bool, err error) { +	if r.attempts > c.retries { +		// Ensure request hasn't reached max number of attempts. +		err = fmt.Errorf("httpclient: reached max retries (%d)", c.retries) +		return +	} -				if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { -					// An integer number of backoff seconds was provided. -					backoff = time.Duration(u) * time.Second -				} else if at, _ := http.ParseTime(after); !at.Before(now) { -					// An HTTP formatted future date-time was provided. -					backoff = at.Sub(now) -				} +	// Update no. +	// attempts. +	r.attempts++ -				// Don't let their provided backoff exceed our max. -				if max := baseBackoff * maxRetries; backoff > max { -					backoff = max -				} -			} +	// Reset backoff. +	r.backoff = 0 + +	// Perform main routine. +	rsp, retry, err = c.do(r) -			// Close + unset rsp. -			_ = rsp.Body.Close() -			rsp = nil +	if rsp != nil { +		// Log successful rsp. +		r.Entry.Info(rsp.Status) +		return +	} -		} else if errorsv2.IsV2(err, +	// Log any errors. +	r.Entry.Error(err) + +	switch { +	case !retry: +		// If they were told not to +		// retry, also set number of +		// attempts to prevent retry. +		r.attempts = c.retries + 1 + +	case r.attempts > c.retries: +		// On max retries, mark this as +		// a "badhost", i.e. is erroring. +		c.badHosts.Set(r.Host, struct{}{}) + +		// Ensure retry flag is unset +		// when reached max attempts. +		retry = false + +	case c.badHosts.Has(r.Host): +		// When retry is still permitted, +		// check host hasn't been marked +		// as a "badhost", i.e. erroring. +		r.attempts = c.retries + 1 +		retry = false +	} + +	return +} + +// do performs the "meat" of DoOnce(), but it's separated out to allow +// easier wrapping of the response, retry, error returns with further logic. +func (c *Client) do(r *Request) (rsp *http.Response, retry bool, err error) { +	// Perform the HTTP request. +	rsp, err = c.client.Do(r.Request) +	if err != nil { + +		if errorsv2.IsV2(err,  			context.DeadlineExceeded,  			context.Canceled,  			ErrBodyTooLarge,  			ErrReservedAddr,  		) {  			// Non-retryable errors. -			return nil, err -		} else if errstr := err.Error(); // nocollapse +			return nil, false, err +		} + +		if errstr := err.Error(); //  		strings.Contains(errstr, "stopped after 10 redirects") ||  			strings.Contains(errstr, "tls: ") ||  			strings.Contains(errstr, "x509: ") {  			// These error types aren't wrapped  			// so we have to check the error string.  			// All are unrecoverable! -			return nil, err -		} else if dnserr := (*net.DNSError)(nil); // nocollapse -		errors.As(err, &dnserr) && dnserr.IsNotFound { -			// DNS lookup failure, this domain does not exist -			return nil, gtserror.SetNotFound(err) +			return nil, false, err  		} -		if fastFail { -			// on fast-fail, don't bother backoff/retry -			return nil, fmt.Errorf("%w (fast fail)", err) +		if dnserr := errorsv2.AsV2[*net.DNSError](err); // +		dnserr != nil && dnserr.IsNotFound { +			// DNS lookup failure, this domain does not exist +			return nil, false, gtserror.SetNotFound(err)  		} -		if backoff == 0 { -			// No retry-after found, set our predefined -			// backoff according to a multiplier of 2^n. -			backoff = baseBackoff * 1 << (i + 1) -		} +		// A retryable error. +		return nil, true, err -		l.Errorf("backing off for %s after http request error: %v", backoff, err) +	} else if rsp.StatusCode > 500 || +		rsp.StatusCode == http.StatusTooManyRequests { -		select { -		// Request ctx cancelled -		case <-r.Context().Done(): -			return nil, r.Context().Err() +		// Codes over 500 (and 429: too many requests) +		// are generally temporary errors. For these +		// we replace the response with a loggable error. +		err = fmt.Errorf(`http response: %s`, rsp.Status) -		// Backoff for some time -		case <-time.After(backoff): -		} -	} +		// Search for a provided "Retry-After" header value. +		if after := rsp.Header.Get("Retry-After"); after != "" { -	// Set error return to trigger setting "bad host". -	err = errors.New("transport reached max retries") -	return -} +			// Get cur time. +			now := time.Now() -// do wraps http.Client{}.Do() to provide safely limited response bodies. -func (c *Client) do(req *http.Request) (*http.Response, error) { -	// Perform the HTTP request. -	rsp, err := c.client.Do(req) -	if err != nil { -		return nil, err +			if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { +				// An integer no. of backoff seconds was provided. +				r.backoff = time.Duration(u) * time.Second +			} else if at, _ := http.ParseTime(after); !at.Before(now) { +				// An HTTP formatted future date-time was provided. +				r.backoff = at.Sub(now) +			} + +			// Don't let their provided backoff exceed our max. +			if max := baseBackoff * time.Duration(c.retries); // +			r.backoff > max { +				r.backoff = max +			} +		} + +		// Unset + close rsp. +		_ = rsp.Body.Close() +		return nil, true, err  	}  	// Seperate the body implementers. @@ -364,11 +395,10 @@ func (c *Client) do(req *http.Request) (*http.Response, error) {  	// Check response body not too large.  	if rsp.ContentLength > c.bodyMax { -		_ = rsp.Body.Close() -		return nil, ErrBodyTooLarge +		return nil, false, ErrBodyTooLarge  	} -	return rsp, nil +	return rsp, true, nil  }  // cast discard writer to full interface it supports. diff --git a/internal/httpclient/request.go b/internal/httpclient/request.go new file mode 100644 index 000000000..0df9211e7 --- /dev/null +++ b/internal/httpclient/request.go @@ -0,0 +1,69 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package httpclient + +import ( +	"net/http" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/log" +) + +const ( +	// starting backoff duration. +	baseBackoff = 2 * time.Second +) + +// Request wraps an HTTP request +// to add our own retry / backoff. +type Request struct { +	// Current backoff dur. +	backoff time.Duration + +	// Delivery attempts. +	attempts uint + +	// log fields. +	log.Entry + +	// underlying request. +	*http.Request +} + +// WrapRequest wraps an existing http.Request within +// our own httpclient.Request with retry / backoff tracking. +func WrapRequest(r *http.Request) Request { +	var rr Request +	rr.Request = r +	rr.Entry = log.WithContext(r.Context()). +		WithField("method", r.Method). +		WithField("url", r.URL.String()). +		WithField("contentType", r.Header.Get("Content-Type")) +	return rr +} + +// GetBackOff returns the currently set backoff duration, +// (using a default according to no. attempts if needed). +func (r *Request) BackOff() time.Duration { +	if r.backoff <= 0 { +		// No backoff dur found, set our predefined +		// backoff according to a multiplier of 2^n. +		r.backoff = baseBackoff * 1 << (r.attempts + 1) +	} +	return r.backoff +} diff --git a/internal/httpclient/sign.go b/internal/httpclient/sign.go index 6b561c45a..eff20be49 100644 --- a/internal/httpclient/sign.go +++ b/internal/httpclient/sign.go @@ -32,9 +32,7 @@ type SignFunc func(r *http.Request) error  // (RoundTripper implementer) to check request  // context for a signing function and using for  // all subsequent trips through RoundTrip(). -type signingtransport struct { -	http.Transport // underlying transport -} +type signingtransport struct{ http.Transport }  func (t *signingtransport) RoundTrip(r *http.Request) (*http.Response, error) {  	// Ensure updated host always set. diff --git a/internal/httpclient/validate.go b/internal/httpclient/validate.go index 881d3f699..5a6257288 100644 --- a/internal/httpclient/validate.go +++ b/internal/httpclient/validate.go @@ -38,7 +38,7 @@ func ValidateRequest(r *http.Request) error {  		return fmt.Errorf("%w: empty url host", ErrInvalidRequest)  	case r.URL.Scheme != "http" && r.URL.Scheme != "https":  		return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme) -	case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1: +	case strings.IndexFunc(r.Method, isNotTokenRune) != -1:  		return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method)  	} @@ -60,3 +60,8 @@ func ValidateRequest(r *http.Request) error {  	return nil  } + +// isNotTokenRune wraps IsTokenRune to inverse result. +func isNotTokenRune(r rune) bool { +	return !httpguts.IsTokenRune(r) +} diff --git a/internal/media/manager_test.go b/internal/media/manager_test.go index dbc9c634a..ac4286c73 100644 --- a/internal/media/manager_test.go +++ b/internal/media/manager_test.go @@ -33,6 +33,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/media"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	gtsstorage "github.com/superseriousbusiness/gotosocial/internal/storage" +	"github.com/superseriousbusiness/gotosocial/testrig"  )  type ManagerTestSuite struct { @@ -1197,8 +1198,8 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingWithDiskStorage() {  	var state state.State -	state.Workers.Start() -	defer state.Workers.Stop() +	testrig.StartNoopWorkers(&state) +	defer testrig.StopWorkers(&state)  	storage := >sstorage.Driver{  		Storage: disk, diff --git a/internal/processing/account_test.go b/internal/processing/account_test.go index 83eebedba..82c28115e 100644 --- a/internal/processing/account_test.go +++ b/internal/processing/account_test.go @@ -21,6 +21,7 @@ import (  	"context"  	"encoding/json"  	"fmt" +	"io"  	"testing"  	"time" @@ -55,7 +56,7 @@ func (suite *AccountTestSuite) TestAccountDeleteLocal() {  	suite.NoError(errWithCode)  	// the delete should be federated outwards to the following account's inbox -	var sent [][]byte +	var sent []byte  	delete := new(struct {  		Actor  string `json:"actor"`  		ID     string `json:"id"` @@ -66,16 +67,22 @@ func (suite *AccountTestSuite) TestAccountDeleteLocal() {  	})  	if !testrig.WaitFor(func() bool { -		sentI, ok := suite.httpClient.SentMessages.Load(*followingAccount.SharedInboxURI) -		if ok { -			sent, ok = sentI.([][]byte) -			if !ok { -				panic("SentMessages entry was not [][]byte") -			} -			err = json.Unmarshal(sent[0], delete) -			return err == nil +		delivery, ok := suite.state.Workers.Delivery.Queue.Pop() +		if !ok { +			return false  		} -		return false +		if !testrig.EqualRequestURIs(delivery.Request.URL, *followingAccount.SharedInboxURI) { +			panic("differing request uris") +		} +		sent, err = io.ReadAll(delivery.Request.Body) +		if err != nil { +			panic("error reading body: " + err.Error()) +		} +		err = json.Unmarshal(sent, delete) +		if err != nil { +			panic("error unmarshaling json: " + err.Error()) +		} +		return true  	}) {  		suite.FailNow("timed out waiting for message")  	} diff --git a/internal/processing/followrequest_test.go b/internal/processing/followrequest_test.go index 4c089be4a..db0419522 100644 --- a/internal/processing/followrequest_test.go +++ b/internal/processing/followrequest_test.go @@ -21,6 +21,7 @@ import (  	"context"  	"encoding/json"  	"fmt" +	"io"  	"testing"  	"time" @@ -77,22 +78,6 @@ func (suite *FollowRequestTestSuite) TestFollowRequestAccept() {  		Note:                "",  	}, relationship) -	// accept should be sent to Some_User -	var sent [][]byte -	if !testrig.WaitFor(func() bool { -		sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI) -		if ok { -			sent, ok = sentI.([][]byte) -			if !ok { -				panic("SentMessages entry was not []byte") -			} -			return true -		} -		return false -	}) { -		suite.FailNow("timed out waiting for message") -	} -  	accept := &struct {  		Actor  string `json:"actor"`  		ID     string `json:"id"` @@ -106,8 +91,29 @@ func (suite *FollowRequestTestSuite) TestFollowRequestAccept() {  		To   string `json:"to"`  		Type string `json:"type"`  	}{} -	err = json.Unmarshal(sent[0], accept) -	suite.NoError(err) + +	// accept should be sent to Some_User +	var sent []byte +	if !testrig.WaitFor(func() bool { +		delivery, ok := suite.state.Workers.Delivery.Queue.Pop() +		if !ok { +			return false +		} +		if !testrig.EqualRequestURIs(delivery.Request.URL, targetAccount.InboxURI) { +			panic("differing request uris") +		} +		sent, err = io.ReadAll(delivery.Request.Body) +		if err != nil { +			panic("error reading body: " + err.Error()) +		} +		err = json.Unmarshal(sent, accept) +		if err != nil { +			panic("error unmarshaling json: " + err.Error()) +		} +		return true +	}) { +		suite.FailNow("timed out waiting for message") +	}  	suite.Equal(requestingAccount.URI, accept.Actor)  	suite.Equal(targetAccount.URI, accept.Object.Actor) @@ -144,22 +150,6 @@ func (suite *FollowRequestTestSuite) TestFollowRequestReject() {  	suite.NoError(errWithCode)  	suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: false, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) -	// reject should be sent to Some_User -	var sent [][]byte -	if !testrig.WaitFor(func() bool { -		sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI) -		if ok { -			sent, ok = sentI.([][]byte) -			if !ok { -				panic("SentMessages entry was not []byte") -			} -			return true -		} -		return false -	}) { -		suite.FailNow("timed out waiting for message") -	} -  	reject := &struct {  		Actor  string `json:"actor"`  		ID     string `json:"id"` @@ -173,8 +163,29 @@ func (suite *FollowRequestTestSuite) TestFollowRequestReject() {  		To   string `json:"to"`  		Type string `json:"type"`  	}{} -	err = json.Unmarshal(sent[0], reject) -	suite.NoError(err) + +	// reject should be sent to Some_User +	var sent []byte +	if !testrig.WaitFor(func() bool { +		delivery, ok := suite.state.Workers.Delivery.Queue.Pop() +		if !ok { +			return false +		} +		if !testrig.EqualRequestURIs(delivery.Request.URL, targetAccount.InboxURI) { +			panic("differing request uris") +		} +		sent, err = io.ReadAll(delivery.Request.Body) +		if err != nil { +			panic("error reading body: " + err.Error()) +		} +		err = json.Unmarshal(sent, reject) +		if err != nil { +			panic("error unmarshaling json: " + err.Error()) +		} +		return true +	}) { +		suite.FailNow("timed out waiting for message") +	}  	suite.Equal(requestingAccount.URI, reject.Actor)  	suite.Equal(targetAccount.URI, reject.Object.Actor) diff --git a/internal/processing/workers/federate.go b/internal/processing/workers/federate.go index 9fdb8f662..e737513f5 100644 --- a/internal/processing/workers/federate.go +++ b/internal/processing/workers/federate.go @@ -75,6 +75,12 @@ func (f *federate) DeleteAccount(ctx context.Context, account *gtsmodel.Account)  		return nil  	} +	// Drop any queued outgoing AP requests to / from account, +	// (this stops any queued likes, boosts, creates etc). +	f.state.Workers.Delivery.Queue.Delete("ActorID", account.URI) +	f.state.Workers.Delivery.Queue.Delete("ObjectID", account.URI) +	f.state.Workers.Delivery.Queue.Delete("TargetID", account.URI) +  	// Parse relevant URI(s).  	outboxIRI, err := parseURI(account.OutboxURI)  	if err != nil { @@ -222,6 +228,11 @@ func (f *federate) DeleteStatus(ctx context.Context, status *gtsmodel.Status) er  		return nil  	} +	// Drop any queued outgoing http requests for status, +	// (this stops any queued likes, boosts, creates etc). +	f.state.Workers.Delivery.Queue.Delete("ObjectID", status.URI) +	f.state.Workers.Delivery.Queue.Delete("TargetID", status.URI) +  	// Ensure the status model is fully populated.  	if err := f.state.DB.PopulateStatus(ctx, status); err != nil {  		return gtserror.Newf("error populating status: %w", err) diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index 1dbefca84..51f61bd12 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -21,6 +21,7 @@ import (  	"context"  	"encoding/json"  	"fmt" +	"io"  	"testing"  	"time" @@ -457,22 +458,6 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {  	})  	suite.NoError(err) -	// an accept message should be sent to satan's inbox -	var sent [][]byte -	if !testrig.WaitFor(func() bool { -		sentI, ok := suite.httpClient.SentMessages.Load(*originAccount.SharedInboxURI) -		if ok { -			sent, ok = sentI.([][]byte) -			if !ok { -				panic("SentMessages entry was not []byte") -			} -			return true -		} -		return false -	}) { -		suite.FailNow("timed out waiting for message") -	} -  	accept := &struct {  		Actor  string `json:"actor"`  		ID     string `json:"id"` @@ -486,8 +471,29 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {  		To   string `json:"to"`  		Type string `json:"type"`  	}{} -	err = json.Unmarshal(sent[0], accept) -	suite.NoError(err) + +	// an accept message should be sent to satan's inbox +	var sent []byte +	if !testrig.WaitFor(func() bool { +		delivery, ok := suite.state.Workers.Delivery.Queue.Pop() +		if !ok { +			return false +		} +		if !testrig.EqualRequestURIs(delivery.Request.URL, *originAccount.SharedInboxURI) { +			panic("differing request uris") +		} +		sent, err = io.ReadAll(delivery.Request.Body) +		if err != nil { +			panic("error reading body: " + err.Error()) +		} +		err = json.Unmarshal(sent, accept) +		if err != nil { +			panic("error unmarshaling json: " + err.Error()) +		} +		return true +	}) { +		suite.FailNow("timed out waiting for message") +	}  	suite.Equal(targetAccount.URI, accept.Actor)  	suite.Equal(originAccount.URI, accept.Object.Actor) diff --git a/internal/queue/wrappers.go b/internal/queue/wrappers.go new file mode 100644 index 000000000..e07984f84 --- /dev/null +++ b/internal/queue/wrappers.go @@ -0,0 +1,96 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package queue + +import ( +	"sync/atomic" + +	"codeberg.org/gruf/go-structr" +) + +// StructQueue wraps a structr.Queue{} to +// provide simple index caching by name. +type StructQueue[StructType any] struct { +	queue structr.Queue[StructType] +	index map[string]*structr.Index +	wait  atomic.Pointer[chan struct{}] +} + +// Init initializes queue with structr.QueueConfig{}. +func (q *StructQueue[T]) Init(config structr.QueueConfig[T]) { +	q.index = make(map[string]*structr.Index, len(config.Indices)) +	q.queue = structr.Queue[T]{} +	q.queue.Init(config) +	for _, cfg := range config.Indices { +		q.index[cfg.Fields] = q.queue.Index(cfg.Fields) +	} +} + +// Pop: see structr.Queue{}.PopFront(). +func (q *StructQueue[T]) Pop() (value T, ok bool) { +	return q.queue.PopFront() +} + +// Push wraps structr.Queue{}.PushBack() to awaken those blocking on <-.Wait(). +func (q *StructQueue[T]) Push(values ...T) { +	q.queue.PushBack(values...) +	q.broadcast() +} + +// Delete pops (and drops!) all queued entries under index with key. +func (q *StructQueue[T]) Delete(index string, key ...any) { +	i := q.index[index] +	_ = q.queue.Pop(i, i.Key(key...)) +} + +// Len: see structr.Queue{}.Len(). +func (q *StructQueue[T]) Len() int { +	return q.queue.Len() +} + +// Wait returns current wait channel, which may be +// blocked on to awaken when new value pushed to queue. +func (q *StructQueue[T]) Wait() <-chan struct{} { +	var ch chan struct{} + +	for { +		// Get channel ptr. +		ptr := q.wait.Load() +		if ptr != nil { +			return *ptr +		} + +		if ch == nil { +			// Allocate new channel. +			ch = make(chan struct{}) +		} + +		// Try set the new wait channel ptr. +		if q.wait.CompareAndSwap(ptr, &ch) { +			return ch +		} +	} +} + +// broadcast safely closes wait channel if +// currently set, releasing waiting goroutines. +func (q *StructQueue[T]) broadcast() { +	if ptr := q.wait.Swap(nil); ptr != nil { +		close(*ptr) +	} +} diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 891a24495..519298d8e 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -28,7 +28,6 @@ import (  	"io"  	"net/http"  	"net/url" -	"runtime"  	"codeberg.org/gruf/go-byteutil"  	"codeberg.org/gruf/go-cache/v3" @@ -56,24 +55,16 @@ type controller struct {  	client    pub.HttpClient  	trspCache cache.TTLCache[string, *transport]  	userAgent string -	senders   int // no. concurrent batch delivery routines.  }  // NewController returns an implementation of the Controller interface for creating new transports  func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {  	var ( -		host             = config.GetHost() -		proto            = config.GetProtocol() -		version          = config.GetSoftwareVersion() -		senderMultiplier = config.GetAdvancedSenderMultiplier() +		host    = config.GetHost() +		proto   = config.GetProtocol() +		version = config.GetSoftwareVersion()  	) -	senders := senderMultiplier * runtime.GOMAXPROCS(0) -	if senders < 1 { -		// Clamp senders to 1. -		senders = 1 -	} -  	c := &controller{  		state:     state,  		fedDB:     federatingDB, @@ -81,7 +72,6 @@ func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.C  		client:    client,  		trspCache: cache.NewTTL[string, *transport](0, 100, 0),  		userAgent: fmt.Sprintf("gotosocial/%s (+%s://%s)", version, proto, host), -		senders:   senders,  	}  	return c diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index fe4d04582..a7e73465d 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -19,118 +19,188 @@ package transport  import (  	"context" +	"encoding/json"  	"net/http"  	"net/url" -	"sync"  	"codeberg.org/gruf/go-byteutil"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"  	"github.com/superseriousbusiness/gotosocial/internal/config" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient" +	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"  ) -func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { +func (t *transport) BatchDeliver(ctx context.Context, obj map[string]interface{}, recipients []*url.URL) error {  	var ( -		// errs accumulates errors received during -		// attempted delivery by deliverer routines. -		errs gtserror.MultiError - -		// wait blocks until all sender -		// routines have returned. -		wait sync.WaitGroup +		// accumulated delivery reqs. +		reqs []*delivery.Delivery -		// mutex protects 'recipients' and -		// 'errs' for concurrent access. -		mutex sync.Mutex +		// accumulated preparation errs. +		errs gtserror.MultiError  		// Get current instance host info.  		domain = config.GetAccountDomain()  		host   = config.GetHost()  	) -	// Block on expect no. senders. -	wait.Add(t.controller.senders) - -	for i := 0; i < t.controller.senders; i++ { -		go func() { -			// Mark returned. -			defer wait.Done() - -			for { -				// Acquire lock. -				mutex.Lock() - -				if len(recipients) == 0 { -					// Reached end. -					mutex.Unlock() -					return -				} - -				// Pop next recipient. -				i := len(recipients) - 1 -				to := recipients[i] -				recipients = recipients[:i] - -				// Done with lock. -				mutex.Unlock() - -				// Skip delivery to recipient if it is "us". -				if to.Host == host || to.Host == domain { -					continue -				} - -				// Attempt to deliver data to recipient. -				if err := t.deliver(ctx, b, to); err != nil { -					mutex.Lock() // safely append err to accumulator. -					errs.Appendf("error delivering to %s: %w", to, err) -					mutex.Unlock() -				} -			} -		}() +	// Marshal object as JSON. +	b, err := json.Marshal(obj) +	if err != nil { +		return gtserror.Newf("error marshaling json: %w", err) +	} + +	// Extract object IDs. +	actID := getActorID(obj) +	objID := getObjectID(obj) +	tgtID := getTargetID(obj) + +	for _, to := range recipients { +		// Skip delivery to recipient if it is "us". +		if to.Host == host || to.Host == domain { +			continue +		} + +		// Prepare http client request. +		req, err := t.prepare(ctx, +			actID, +			objID, +			tgtID, +			b, +			to, +		) +		if err != nil { +			errs.Append(err) +			continue +		} + +		// Append to request queue. +		reqs = append(reqs, req)  	} -	// Wait for finish. -	wait.Wait() +	// Push prepared request list to the delivery queue. +	t.controller.state.Workers.Delivery.Queue.Push(reqs...)  	// Return combined err.  	return errs.Combine()  } -func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { +func (t *transport) Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error {  	// if 'to' host is our own, skip as we don't need to deliver to ourselves...  	if to.Host == config.GetHost() || to.Host == config.GetAccountDomain() {  		return nil  	} -	// Deliver data to recipient. -	return t.deliver(ctx, b, to) +	// Marshal object as JSON. +	b, err := json.Marshal(obj) +	if err != nil { +		return gtserror.Newf("error marshaling json: %w", err) +	} + +	// Prepare http client request. +	req, err := t.prepare(ctx, +		getActorID(obj), +		getObjectID(obj), +		getTargetID(obj), +		b, +		to, +	) +	if err != nil { +		return err +	} + +	// Push prepared request to the delivery queue. +	t.controller.state.Workers.Delivery.Queue.Push(req) + +	return nil  } -func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error { +// prepare will prepare a POST http.Request{} +// to recipient at 'to', wrapping in a queued +// request object with signing function. +func (t *transport) prepare( +	ctx context.Context, +	actorID string, +	objectID string, +	targetID string, +	data []byte, +	to *url.URL, +) ( +	*delivery.Delivery, +	error, +) {  	url := to.String() -	// Use rewindable bytes reader for body. +	// Use rewindable reader for body.  	var body byteutil.ReadNopCloser -	body.Reset(b) +	body.Reset(data) + +	// Prepare POST signer. +	sign := t.signPOST(data) -	req, err := http.NewRequestWithContext(ctx, "POST", url, &body) +	// Update to-be-used request context with signing details. +	ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID) +	ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign) + +	// Prepare a new request with data body directed at URL. +	r, err := http.NewRequestWithContext(ctx, "POST", url, &body)  	if err != nil { -		return err +		return nil, gtserror.Newf("error preparing request: %w", err)  	} -	req.Header.Add("Content-Type", string(apiutil.AppActivityLDJSON)) -	req.Header.Add("Accept-Charset", "utf-8") +	// Set the standard ActivityPub content-type + charset headers. +	r.Header.Add("Content-Type", string(apiutil.AppActivityLDJSON)) +	r.Header.Add("Accept-Charset", "utf-8") -	rsp, err := t.POST(req, b) -	if err != nil { -		return err +	// Validate the request before queueing for delivery. +	if err := httpclient.ValidateRequest(r); err != nil { +		return nil, err +	} + +	return &delivery.Delivery{ +		ActorID:  actorID, +		ObjectID: objectID, +		TargetID: targetID, +		Request:  httpclient.WrapRequest(r), +	}, nil +} + +// getObjectID extracts an object ID from 'serialized' ActivityPub object map. +func getObjectID(obj map[string]interface{}) string { +	switch t := obj["object"].(type) { +	case string: +		return t +	case map[string]interface{}: +		id, _ := t["id"].(string) +		return id +	default: +		return ""  	} -	defer rsp.Body.Close() +} -	if code := rsp.StatusCode; code != http.StatusOK && -		code != http.StatusCreated && code != http.StatusAccepted { -		return gtserror.NewFromResponse(rsp) +// getActorID extracts an actor ID from 'serialized' ActivityPub object map. +func getActorID(obj map[string]interface{}) string { +	switch t := obj["actor"].(type) { +	case string: +		return t +	case map[string]interface{}: +		id, _ := t["id"].(string) +		return id +	default: +		return ""  	} +} -	return nil +// getTargetID extracts a target ID from 'serialized' ActivityPub object map. +func getTargetID(obj map[string]interface{}) string { +	switch t := obj["target"].(type) { +	case string: +		return t +	case map[string]interface{}: +		id, _ := t["id"].(string) +		return id +	default: +		return "" +	}  } diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go new file mode 100644 index 000000000..27281399f --- /dev/null +++ b/internal/transport/delivery/delivery.go @@ -0,0 +1,323 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package delivery + +import ( +	"context" +	"slices" +	"time" + +	"codeberg.org/gruf/go-runners" +	"codeberg.org/gruf/go-structr" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/queue" +) + +// Delivery wraps an httpclient.Request{} +// to add ActivityPub ID IRI fields of the +// outgoing activity, so that deliveries may +// be indexed (and so, dropped from queue) +// by any of these possible ID IRIs. +type Delivery struct { + +	// ActorID contains the ActivityPub +	// actor ID IRI (if any) of the activity +	// being sent out by this request. +	ActorID string + +	// ObjectID contains the ActivityPub +	// object ID IRI (if any) of the activity +	// being sent out by this request. +	ObjectID string + +	// TargetID contains the ActivityPub +	// target ID IRI (if any) of the activity +	// being sent out by this request. +	TargetID string + +	// Request is the prepared (+ wrapped) +	// httpclient.Client{} request that +	// constitutes this ActivtyPub delivery. +	Request httpclient.Request + +	// internal fields. +	next time.Time +} + +func (dlv *Delivery) backoff() time.Duration { +	if dlv.next.IsZero() { +		return 0 +	} +	return time.Until(dlv.next) +} + +// WorkerPool wraps multiple Worker{}s in +// a singular struct for easy multi start/stop. +type WorkerPool struct { + +	// Client defines httpclient.Client{} +	// passed to each of delivery pool Worker{}s. +	Client *httpclient.Client + +	// Queue is the embedded queue.StructQueue{} +	// passed to each of delivery pool Worker{}s. +	Queue queue.StructQueue[*Delivery] + +	// internal fields. +	workers []*Worker +} + +// Init will initialize the Worker{} pool +// with given http client, request queue to pull +// from and number of delivery workers to spawn. +func (p *WorkerPool) Init(client *httpclient.Client) { +	p.Client = client +	p.Queue.Init(structr.QueueConfig[*Delivery]{ +		Indices: []structr.IndexConfig{ +			{Fields: "ActorID", Multiple: true}, +			{Fields: "ObjectID", Multiple: true}, +			{Fields: "TargetID", Multiple: true}, +		}, +	}) +} + +// Start will attempt to start 'n' Worker{}s. +func (p *WorkerPool) Start(n int) (ok bool) { +	if ok = (len(p.workers) == 0); ok { +		p.workers = make([]*Worker, n) +		for i := range p.workers { +			p.workers[i] = new(Worker) +			p.workers[i].Client = p.Client +			p.workers[i].Queue = &p.Queue +			ok = p.workers[i].Start() && ok +		} +	} +	return +} + +// Stop will attempt to stop contained Worker{}s. +func (p *WorkerPool) Stop() (ok bool) { +	if ok = (len(p.workers) > 0); ok { +		for i := range p.workers { +			ok = p.workers[i].Stop() && ok +			p.workers[i] = nil +		} +		p.workers = p.workers[:0] +	} +	return +} + +// Worker wraps an httpclient.Client{} to feed +// from queue.StructQueue{} for ActivityPub reqs +// to deliver. It does so while prioritizing new +// queued requests over backlogged retries. +type Worker struct { + +	// Client is the httpclient.Client{} that +	// delivery worker will use for requests. +	Client *httpclient.Client + +	// Queue is the Delivery{} message queue +	// that delivery worker will feed from. +	Queue *queue.StructQueue[*Delivery] + +	// internal fields. +	backlog []*Delivery +	service runners.Service +} + +// Start will attempt to start the Worker{}. +func (w *Worker) Start() bool { +	return w.service.GoRun(w.run) +} + +// Stop will attempt to stop the Worker{}. +func (w *Worker) Stop() bool { +	return w.service.Stop() +} + +// run wraps process to restart on any panic. +func (w *Worker) run(ctx context.Context) { +	if w.Client == nil || w.Queue == nil { +		panic("not yet initialized") +	} +	log.Infof(ctx, "%p: started delivery worker", w) +	defer log.Infof(ctx, "%p: stopped delivery worker", w) +	for returned := false; !returned; { +		func() { +			defer func() { +				if r := recover(); r != nil { +					log.Errorf(ctx, "recovered panic: %v", r) +				} +			}() +			w.process(ctx) +			returned = true +		}() +	} +} + +// process is the main delivery worker processing routine. +func (w *Worker) process(ctx context.Context) { +	if w.Client == nil || w.Queue == nil { +		// we perform this check here just +		// to ensure the compiler knows these +		// variables aren't nil in the loop, +		// even if already checked by caller. +		panic("not yet initialized") +	} + +loop: +	for { +		// Get next delivery. +		dlv, ok := w.next(ctx) +		if !ok { +			return +		} + +		// Check whether backoff required. +		const min = 100 * time.Millisecond +		if d := dlv.backoff(); d > min { + +			// Start backoff sleep timer. +			backoff := time.NewTimer(d) + +			select { +			case <-ctx.Done(): +				// Main ctx +				// cancelled. +				backoff.Stop() +				return + +			case <-w.Queue.Wait(): +				// A new message was +				// queued, re-add this +				// to backlog + retry. +				w.pushBacklog(dlv) +				backoff.Stop() +				continue loop + +			case <-backoff.C: +				// success! +			} +		} + +		// Attempt delivery of AP request. +		rsp, retry, err := w.Client.DoOnce( +			&dlv.Request, +		) + +		if err == nil { +			// Ensure body closed. +			_ = rsp.Body.Close() +			continue loop +		} + +		if !retry { +			// Drop deliveries when no +			// retry requested, or they +			// reached max (either). +			continue loop +		} + +		// Determine next delivery attempt. +		backoff := dlv.Request.BackOff() +		dlv.next = time.Now().Add(backoff) + +		// Push to backlog. +		w.pushBacklog(dlv) +	} +} + +// next gets the next available delivery, blocking until available if necessary. +func (w *Worker) next(ctx context.Context) (*Delivery, bool) { +loop: +	for { +		// Try pop next queued. +		dlv, ok := w.Queue.Pop() + +		if !ok { +			// Check the backlog. +			if len(w.backlog) > 0 { + +				// Sort by 'next' time. +				sortDeliveries(w.backlog) + +				// Pop next delivery. +				dlv := w.popBacklog() + +				return dlv, true +			} + +			select { +			// Backlog is empty, we MUST +			// block until next enqueued. +			case <-w.Queue.Wait(): +				continue loop + +			// Worker was stopped. +			case <-ctx.Done(): +				return nil, false +			} +		} + +		// Replace request context for worker state canceling. +		ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) +		dlv.Request.Request = dlv.Request.Request.WithContext(ctx) + +		return dlv, true +	} +} + +// popBacklog pops next available from the backlog. +func (w *Worker) popBacklog() *Delivery { +	if len(w.backlog) == 0 { +		return nil +	} + +	// Pop from backlog. +	dlv := w.backlog[0] + +	// Shift backlog down by one. +	copy(w.backlog, w.backlog[1:]) +	w.backlog = w.backlog[:len(w.backlog)-1] + +	return dlv +} + +// pushBacklog pushes the given delivery to backlog. +func (w *Worker) pushBacklog(dlv *Delivery) { +	w.backlog = append(w.backlog, dlv) +} + +// sortDeliveries sorts deliveries according +// to when is the first requiring re-attempt. +func sortDeliveries(d []*Delivery) { +	slices.SortFunc(d, func(a, b *Delivery) int { +		const k = +1 +		switch { +		case a.next.Before(b.next): +			return +k +		case b.next.Before(a.next): +			return -k +		default: +			return 0 +		} +	}) +} diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go new file mode 100644 index 000000000..852c6f6f3 --- /dev/null +++ b/internal/transport/delivery/delivery_test.go @@ -0,0 +1,205 @@ +package delivery_test + +import ( +	"fmt" +	"io" +	"math/rand" +	"net" +	"net/http" +	"strconv" +	"strings" +	"testing" + +	"codeberg.org/gruf/go-byteutil" +	"github.com/superseriousbusiness/gotosocial/internal/config" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient" +	"github.com/superseriousbusiness/gotosocial/internal/queue" +	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery" +) + +func TestDeliveryWorkerPool(t *testing.T) { +	for _, i := range []int{1, 2, 4, 8, 16, 32} { +		t.Run("size="+strconv.Itoa(i), func(t *testing.T) { +			testDeliveryWorkerPool(t, i, generateInput(100*i)) +		}) +	} +} + +func testDeliveryWorkerPool(t *testing.T, sz int, input []*testrequest) { +	wp := new(delivery.WorkerPool) +	wp.Init(httpclient.New(httpclient.Config{ +		AllowRanges: config.MustParseIPPrefixes([]string{ +			"127.0.0.0/8", +		}), +	})) +	if !wp.Start(sz) { +		t.Fatal("failed starting pool") +	} +	defer wp.Stop() +	test(t, &wp.Queue, input) +} + +func test( +	t *testing.T, +	queue *queue.StructQueue[*delivery.Delivery], +	input []*testrequest, +) { +	expect := make(chan *testrequest) +	errors := make(chan error) + +	// Prepare an HTTP test handler that ensures expected delivery is received. +	handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { +		errors <- (<-expect).Equal(r) +	}) + +	// Start new HTTP test server listener. +	l, err := net.Listen("tcp", "127.0.0.1:0") +	if err != nil { +		t.Fatal(err) +	} +	defer l.Close() + +	// Start the HTTP server. +	// +	// specifically not using httptest.Server{} here as httptest +	// links that server with its own http.Client{}, whereas we're +	// using an httpclient.Client{} (well, delivery routine is). +	srv := new(http.Server) +	srv.Addr = "http://" + l.Addr().String() +	srv.Handler = handler +	go srv.Serve(l) +	defer srv.Close() + +	// Range over test input. +	for _, test := range input { + +		// Generate req for input. +		req := test.Generate(srv.Addr) +		r := httpclient.WrapRequest(req) + +		// Wrap the request in delivery. +		dlv := new(delivery.Delivery) +		dlv.Request = r + +		// Enqueue delivery! +		queue.Push(dlv) +		expect <- test + +		// Wait for errors from handler. +		if err := <-errors; err != nil { +			t.Error(err) +		} +	} +} + +type testrequest struct { +	method string +	uri    string +	body   []byte +} + +// generateInput generates 'n' many testrequest cases. +func generateInput(n int) []*testrequest { +	tests := make([]*testrequest, n) +	for i := range tests { +		tests[i] = new(testrequest) +		tests[i].method = randomMethod() +		tests[i].uri = randomURI() +		tests[i].body = randomBody(tests[i].method) +	} +	return tests +} + +var methods = []string{ +	http.MethodConnect, +	http.MethodDelete, +	http.MethodGet, +	http.MethodHead, +	http.MethodOptions, +	http.MethodPatch, +	http.MethodPost, +	http.MethodPut, +	http.MethodTrace, +} + +// randomMethod generates a random http method. +func randomMethod() string { +	return methods[rand.Intn(len(methods))] +} + +// randomURI generates a random http uri. +func randomURI() string { +	n := rand.Intn(5) +	p := make([]string, n) +	for i := range p { +		p[i] = strconv.Itoa(rand.Int()) +	} +	return "/" + strings.Join(p, "/") +} + +// randomBody generates a random http body DEPENDING on method. +func randomBody(method string) []byte { +	if requiresBody(method) { +		return []byte(method + " " + randomURI()) +	} +	return nil +} + +// requiresBody returns whether method requires body. +func requiresBody(method string) bool { +	switch method { +	case http.MethodPatch, +		http.MethodPost, +		http.MethodPut: +		return true +	default: +		return false +	} +} + +// Generate will generate a real http.Request{} from test data. +func (t *testrequest) Generate(addr string) *http.Request { +	var body io.ReadCloser +	if t.body != nil { +		var b byteutil.ReadNopCloser +		b.Reset(t.body) +		body = &b +	} +	req, err := http.NewRequest(t.method, addr+t.uri, body) +	if err != nil { +		panic(err) +	} +	return req +} + +// Equal checks if request matches receiving test request. +func (t *testrequest) Equal(r *http.Request) error { +	// Ensure methods match. +	if t.method != r.Method { +		return fmt.Errorf("differing request methods: t=%q r=%q", t.method, r.Method) +	} + +	// Ensure request URIs match. +	if t.uri != r.URL.RequestURI() { +		return fmt.Errorf("differing request urls: t=%q r=%q", t.uri, r.URL.RequestURI()) +	} + +	// Ensure body cases match. +	if requiresBody(t.method) { + +		// Read request into memory. +		b, err := io.ReadAll(r.Body) +		if err != nil { +			return fmt.Errorf("error reading request body: %v", err) +		} + +		// Compare the request bodies. +		st := strings.TrimSpace(string(t.body)) +		sr := strings.TrimSpace(string(b)) +		if st != sr { +			return fmt.Errorf("differing request bodies: t=%q r=%q", st, sr) +		} +	} + +	return nil +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 3ae5c8967..110c19b3d 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -51,10 +51,10 @@ type Transport interface {  	POST(*http.Request, []byte) (*http.Response, error)  	// Deliver sends an ActivityStreams object. -	Deliver(ctx context.Context, b []byte, to *url.URL) error +	Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error  	// BatchDeliver sends an ActivityStreams object to multiple recipients. -	BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error +	BatchDeliver(ctx context.Context, obj map[string]interface{}, recipients []*url.URL) error  	/*  		GET functions @@ -77,7 +77,8 @@ type Transport interface {  	Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error)  } -// transport implements the Transport interface. +// transport implements +// the Transport interface.  type transport struct {  	controller *controller  	pubKeyID   string diff --git a/internal/workers/workers.go b/internal/workers/workers.go index 3617ce333..17728c255 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -23,14 +23,22 @@ import (  	"runtime"  	"codeberg.org/gruf/go-runners" +	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/scheduler" +	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"  )  type Workers struct {  	// Main task scheduler instance.  	Scheduler scheduler.Scheduler +	// Delivery provides a worker pool that +	// handles outgoing ActivityPub deliveries. +	// It contains an embedded (but accessible) +	// indexed queue of Delivery{} objects. +	Delivery delivery.WorkerPool +  	// ClientAPI provides a worker pool that handles both  	// incoming client actions, and our own side-effects.  	ClientAPI runners.WorkerPool @@ -65,13 +73,23 @@ type Workers struct {  	_ nocopy  } -// Start will start all of the contained worker pools (and global scheduler). +// Start will start all of the contained +// worker pools (and global scheduler).  func (w *Workers) Start() {  	// Get currently set GOMAXPROCS.  	maxprocs := runtime.GOMAXPROCS(0)  	tryUntil("starting scheduler", 5, w.Scheduler.Start) +	tryUntil("start delivery workerpool", 5, func() bool { +		n := config.GetAdvancedSenderMultiplier() +		if n < 1 { +			// clamp min senders to 1. +			return w.Delivery.Start(1) +		} +		return w.Delivery.Start(n * maxprocs) +	}) +  	tryUntil("starting client API workerpool", 5, func() bool {  		return w.ClientAPI.Start(4*maxprocs, 400*maxprocs)  	}) @@ -88,6 +106,7 @@ func (w *Workers) Start() {  // Stop will stop all of the contained worker pools (and global scheduler).  func (w *Workers) Stop() {  	tryUntil("stopping scheduler", 5, w.Scheduler.Stop) +	tryUntil("stopping delivery workerpool", 5, w.Delivery.Stop)  	tryUntil("stopping client API workerpool", 5, w.ClientAPI.Stop)  	tryUntil("stopping federator workerpool", 5, w.Federator.Stop)  	tryUntil("stopping media workerpool", 5, w.Media.Stop)  | 
