diff options
Diffstat (limited to 'internal/transport/transport.go')
-rw-r--r-- | internal/transport/transport.go | 187 |
1 files changed, 25 insertions, 162 deletions
diff --git a/internal/transport/transport.go b/internal/transport/transport.go index e8f742f5b..0123b3ea8 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -20,26 +20,17 @@ package transport import ( "context" "crypto" - "crypto/x509" "errors" - "fmt" "io" - "net" "net/http" "net/url" - "strconv" - "strings" "sync" "time" - "codeberg.org/gruf/go-byteutil" - errorsv2 "codeberg.org/gruf/go-errors/v2" - "codeberg.org/gruf/go-kv" "github.com/go-fed/httpsig" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/httpclient" - "github.com/superseriousbusiness/gotosocial/internal/log" ) // Transport implements the pub.Transport interface with some additional functionality for fetching remote media. @@ -78,7 +69,7 @@ type Transport interface { Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) } -// transport implements the Transport interface +// transport implements the Transport interface. type transport struct { controller *controller pubKeyID string @@ -95,9 +86,11 @@ func (t *transport) GET(r *http.Request) (*http.Response, error) { if r.Method != http.MethodGet { return nil, errors.New("must be GET request") } - return t.do(r, func(r *http.Request) error { - return t.signGET(r) - }) + ctx := r.Context() // extract, set pubkey ID. + ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) + r = r.WithContext(ctx) // replace request ctx. + r.Header.Set("User-Agent", t.controller.userAgent) + return t.controller.client.DoSigned(r, t.signGET()) } // POST will perform given http request using transport client, retrying on certain preset errors. @@ -105,161 +98,31 @@ func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) { if r.Method != http.MethodPost { return nil, errors.New("must be POST request") } - return t.do(r, func(r *http.Request) error { - return t.signPOST(r, body) - }) -} - -func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http.Response, error) { - const ( - // max no. attempts - maxRetries = 5 - - // starting backoff duration. - baseBackoff = 2 * time.Second - ) - - // Get request hostname - host := r.URL.Hostname() - - // Check whether request should fast fail, we check this - // before loop as each context.Value() requires mutex lock. - fastFail := IsFastfail(r.Context()) - if !fastFail { - // Check if recently reached max retries for this host - // so we don't bother with a retry-backoff loop. The only - // errors that are retried upon are server failure and - // domain resolution type errors, so this cached result - // indicates this server is likely having issues. - fastFail = t.controller.badHosts.Has(host) - } - - // Start a log entry for this request - l := log.WithContext(r.Context()). - WithFields(kv.Fields{ - {"pubKeyID", t.pubKeyID}, - {"method", r.Method}, - {"url", r.URL.String()}, - }...) - + ctx := r.Context() // extract, set pubkey ID. + ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) + r = r.WithContext(ctx) // replace request ctx. r.Header.Set("User-Agent", t.controller.userAgent) - - for i := 0; i < maxRetries; i++ { - var backoff time.Duration - - // Reset signing header fields - now := t.controller.clock.Now().UTC() - r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - r.Header.Del("Signature") - r.Header.Del("Digest") - - // Rewind body reader and content-length if set. - if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { - r.ContentLength = int64(rc.Len()) - rc.Rewind() - } - - // Perform request signing - if err := signer(r); err != nil { - return nil, err - } - - l.Infof("performing request") - - // Attempt to perform request - rsp, err := t.controller.client.Do(r) - if err == nil { //nolint:gocritic - // TooManyRequest means we need to slow - // down and retry our request. Codes over - // 500 generally indicate temp. outages. - if code := rsp.StatusCode; code < 500 && - code != http.StatusTooManyRequests { - return rsp, nil - } - - // Generate error from status code for logging - err = errors.New(`http response "` + rsp.Status + `"`) - - // Search for a provided "Retry-After" header value. - if after := rsp.Header.Get("Retry-After"); after != "" { - - if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { - // An integer number of backoff seconds was provided. - backoff = time.Duration(u) * time.Second - } else if at, _ := http.ParseTime(after); !at.Before(now) { - // An HTTP formatted future date-time was provided. - backoff = at.Sub(now) - } - - // Don't let their provided backoff exceed our max. - if max := baseBackoff * maxRetries; backoff > max { - backoff = max - } - } - - } else if errorsv2.Is(err, - context.DeadlineExceeded, - context.Canceled, - httpclient.ErrInvalidRequest, - httpclient.ErrBodyTooLarge, - httpclient.ErrReservedAddr, - ) { - // Return on non-retryable errors - return nil, err - } else if strings.Contains(err.Error(), "stopped after 10 redirects") { - // Don't bother if net/http returned after too many redirects - return nil, err - } else if errors.As(err, &x509.UnknownAuthorityError{}) { - // Unknown authority errors we do NOT recover from - return nil, err - } else if dnserr := (*net.DNSError)(nil); // nocollapse - errors.As(err, &dnserr) && dnserr.IsNotFound { - // DNS lookup failure, this domain does not exist - return nil, gtserror.SetNotFound(err) - } - - if fastFail { - // on fast-fail, don't bother backoff/retry - return nil, fmt.Errorf("%w (fast fail)", err) - } - - if backoff == 0 { - // No retry-after found, set our predefined backoff. - backoff = time.Duration(i) * baseBackoff - } - - l.Errorf("backing off for %s after http request error: %v", backoff, err) - - select { - // Request ctx cancelled - case <-r.Context().Done(): - return nil, r.Context().Err() - - // Backoff for some time - case <-time.After(backoff): - } - } - - // Add "bad" entry for this host. - t.controller.badHosts.Set(host, struct{}{}) - - return nil, errors.New("transport reached max retries") + return t.controller.client.DoSigned(r, t.signPOST(body)) } // signGET will safely sign an HTTP GET request. -func (t *transport) signGET(r *http.Request) (err error) { - t.safesign(func() { - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) - }) - return +func (t *transport) signGET() httpclient.SignFunc { + return func(r *http.Request) (err error) { + t.safesign(func() { + err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) + }) + return + } } // signPOST will safely sign an HTTP POST request for given body. -func (t *transport) signPOST(r *http.Request, body []byte) (err error) { - t.safesign(func() { - err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) - }) - return +func (t *transport) signPOST(body []byte) httpclient.SignFunc { + return func(r *http.Request) (err error) { + t.safesign(func() { + err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) + }) + return + } } // safesign will perform sign function within mutex protection, |