diff options
24 files changed, 434 insertions, 496 deletions
@@ -19,7 +19,6 @@ require (  	github.com/abema/go-mp4 v0.10.1  	github.com/buckket/go-blurhash v1.1.0  	github.com/coreos/go-oidc/v3 v3.5.0 -	github.com/cornelk/hashmap v1.0.8  	github.com/disintegration/imaging v1.6.2  	github.com/gin-contrib/cors v1.4.0  	github.com/gin-contrib/gzip v0.0.6 @@ -82,6 +81,7 @@ require (  	github.com/cilium/ebpf v0.9.1 // indirect  	github.com/containerd/cgroups/v3 v3.0.1 // indirect  	github.com/coreos/go-systemd/v22 v22.3.2 // indirect +	github.com/cornelk/hashmap v1.0.8 // indirect  	github.com/davecgh/go-spew v1.1.1 // indirect  	github.com/docker/go-units v0.4.0 // indirect  	github.com/dsoprea/go-exif/v3 v3.0.0-20210625224831-a6301f85c82b // indirect diff --git a/internal/api/util/errorhandling.go b/internal/api/util/errorhandling.go index 45bcf1d7a..4daaf44c8 100644 --- a/internal/api/util/errorhandling.go +++ b/internal/api/util/errorhandling.go @@ -24,9 +24,9 @@ import (  	"codeberg.org/gruf/go-kv"  	"github.com/gin-gonic/gin"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/log" -	"github.com/superseriousbusiness/gotosocial/internal/middleware"  )  // TODO: add more templated html pages here for different error types @@ -51,7 +51,7 @@ func NotFoundHandler(c *gin.Context, instanceGet func(ctx context.Context) (*api  		c.HTML(http.StatusNotFound, "404.tmpl", gin.H{  			"instance":  instance, -			"requestID": middleware.RequestID(ctx), +			"requestID": gtscontext.RequestID(ctx),  		})  	default:  		c.JSON(http.StatusNotFound, gin.H{ @@ -76,7 +76,7 @@ func genericErrorHandler(c *gin.Context, instanceGet func(ctx context.Context) (  			"instance":  instance,  			"code":      errWithCode.Code(),  			"error":     errWithCode.Safe(), -			"requestID": middleware.RequestID(ctx), +			"requestID": gtscontext.RequestID(ctx),  		})  	default:  		c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/federation/authenticate.go b/internal/federation/authenticate.go index 96436ee0e..5fe4873d4 100644 --- a/internal/federation/authenticate.go +++ b/internal/federation/authenticate.go @@ -34,10 +34,10 @@ import (  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/config" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  )  /* @@ -216,7 +216,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU  		}  		log.Tracef(ctx, "proceeding with dereference for uncached public key %s", requestingPublicKeyID) -		trans, err := f.transportController.NewTransportForUsername(transport.WithFastfail(ctx), requestedUsername) +		trans, err := f.transportController.NewTransportForUsername(gtscontext.SetFastFail(ctx), requestedUsername)  		if err != nil {  			errWithCode := gtserror.NewErrorInternalError(fmt.Errorf("error creating transport for %s: %s", requestedUsername, err))  			log.Debug(ctx, errWithCode) diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index 52f46586d..7995faa84 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -29,10 +29,10 @@ import (  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/uris"  	"github.com/superseriousbusiness/gotosocial/internal/util"  ) @@ -191,9 +191,8 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr  			return ctx, false, err  		} -		// We don't yet have an entry for -		// the instance, go dereference it. -		instance, err := f.GetRemoteInstance(transport.WithFastfail(ctx), username, &url.URL{ +		// we don't have an entry for this instance yet so dereference it +		instance, err := f.GetRemoteInstance(gtscontext.SetFastFail(ctx), username, &url.URL{  			Scheme: publicKeyOwnerURI.Scheme,  			Host:   publicKeyOwnerURI.Host,  		}) @@ -212,7 +211,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr  	// dereference the remote account (or just get it  	// from the db if we already have it).  	requestingAccount, err := f.GetAccountByURI( -		transport.WithFastfail(ctx), username, publicKeyOwnerURI, false, +		gtscontext.SetFastFail(ctx), username, publicKeyOwnerURI, false,  	)  	if err != nil {  		if gtserror.StatusCode(err) == http.StatusGone { diff --git a/internal/gtscontext/context.go b/internal/gtscontext/context.go index 7d4a44774..d52bf2801 100644 --- a/internal/gtscontext/context.go +++ b/internal/gtscontext/context.go @@ -17,7 +17,9 @@  package gtscontext -import "context" +import ( +	"context" +)  // package private context key type.  type ctxkey uint @@ -26,8 +28,54 @@ const (  	// context keys.  	_ ctxkey = iota  	barebonesKey +	fastFailKey +	pubKeyIDKey +	requestIDKey  ) +// RequestID returns the request ID associated with context. This value will usually +// be set by the request ID middleware handler, either pulling an existing supplied +// value from request headers, or generating a unique new entry. This is useful for +// tying together log entries associated with an original incoming request. +func RequestID(ctx context.Context) string { +	id, _ := ctx.Value(requestIDKey).(string) +	return id +} + +// SetRequestID stores the given request ID value and returns the wrapped +// context. See RequestID() for further information on the request ID value. +func SetRequestID(ctx context.Context, id string) context.Context { +	return context.WithValue(ctx, requestIDKey, id) +} + +// PublicKeyID returns the public key ID (URI) associated with context. This +// value is useful for logging situations in which a given public key URI is +// relevant, e.g. for outgoing requests being signed by the given key. +func PublicKeyID(ctx context.Context) string { +	id, _ := ctx.Value(pubKeyIDKey).(string) +	return id +} + +// SetPublicKeyID stores the given public key ID value and returns the wrapped +// context. See PublicKeyID() for further information on the public key ID value. +func SetPublicKeyID(ctx context.Context, id string) context.Context { +	return context.WithValue(ctx, pubKeyIDKey, id) +} + +// IsFastFail returns whether the "fastfail" context key has been set. This +// can be used to indicate to an http client, for example, that the result +// of an outgoing request is time sensitive and so not to bother with retries. +func IsFastfail(ctx context.Context) bool { +	_, ok := ctx.Value(fastFailKey).(struct{}) +	return ok +} + +// SetFastFail sets the "fastfail" context flag and returns this wrapped context. +// See IsFastFail() for further information on the "fastfail" context flag. +func SetFastFail(ctx context.Context) context.Context { +	return context.WithValue(ctx, fastFailKey, struct{}{}) +} +  // Barebones returns whether the "barebones" context key has been set. This  // can be used to indicate to the database, for example, that only a barebones  // model need be returned, Allowing it to skip populating sub models. @@ -37,7 +85,7 @@ func Barebones(ctx context.Context) bool {  }  // SetBarebones sets the "barebones" context flag and returns this wrapped context. -// See Barebones() for further information on the "barebones" context flag.. +// See Barebones() for further information on the "barebones" context flag.  func SetBarebones(ctx context.Context) context.Context {  	return context.WithValue(ctx, barebonesKey, struct{}{})  } diff --git a/internal/gtscontext/log_hooks.go b/internal/gtscontext/log_hooks.go new file mode 100644 index 000000000..2fe43e488 --- /dev/null +++ b/internal/gtscontext/log_hooks.go @@ -0,0 +1,44 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package gtscontext + +import ( +	"context" + +	"codeberg.org/gruf/go-kv" +	"github.com/superseriousbusiness/gotosocial/internal/log" +) + +func init() { +	// Add our required logging hooks on application initialization. +	// +	// Request ID middleware hook. +	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { +		if id := RequestID(ctx); id != "" { +			return append(kvs, kv.Field{K: "requestID", V: id}) +		} +		return kvs +	}) +	// Client IP middleware hook. +	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { +		if id := PublicKeyID(ctx); id != "" { +			return append(kvs, kv.Field{K: "pubKeyID", V: id}) +		} +		return kvs +	}) +} diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index 9562bdc48..67a1d0715 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -18,31 +18,39 @@  package httpclient  import ( +	"context" +	"crypto/x509"  	"errors" +	"fmt"  	"io"  	"net"  	"net/http"  	"net/netip"  	"runtime" +	"strconv" +	"strings"  	"time"  	"codeberg.org/gruf/go-bytesize" +	"codeberg.org/gruf/go-byteutil" +	"codeberg.org/gruf/go-cache/v3" +	errorsv2 "codeberg.org/gruf/go-errors/v2"  	"codeberg.org/gruf/go-kv" -	"github.com/cornelk/hashmap" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/log"  ) -// ErrInvalidRequest is returned if a given HTTP request is invalid and cannot be performed. -var ErrInvalidRequest = errors.New("invalid http request") +var ( +	// ErrInvalidNetwork is returned if the request would not be performed over TCP +	ErrInvalidNetwork = errors.New("invalid network type") -// ErrInvalidNetwork is returned if the request would not be performed over TCP -var ErrInvalidNetwork = errors.New("invalid network type") +	// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. +	ErrReservedAddr = errors.New("dial within blocked / reserved IP range") -// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. -var ErrReservedAddr = errors.New("dial within blocked / reserved IP range") - -// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). -var ErrBodyTooLarge = errors.New("body size too large") +	// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). +	ErrBodyTooLarge = errors.New("body size too large") +)  // Config provides configuration details for setting up a new  // instance of httpclient.Client{}. Within are a subset of the @@ -83,13 +91,10 @@ type Config struct {  //     cases to protect against forged / unknown content-lengths  //   - protection from server side request forgery (SSRF) by only dialing  //     out to known public IP prefixes, configurable with allows/blocks -//   - limit number of concurrent requests, else blocking until a slot -//     is available (context channels still respected)  type Client struct { -	client http.Client -	queue  *hashmap.Map[string, chan struct{}] -	bmax   int64 // max response body size -	cmax   int   // max open conns per host +	client   http.Client +	badHosts cache.Cache[string, struct{}] +	bodyMax  int64  }  // New returns a new instance of Client initialized using configuration. @@ -109,28 +114,26 @@ func New(cfg Config) *Client {  	}  	if cfg.MaxIdleConns <= 0 { -		// By default base this value on MaxOpenConns +		// By default base this value on MaxOpenConns.  		cfg.MaxIdleConns = cfg.MaxOpenConnsPerHost * 10  	}  	if cfg.MaxBodySize <= 0 { -		// By default set this to a reasonable 40MB +		// By default set this to a reasonable 40MB.  		cfg.MaxBodySize = int64(40 * bytesize.MiB)  	} -	// Protect dialer with IP range sanitizer +	// Protect dialer with IP range sanitizer.  	d.Control = (&sanitizer{  		allow: cfg.AllowRanges,  		block: cfg.BlockRanges,  	}).Sanitize -	// Prepare client fields +	// Prepare client fields.  	c.client.Timeout = cfg.Timeout -	c.cmax = cfg.MaxOpenConnsPerHost -	c.bmax = cfg.MaxBodySize -	c.queue = hashmap.New[string, chan struct{}]() +	c.bodyMax = cfg.MaxBodySize -	// Set underlying HTTP client roundtripper +	// Set underlying HTTP client roundtripper.  	c.client.Transport = &http.Transport{  		Proxy:                 http.ProxyFromEnvironment,  		ForceAttemptHTTP2:     true, @@ -144,90 +147,185 @@ func New(cfg Config) *Client {  		DisableCompression:    cfg.DisableCompression,  	} +	// Initiate outgoing bad hosts lookup cache. +	c.badHosts = cache.New[string, struct{}](0, 1000, 0) +	c.badHosts.SetTTL(15*time.Minute, false) +	if !c.badHosts.Start(time.Minute) { +		log.Panic(nil, "failed to start transport controller cache") +	} +  	return &c  } -// Do will perform given request when an available slot in the queue is available, -// and block until this time. For returned values, this follows the same semantics -// as the standard http.Client{}.Do() implementation except that response body will -// be wrapped by an io.LimitReader() to limit response body sizes. -func (c *Client) Do(req *http.Request) (*http.Response, error) { -	// Ensure this is a valid request -	if err := ValidateRequest(req); err != nil { -		return nil, err -	} +// Do ... +func (c *Client) Do(r *http.Request) (*http.Response, error) { +	return c.DoSigned(r, func(r *http.Request) error { +		return nil // no request signing +	}) +} -	// Get host's wait queue -	wait := c.wait(req.Host) - -	var ok bool - -	select { -	// Quickly try grab a spot -	case wait <- struct{}{}: -		// it's our turn! -		ok = true - -		// NOTE: -		// Ideally here we would set the slot release to happen either -		// on error return, or via callback from the response body closer. -		// However when implementing this, there appear deadlocks between -		// the channel queue here and the media manager worker pool. So -		// currently we only place a limit on connections dialing out, but -		// there may still be more connections open than len(c.queue) given -		// that connections may not be closed until response body is closed. -		// The current implementation will reduce the viability of denial of -		// service attacks, but if there are future issues heed this advice :] -		defer func() { <-wait }() -	default: +// DoSigned ... +func (c *Client) DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) { +	const ( +		// max no. attempts. +		maxRetries = 5 + +		// starting backoff duration. +		baseBackoff = 2 * time.Second +	) + +	// Get request hostname. +	host := r.URL.Hostname() + +	// Check whether request should fast fail. +	fastFail := gtscontext.IsFastfail(r.Context()) +	if !fastFail { +		// Check if recently reached max retries for this host +		// so we don't bother with a retry-backoff loop. The only +		// errors that are retried upon are server failure and +		// domain resolution type errors, so this cached result +		// indicates this server is likely having issues. +		fastFail = c.badHosts.Has(host)  	} -	if !ok { -		// No spot acquired, log warning -		log.WithContext(req.Context()). -			WithFields(kv.Fields{ -				{K: "queue", V: len(wait)}, -				{K: "method", V: req.Method}, -				{K: "host", V: req.Host}, -				{K: "uri", V: req.URL.RequestURI()}, -			}...).Warn("full request queue") +	// Start a log entry for this request +	l := log.WithContext(r.Context()). +		WithFields(kv.Fields{ +			{"method", r.Method}, +			{"url", r.URL.String()}, +		}...) + +	for i := 0; i < maxRetries; i++ { +		var backoff time.Duration + +		// Reset signing header fields +		now := time.Now().UTC() +		r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") +		r.Header.Del("Signature") +		r.Header.Del("Digest") + +		// Rewind body reader and content-length if set. +		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { +			r.ContentLength = int64(rc.Len()) +			rc.Rewind() +		} + +		// Sign the outgoing request. +		if err := sign(r); err != nil { +			return nil, err +		} + +		l.Infof("performing request") + +		// Perform the request. +		rsp, err := c.do(r) +		if err == nil { //nolint:gocritic + +			// TooManyRequest means we need to slow +			// down and retry our request. Codes over +			// 500 generally indicate temp. outages. +			if code := rsp.StatusCode; code < 500 && +				code != http.StatusTooManyRequests { +				return rsp, nil +			} + +			// Generate error from status code for logging +			err = errors.New(`http response "` + rsp.Status + `"`) + +			// Search for a provided "Retry-After" header value. +			if after := rsp.Header.Get("Retry-After"); after != "" { + +				if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { +					// An integer number of backoff seconds was provided. +					backoff = time.Duration(u) * time.Second +				} else if at, _ := http.ParseTime(after); !at.Before(now) { +					// An HTTP formatted future date-time was provided. +					backoff = at.Sub(now) +				} + +				// Don't let their provided backoff exceed our max. +				if max := baseBackoff * maxRetries; backoff > max { +					backoff = max +				} +			} + +		} else if errorsv2.Is(err, +			context.DeadlineExceeded, +			context.Canceled, +			ErrBodyTooLarge, +			ErrReservedAddr, +		) { +			// Return on non-retryable errors +			return nil, err +		} else if strings.Contains(err.Error(), "stopped after 10 redirects") { +			// Don't bother if net/http returned after too many redirects +			return nil, err +		} else if errors.As(err, &x509.UnknownAuthorityError{}) { +			// Unknown authority errors we do NOT recover from +			return nil, err +		} else if dnserr := (*net.DNSError)(nil); // nocollapse +		errors.As(err, &dnserr) && dnserr.IsNotFound { +			// DNS lookup failure, this domain does not exist +			return nil, gtserror.SetNotFound(err) +		} + +		if fastFail { +			// on fast-fail, don't bother backoff/retry +			return nil, fmt.Errorf("%w (fast fail)", err) +		} + +		if backoff == 0 { +			// No retry-after found, set our predefined +			// backoff according to a multiplier of 2^n. +			backoff = baseBackoff * 1 << (i + 1) +		} + +		l.Errorf("backing off for %s after http request error: %v", backoff, err)  		select { -		case <-req.Context().Done(): -			// the request was canceled before we -			// got to our turn: no need to release -			return nil, req.Context().Err() -		case wait <- struct{}{}: -			defer func() { <-wait }() +		// Request ctx cancelled +		case <-r.Context().Done(): +			return nil, r.Context().Err() + +		// Backoff for some time +		case <-time.After(backoff):  		}  	} -	// Perform the HTTP request +	// Add "bad" entry for this host. +	c.badHosts.Set(host, struct{}{}) + +	return nil, errors.New("transport reached max retries") +} + +// do ... +func (c *Client) do(req *http.Request) (*http.Response, error) { +	// Perform the HTTP request.  	rsp, err := c.client.Do(req)  	if err != nil {  		return nil, err  	} -	// Check response body not too large -	if rsp.ContentLength > c.bmax { +	// Check response body not too large. +	if rsp.ContentLength > c.bodyMax {  		return nil, ErrBodyTooLarge  	} -	// Seperate the body implementers +	// Seperate the body implementers.  	rbody := (io.Reader)(rsp.Body)  	cbody := (io.Closer)(rsp.Body)  	var limit int64  	if limit = rsp.ContentLength; limit < 0 { -		// If unknown, use max as reader limit -		limit = c.bmax +		// If unknown, use max as reader limit. +		limit = c.bodyMax  	} -	// Don't trust them, limit body reads +	// Don't trust them, limit body reads.  	rbody = io.LimitReader(rbody, limit) -	// Wrap body with limit +	// Wrap body with limit.  	rsp.Body = &struct {  		io.Reader  		io.Closer @@ -235,17 +333,3 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {  	return rsp, nil  } - -// wait acquires the 'wait' queue for the given host string, or allocates new. -func (c *Client) wait(host string) chan struct{} { -	// Look for an existing queue -	queue, ok := c.queue.Get(host) -	if ok { -		return queue -	} - -	// Allocate a new host queue (or return a sneaky existing one). -	queue, _ = c.queue.GetOrInsert(host, make(chan struct{}, c.cmax)) - -	return queue -} diff --git a/internal/httpclient/client_test.go b/internal/httpclient/client_test.go index 9eab0fed4..f0ec01ec3 100644 --- a/internal/httpclient/client_test.go +++ b/internal/httpclient/client_test.go @@ -48,14 +48,6 @@ var bodies = []string{  	"body with\r\nnewlines",  } -// Note: -// There is no test for the .MaxOpenConns implementation -// in the httpclient.Client{}, due to the difficult to test -// this. The block is only held for the actual dial out to -// the connection, so the usual test of blocking and holding -// open this queue slot to check we can't open another isn't -// an easy test here. -  func TestHTTPClientSmallBody(t *testing.T) {  	for _, body := range bodies {  		_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0))) diff --git a/internal/httpclient/request.go b/internal/httpclient/request.go deleted file mode 100644 index 881d3f699..000000000 --- a/internal/httpclient/request.go +++ /dev/null @@ -1,62 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program.  If not, see <http://www.gnu.org/licenses/>. - -package httpclient - -import ( -	"fmt" -	"net/http" -	"strings" - -	"golang.org/x/net/http/httpguts" -) - -// ValidateRequest performs the same request validation logic found in the default -// net/http.Transport{}.roundTrip() function, but pulls it out into this separate -// function allowing validation errors to be wrapped under a single error type. -func ValidateRequest(r *http.Request) error { -	switch { -	case r.URL == nil: -		return fmt.Errorf("%w: nil url", ErrInvalidRequest) -	case r.Header == nil: -		return fmt.Errorf("%w: nil header", ErrInvalidRequest) -	case r.URL.Host == "": -		return fmt.Errorf("%w: empty url host", ErrInvalidRequest) -	case r.URL.Scheme != "http" && r.URL.Scheme != "https": -		return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme) -	case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1: -		return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method) -	} - -	for key, values := range r.Header { -		// Check field key name is valid -		if !httpguts.ValidHeaderFieldName(key) { -			return fmt.Errorf("%w: invalid header field name %q", ErrInvalidRequest, key) -		} - -		// Check each field value is valid -		for i := 0; i < len(values); i++ { -			if !httpguts.ValidHeaderFieldValue(values[i]) { -				return fmt.Errorf("%w: invalid header field value %q", ErrInvalidRequest, values[i]) -			} -		} -	} - -	// ps. kim wrote this - -	return nil -} diff --git a/internal/transport/context_test.go b/internal/httpclient/sign.go index e06e7c4d5..78046aa28 100644 --- a/internal/transport/context_test.go +++ b/internal/httpclient/sign.go @@ -15,19 +15,14 @@  // You should have received a copy of the GNU Affero General Public License  // along with this program.  If not, see <http://www.gnu.org/licenses/>. -package transport_test +package httpclient -import ( -	"context" -	"testing" +import "net/http" -	"github.com/superseriousbusiness/gotosocial/internal/transport" -) +// SignFunc is a function signature that provides request signing. +type SignFunc func(r *http.Request) error -func TestFastFailContext(t *testing.T) { -	ctx := context.Background() -	ctx = transport.WithFastfail(ctx) -	if !transport.IsFastfail(ctx) { -		t.Fatal("failed to set fast-fail context key") -	} +type SigningClient interface { +	Do(r *http.Request) (*http.Response, error) +	DoSigned(r *http.Request, sign SignFunc) (*http.Response, error)  } diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go index e80488330..50e5542c3 100644 --- a/internal/middleware/logger.go +++ b/internal/middleware/logger.go @@ -34,7 +34,7 @@ import (  func Logger() gin.HandlerFunc {  	return func(c *gin.Context) {  		// Initialize the logging fields -		fields := make(kv.Fields, 6, 7) +		fields := make(kv.Fields, 5, 7)  		// Determine pre-handler time  		before := time.Now() @@ -68,11 +68,18 @@ func Logger() gin.HandlerFunc {  			// Set request logging fields  			fields[0] = kv.Field{"latency", time.Since(before)} -			fields[1] = kv.Field{"clientIP", c.ClientIP()} -			fields[2] = kv.Field{"userAgent", c.Request.UserAgent()} -			fields[3] = kv.Field{"method", c.Request.Method} -			fields[4] = kv.Field{"statusCode", code} -			fields[5] = kv.Field{"path", path} +			fields[1] = kv.Field{"userAgent", c.Request.UserAgent()} +			fields[2] = kv.Field{"method", c.Request.Method} +			fields[3] = kv.Field{"statusCode", code} +			fields[4] = kv.Field{"path", path} +			if includeClientIP := true; includeClientIP { +				// TODO: make this configurable. +				// +				// Include clientIP if enabled. +				fields = append(fields, kv.Field{ +					"clientIP", c.ClientIP(), +				}) +			}  			// Create log entry with fields  			l := log.WithContext(c.Request.Context()). diff --git a/internal/middleware/requestid.go b/internal/middleware/requestid.go index 27189b219..6e2a83c68 100644 --- a/internal/middleware/requestid.go +++ b/internal/middleware/requestid.go @@ -19,7 +19,6 @@ package middleware  import (  	"bufio" -	"context"  	"crypto/rand"  	"encoding/base32"  	"encoding/binary" @@ -27,17 +26,11 @@ import (  	"sync"  	"time" -	"codeberg.org/gruf/go-kv"  	"github.com/gin-gonic/gin" -	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  ) -type ctxType string -  var ( -	// ridCtxKey is the key underwhich we store request IDs in a context. -	ridCtxKey ctxType = "id" -  	// crand provides buffered reads of random input.  	crand = bufio.NewReader(rand.Reader)  	mrand sync.Mutex @@ -69,22 +62,8 @@ func generateID() string {  	return base32enc.EncodeToString(b)  } -// RequestID fetches the stored request ID from context. -func RequestID(ctx context.Context) string { -	id, _ := ctx.Value(ridCtxKey).(string) -	return id -} -  // AddRequestID returns a gin middleware which adds a unique ID to each request (both response header and context).  func AddRequestID(header string) gin.HandlerFunc { -	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { -		if id, _ := ctx.Value(ridCtxKey).(string); id != "" { -			// Add stored request ID to log entry fields. -			return append(kvs, kv.Field{K: "requestID", V: id}) -		} -		return kvs -	}) -  	return func(c *gin.Context) {  		// Look for existing ID.  		id := c.GetHeader(header) @@ -100,8 +79,8 @@ func AddRequestID(header string) gin.HandlerFunc {  			c.Request.Header.Set(header, id)  		} -		// Store request ID in new request ctx and set new gin request obj. -		ctx := context.WithValue(c.Request.Context(), ridCtxKey, id) +		// Store request ID in new request context and set on gin ctx. +		ctx := gtscontext.SetRequestID(c.Request.Context(), id)  		c.Request = c.Request.WithContext(ctx)  		// Set the request ID in the rsp header. diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index 84d00c46b..d0ea96ca2 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -25,9 +25,9 @@ import (  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  )  // Get processes the given request for account information. @@ -96,7 +96,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco  		}  		a, err := p.federator.GetAccountByURI( -			transport.WithFastfail(ctx), requestingAccount.Username, targetAccountURI, true, +			gtscontext.SetFastFail(ctx), requestingAccount.Username, targetAccountURI, true,  		)  		if err == nil {  			targetAccount = a diff --git a/internal/processing/fedi/common.go b/internal/processing/fedi/common.go index 91b3030e1..3fade397b 100644 --- a/internal/processing/fedi/common.go +++ b/internal/processing/fedi/common.go @@ -22,9 +22,9 @@ import (  	"fmt"  	"net/url" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  )  func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { @@ -40,7 +40,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)  		return  	} -	if requestingAccount, err = p.federator.GetAccountByURI(transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false); err != nil { +	if requestingAccount, err = p.federator.GetAccountByURI(gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false); err != nil {  		errWithCode = gtserror.NewErrorUnauthorized(err)  		return  	} diff --git a/internal/processing/fedi/user.go b/internal/processing/fedi/user.go index 3343ae8bc..28dc3c857 100644 --- a/internal/processing/fedi/user.go +++ b/internal/processing/fedi/user.go @@ -24,8 +24,8 @@ import (  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/activity/streams/vocab" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/uris"  ) @@ -56,7 +56,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque  		// if we're not already handshaking/dereferencing a remote account, dereference it now  		if !p.federator.Handshaking(requestedUsername, requestingAccountURI) {  			requestingAccount, err := p.federator.GetAccountByURI( -				transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false, +				gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false,  			)  			if err != nil {  				return nil, gtserror.NewErrorUnauthorized(err) diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index 293093ac2..2694fde13 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -25,10 +25,10 @@ import (  	"strings"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/uris"  ) @@ -157,7 +157,7 @@ func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount  			if err != nil {  				return nil, 0, err  			} -			return t.DereferenceMedia(transport.WithFastfail(innerCtx), remoteMediaIRI) +			return t.DereferenceMedia(gtscontext.SetFastFail(innerCtx), remoteMediaIRI)  		}  		// Start recaching this media with the prepared data function. diff --git a/internal/processing/search.go b/internal/processing/search.go index 0c9ef43fd..624537b6a 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -30,11 +30,11 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/oauth" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/util"  ) @@ -226,14 +226,14 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a  }  func (p *Processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL) (*gtsmodel.Status, error) { -	status, statusable, err := p.federator.GetStatus(transport.WithFastfail(ctx), authed.Account.Username, uri, true, true) +	status, statusable, err := p.federator.GetStatus(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, true, true)  	if err != nil {  		return nil, err  	}  	if !*status.Local && statusable != nil {  		// Attempt to dereference the status thread while we are here -		p.federator.DereferenceThread(transport.WithFastfail(ctx), authed.Account.Username, uri, status, statusable) +		p.federator.DereferenceThread(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, status, statusable)  	}  	return status, nil @@ -268,7 +268,7 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth,  	}  	return p.federator.GetAccountByURI( -		transport.WithFastfail(ctx), +		gtscontext.SetFastFail(ctx),  		authed.Account.Username,  		uri, false,  	) @@ -295,7 +295,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o  	}  	return p.federator.GetAccountByUsernameDomain( -		transport.WithFastfail(ctx), +		gtscontext.SetFastFail(ctx),  		authed.Account.Username,  		username, domain, false,  	) diff --git a/internal/processing/util.go b/internal/processing/util.go index 3f3f7ec79..967c03f9f 100644 --- a/internal/processing/util.go +++ b/internal/processing/util.go @@ -24,9 +24,9 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/id" -	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/util"  ) @@ -58,7 +58,7 @@ func GetParseMentionFunc(dbConn db.DB, federator federation.Federator) gtsmodel.  			}  			remoteAccount, err := federator.GetAccountByUsernameDomain( -				transport.WithFastfail(ctx), +				gtscontext.SetFastFail(ctx),  				requestingUsername,  				username,  				domain, diff --git a/internal/transport/context.go b/internal/transport/context.go deleted file mode 100644 index 96d3f23f7..000000000 --- a/internal/transport/context.go +++ /dev/null @@ -1,42 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program.  If not, see <http://www.gnu.org/licenses/>. - -package transport - -import "context" - -// ctxkey is our own unique context key type to prevent setting outside package. -type ctxkey string - -// fastfailkey is our unique context key to indicate fast-fail is enabled. -var fastfailkey = ctxkey("ff") - -// WithFastfail returns a Context which indicates that any http requests made -// with it should return after the first failed attempt, instead of retrying. -// -// This can be used to fail quickly when you're making an outgoing http request -// inside the context of an incoming http request, and you want to be able to -// provide a snappy response to the user, instead of retrying + backing off. -func WithFastfail(parent context.Context) context.Context { -	return context.WithValue(parent, fastfailkey, struct{}{}) -} - -// IsFastfail returns true if the given context was created by WithFastfail. -func IsFastfail(ctx context.Context) bool { -	_, ok := ctx.Value(fastfailkey).(struct{}) -	return ok -} diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 331659f64..e1271d202 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -24,7 +24,7 @@ import (  	"encoding/json"  	"fmt"  	"net/url" -	"time" +	"runtime"  	"codeberg.org/gruf/go-byteutil"  	"codeberg.org/gruf/go-cache/v3" @@ -32,7 +32,7 @@ import (  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" -	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient"  	"github.com/superseriousbusiness/gotosocial/internal/state"  ) @@ -49,14 +49,14 @@ type controller struct {  	state     *state.State  	fedDB     federatingdb.DB  	clock     pub.Clock -	client    pub.HttpClient +	client    httpclient.SigningClient  	trspCache cache.Cache[string, *transport] -	badHosts  cache.Cache[string, struct{}]  	userAgent string +	senders   int // no. concurrent batch delivery routines.  }  // NewController returns an implementation of the Controller interface for creating new transports -func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { +func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client httpclient.SigningClient) Controller {  	applicationName := config.GetApplicationName()  	host := config.GetHost()  	proto := config.GetProtocol() @@ -68,20 +68,8 @@ func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.C  		clock:     clock,  		client:    client,  		trspCache: cache.New[string, *transport](0, 100, 0), -		badHosts:  cache.New[string, struct{}](0, 1000, 0),  		userAgent: fmt.Sprintf("%s (+%s://%s) gotosocial/%s", applicationName, proto, host, version), -	} - -	// Transport cache has TTL=1hr freq=1min -	c.trspCache.SetTTL(time.Hour, false) -	if !c.trspCache.Start(time.Minute) { -		log.Panic(nil, "failed to start transport controller cache") -	} - -	// Bad hosts cache has TTL=15min freq=1min -	c.badHosts.SetTTL(15*time.Minute, false) -	if !c.badHosts.Start(time.Minute) { -		log.Panic(nil, "failed to start transport controller cache") +		senders:   runtime.GOMAXPROCS(0), // on batch delivery, only ever send GOMAXPROCS at a time.  	}  	return c diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index 8ec939503..fff7dbcf4 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -22,7 +22,6 @@ import (  	"fmt"  	"net/http"  	"net/url" -	"strings"  	"sync"  	"codeberg.org/gruf/go-byteutil" @@ -32,54 +31,90 @@ import (  )  func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { -	// concurrently deliver to recipients; for each delivery, buffer the error if it fails -	wg := sync.WaitGroup{} -	errCh := make(chan error, len(recipients)) -	for _, recipient := range recipients { -		wg.Add(1) -		go func(r *url.URL) { -			defer wg.Done() -			if err := t.Deliver(ctx, b, r); err != nil { -				errCh <- err +	var ( +		// errs accumulates errors received during +		// attempted delivery by deliverer routines. +		errs gtserror.MultiError + +		// wait blocks until all sender +		// routines have returned. +		wait sync.WaitGroup + +		// mutex protects 'recipients' and +		// 'errs' for concurrent access. +		mutex sync.Mutex + +		// Get current instance host info. +		domain = config.GetAccountDomain() +		host   = config.GetHost() +	) + +	// Block on expect no. senders. +	wait.Add(t.controller.senders) + +	for i := 0; i < t.controller.senders; i++ { +		go func() { +			// Mark returned. +			defer wait.Done() + +			for { +				// Acquire lock. +				mutex.Lock() + +				if len(recipients) == 0 { +					// Reached end. +					mutex.Unlock() +					return +				} + +				// Pop next recipient. +				i := len(recipients) - 1 +				to := recipients[i] +				recipients = recipients[:i] + +				// Done with lock. +				mutex.Unlock() + +				// Skip delivery to recipient if it is "us". +				if to.Host == host || to.Host == domain { +					continue +				} + +				// Attempt to deliver data to recipient. +				if err := t.deliver(ctx, b, to); err != nil { +					mutex.Lock() // safely append err to accumulator. +					errs.Appendf("error delivering to %s: %v", to, err) +					mutex.Unlock() +				}  			} -		}(recipient) +		}()  	} -	// wait until all deliveries have succeeded or failed -	wg.Wait() - -	// receive any buffered errors -	errs := make([]string, 0, len(errCh)) -outer: -	for { -		select { -		case e := <-errCh: -			errs = append(errs, e.Error()) -		default: -			break outer -		} -	} - -	if len(errs) > 0 { -		return fmt.Errorf("BatchDeliver: at least one failure: %s", strings.Join(errs, "; ")) -	} +	// Wait for finish. +	wait.Wait() -	return nil +	// Return combined err. +	return errs.Combine()  }  func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { -	// if the 'to' host is our own, just skip this delivery since we by definition already have the message! +	// if 'to' host is our own, skip as we don't need to deliver to ourselves...  	if to.Host == config.GetHost() || to.Host == config.GetAccountDomain() {  		return nil  	} -	urlStr := to.String() +	// Deliver data to recipient. +	return t.deliver(ctx, b, to) +} + +func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error { +	url := to.String()  	// Use rewindable bytes reader for body.  	var body byteutil.ReadNopCloser  	body.Reset(b) -	req, err := http.NewRequestWithContext(ctx, "POST", urlStr, &body) +	req, err := http.NewRequestWithContext(ctx, "POST", url, &body)  	if err != nil {  		return err  	} @@ -88,16 +123,16 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {  	req.Header.Add("Accept-Charset", "utf-8")  	req.Header.Set("Host", to.Host) -	resp, err := t.POST(req, b) +	rsp, err := t.POST(req, b)  	if err != nil {  		return err  	} -	defer resp.Body.Close() +	defer rsp.Body.Close() -	if code := resp.StatusCode; code != http.StatusOK && +	if code := rsp.StatusCode; code != http.StatusOK &&  		code != http.StatusCreated && code != http.StatusAccepted { -		err := fmt.Errorf("POST request to %s failed: %s", urlStr, resp.Status) -		return gtserror.WithStatusCode(err, resp.StatusCode) +		err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status) +		return gtserror.WithStatusCode(err, rsp.StatusCode)  	}  	return nil diff --git a/internal/transport/transport.go b/internal/transport/transport.go index e8f742f5b..0123b3ea8 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -20,26 +20,17 @@ package transport  import (  	"context"  	"crypto" -	"crypto/x509"  	"errors" -	"fmt"  	"io" -	"net"  	"net/http"  	"net/url" -	"strconv" -	"strings"  	"sync"  	"time" -	"codeberg.org/gruf/go-byteutil" -	errorsv2 "codeberg.org/gruf/go-errors/v2" -	"codeberg.org/gruf/go-kv"  	"github.com/go-fed/httpsig" -	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/httpclient" -	"github.com/superseriousbusiness/gotosocial/internal/log"  )  // Transport implements the pub.Transport interface with some additional functionality for fetching remote media. @@ -78,7 +69,7 @@ type Transport interface {  	Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error)  } -// transport implements the Transport interface +// transport implements the Transport interface.  type transport struct {  	controller *controller  	pubKeyID   string @@ -95,9 +86,11 @@ func (t *transport) GET(r *http.Request) (*http.Response, error) {  	if r.Method != http.MethodGet {  		return nil, errors.New("must be GET request")  	} -	return t.do(r, func(r *http.Request) error { -		return t.signGET(r) -	}) +	ctx := r.Context() // extract, set pubkey ID. +	ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) +	r = r.WithContext(ctx) // replace request ctx. +	r.Header.Set("User-Agent", t.controller.userAgent) +	return t.controller.client.DoSigned(r, t.signGET())  }  // POST will perform given http request using transport client, retrying on certain preset errors. @@ -105,161 +98,31 @@ func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) {  	if r.Method != http.MethodPost {  		return nil, errors.New("must be POST request")  	} -	return t.do(r, func(r *http.Request) error { -		return t.signPOST(r, body) -	}) -} - -func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http.Response, error) { -	const ( -		// max no. attempts -		maxRetries = 5 - -		// starting backoff duration. -		baseBackoff = 2 * time.Second -	) - -	// Get request hostname -	host := r.URL.Hostname() - -	// Check whether request should fast fail, we check this -	// before loop as each context.Value() requires mutex lock. -	fastFail := IsFastfail(r.Context()) -	if !fastFail { -		// Check if recently reached max retries for this host -		// so we don't bother with a retry-backoff loop. The only -		// errors that are retried upon are server failure and -		// domain resolution type errors, so this cached result -		// indicates this server is likely having issues. -		fastFail = t.controller.badHosts.Has(host) -	} - -	// Start a log entry for this request -	l := log.WithContext(r.Context()). -		WithFields(kv.Fields{ -			{"pubKeyID", t.pubKeyID}, -			{"method", r.Method}, -			{"url", r.URL.String()}, -		}...) - +	ctx := r.Context() // extract, set pubkey ID. +	ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) +	r = r.WithContext(ctx) // replace request ctx.  	r.Header.Set("User-Agent", t.controller.userAgent) - -	for i := 0; i < maxRetries; i++ { -		var backoff time.Duration - -		// Reset signing header fields -		now := t.controller.clock.Now().UTC() -		r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") -		r.Header.Del("Signature") -		r.Header.Del("Digest") - -		// Rewind body reader and content-length if set. -		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { -			r.ContentLength = int64(rc.Len()) -			rc.Rewind() -		} - -		// Perform request signing -		if err := signer(r); err != nil { -			return nil, err -		} - -		l.Infof("performing request") - -		// Attempt to perform request -		rsp, err := t.controller.client.Do(r) -		if err == nil { //nolint:gocritic -			// TooManyRequest means we need to slow -			// down and retry our request. Codes over -			// 500 generally indicate temp. outages. -			if code := rsp.StatusCode; code < 500 && -				code != http.StatusTooManyRequests { -				return rsp, nil -			} - -			// Generate error from status code for logging -			err = errors.New(`http response "` + rsp.Status + `"`) - -			// Search for a provided "Retry-After" header value. -			if after := rsp.Header.Get("Retry-After"); after != "" { - -				if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { -					// An integer number of backoff seconds was provided. -					backoff = time.Duration(u) * time.Second -				} else if at, _ := http.ParseTime(after); !at.Before(now) { -					// An HTTP formatted future date-time was provided. -					backoff = at.Sub(now) -				} - -				// Don't let their provided backoff exceed our max. -				if max := baseBackoff * maxRetries; backoff > max { -					backoff = max -				} -			} - -		} else if errorsv2.Is(err, -			context.DeadlineExceeded, -			context.Canceled, -			httpclient.ErrInvalidRequest, -			httpclient.ErrBodyTooLarge, -			httpclient.ErrReservedAddr, -		) { -			// Return on non-retryable errors -			return nil, err -		} else if strings.Contains(err.Error(), "stopped after 10 redirects") { -			// Don't bother if net/http returned after too many redirects -			return nil, err -		} else if errors.As(err, &x509.UnknownAuthorityError{}) { -			// Unknown authority errors we do NOT recover from -			return nil, err -		} else if dnserr := (*net.DNSError)(nil); // nocollapse -		errors.As(err, &dnserr) && dnserr.IsNotFound { -			// DNS lookup failure, this domain does not exist -			return nil, gtserror.SetNotFound(err) -		} - -		if fastFail { -			// on fast-fail, don't bother backoff/retry -			return nil, fmt.Errorf("%w (fast fail)", err) -		} - -		if backoff == 0 { -			// No retry-after found, set our predefined backoff. -			backoff = time.Duration(i) * baseBackoff -		} - -		l.Errorf("backing off for %s after http request error: %v", backoff, err) - -		select { -		// Request ctx cancelled -		case <-r.Context().Done(): -			return nil, r.Context().Err() - -		// Backoff for some time -		case <-time.After(backoff): -		} -	} - -	// Add "bad" entry for this host. -	t.controller.badHosts.Set(host, struct{}{}) - -	return nil, errors.New("transport reached max retries") +	return t.controller.client.DoSigned(r, t.signPOST(body))  }  // signGET will safely sign an HTTP GET request. -func (t *transport) signGET(r *http.Request) (err error) { -	t.safesign(func() { -		err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) -	}) -	return +func (t *transport) signGET() httpclient.SignFunc { +	return func(r *http.Request) (err error) { +		t.safesign(func() { +			err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) +		}) +		return +	}  }  // signPOST will safely sign an HTTP POST request for given body. -func (t *transport) signPOST(r *http.Request, body []byte) (err error) { -	t.safesign(func() { -		err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) -	}) -	return +func (t *transport) signPOST(body []byte) httpclient.SignFunc { +	return func(r *http.Request) (err error) { +		t.safesign(func() { +			err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) +		}) +		return +	}  }  // safesign will perform sign function within mutex protection, diff --git a/internal/workers/workers.go b/internal/workers/workers.go index bf64a28ee..aa8e40e1c 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -31,8 +31,12 @@ type Workers struct {  	// Main task scheduler instance.  	Scheduler sched.Scheduler -	// ClientAPI / federator worker pools. +	// ClientAPI provides a worker pool that handles both +	// incoming client actions, and our own side-effects.  	ClientAPI runners.WorkerPool + +	// Federator provides a worker pool that handles both +	// incoming federated actions, and our own side-effects.  	Federator runners.WorkerPool  	// Enqueue functions for clientAPI / federator worker pools, diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index f2c6b1d28..b74888934 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -26,12 +26,12 @@ import (  	"strings"  	"sync" -	"github.com/superseriousbusiness/activity/pub"  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/activity/streams/vocab"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/httpclient"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/transport" @@ -51,7 +51,7 @@ const (  // Unlike the other test interfaces provided in this package, you'll probably want to call this function  // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)  // basis. -func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller { +func NewTestTransportController(state *state.State, client httpclient.SigningClient) transport.Controller {  	return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client)  } @@ -225,6 +225,10 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {  	return m.do(req)  } +func (m *MockHTTPClient) DoSigned(req *http.Request, sign httpclient.SignFunc) (*http.Response, error) { +	return m.do(req) +} +  func HostMetaResponse(req *http.Request) (responseCode int, responseBytes []byte, responseContentType string, responseContentLength int) {  	var hm *apimodel.HostMeta  | 
