diff options
Diffstat (limited to 'internal/transport/transport.go')
-rw-r--r-- | internal/transport/transport.go | 163 |
1 files changed, 145 insertions, 18 deletions
diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 40c11ca17..c52686c43 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -21,11 +21,18 @@ package transport import ( "context" "crypto" + "crypto/x509" + "errors" "io" + "net/http" "net/url" + "strings" "sync" + "time" + errorsv2 "codeberg.org/gruf/go-errors/v2" "github.com/go-fed/httpsig" + "github.com/sirupsen/logrus" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -43,28 +50,148 @@ type Transport interface { DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) // Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body. Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error) - // SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport. - SigTransport() pub.Transport } // transport implements the Transport interface type transport struct { - client pub.HttpClient - appAgent string - gofedAgent string - clock pub.Clock - pubKeyID string - privkey crypto.PrivateKey - sigTransport *pub.HttpSigTransport - getSigner httpsig.Signer - getSignerMu *sync.Mutex - - // shortcuts for dereferencing things that exist on our instance without making an http call to ourself - - dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) - dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) + controller *controller + pubKeyID string + privkey crypto.PrivateKey + + signerExp time.Time + getSigner httpsig.Signer + postSigner httpsig.Signer + signerMu sync.Mutex +} + +// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) { + if r.Method != http.MethodGet { + return nil, errors.New("must be GET request") + } + return t.do(r, func(r *http.Request) error { + return t.signGET(r) + }, retryOn...) +} + +// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) { + if r.Method != http.MethodPost { + return nil, errors.New("must be POST request") + } + return t.do(r, func(r *http.Request) error { + return t.signPOST(r, body) + }, retryOn...) +} + +func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) { + const maxRetries = 5 + backoff := time.Second * 2 + + // Start a log entry for this request + l := logrus.WithFields(logrus.Fields{ + "pubKeyID": t.pubKeyID, + "method": r.Method, + "url": r.URL.String(), + }) + + for i := 0; i < maxRetries; i++ { + // Reset signing header fields + now := t.controller.clock.Now().UTC() + r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") + r.Header.Del("Signature") + r.Header.Del("Digest") + + // Perform request signing + if err := signer(r); err != nil { + return nil, err + } + + l.Infof("performing request") + + // Attempt to perform request + rsp, err := t.controller.client.Do(r) + if err == nil { //nolint shutup linter + // TooManyRequest means we need to slow + // down and retry our request. Codes over + // 500 generally indicate temp. outages. + if code := rsp.StatusCode; code < 500 && + code != http.StatusTooManyRequests && + !containsInt(retryOn, rsp.StatusCode) { + return rsp, nil + } + + // Generate error from status code for logging + err = errors.New(`http response "` + rsp.Status + `"`) + } else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) { + // Return early if context has cancelled + return nil, err + } else if strings.Contains(err.Error(), "stopped after 10 redirects") { + // Don't bother if net/http returned after too many redirects + return nil, err + } else if errors.As(err, &x509.UnknownAuthorityError{}) { + // Unknown authority errors we do NOT recover from + return nil, err + } + + l.Errorf("backing off for %s after http request error: %v", backoff.String(), err) + + select { + // Request ctx cancelled + case <-r.Context().Done(): + return nil, r.Context().Err() + + // Backoff for some time + case <-time.After(backoff): + backoff *= 2 + } + } + + return nil, errors.New("transport reached max retries") +} + +// signGET will safely sign an HTTP GET request. +func (t *transport) signGET(r *http.Request) (err error) { + t.safesign(func() { + err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) + }) + return +} + +// signPOST will safely sign an HTTP POST request for given body. +func (t *transport) signPOST(r *http.Request, body []byte) (err error) { + t.safesign(func() { + err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) + }) + return +} + +// safesign will perform sign function within mutex protection, +// and ensured that httpsig.Signers are up-to-date. +func (t *transport) safesign(sign func()) { + // Perform within mu safety + t.signerMu.Lock() + defer t.signerMu.Unlock() + + if now := time.Now(); now.After(t.signerExp) { + const expiry = 120 + + // Signers have expired and require renewal + t.getSigner, _ = NewGETSigner(expiry) + t.postSigner, _ = NewPOSTSigner(expiry) + t.signerExp = now.Add(time.Second * expiry) + } + + // Perform signing + sign() } -func (t *transport) SigTransport() pub.Transport { - return t.sigTransport +// containsInt checks if slice contains check. +func containsInt(slice []int, check int) bool { + for _, i := range slice { + if i == check { + return true + } + } + return false } |