summaryrefslogtreecommitdiff
path: root/internal/transport/transport.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/transport/transport.go')
-rw-r--r--internal/transport/transport.go163
1 files changed, 145 insertions, 18 deletions
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index 40c11ca17..c52686c43 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -21,11 +21,18 @@ package transport
import (
"context"
"crypto"
+ "crypto/x509"
+ "errors"
"io"
+ "net/http"
"net/url"
+ "strings"
"sync"
+ "time"
+ errorsv2 "codeberg.org/gruf/go-errors/v2"
"github.com/go-fed/httpsig"
+ "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -43,28 +50,148 @@ type Transport interface {
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
- // SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
- SigTransport() pub.Transport
}
// transport implements the Transport interface
type transport struct {
- client pub.HttpClient
- appAgent string
- gofedAgent string
- clock pub.Clock
- pubKeyID string
- privkey crypto.PrivateKey
- sigTransport *pub.HttpSigTransport
- getSigner httpsig.Signer
- getSignerMu *sync.Mutex
-
- // shortcuts for dereferencing things that exist on our instance without making an http call to ourself
-
- dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
- dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
+ controller *controller
+ pubKeyID string
+ privkey crypto.PrivateKey
+
+ signerExp time.Time
+ getSigner httpsig.Signer
+ postSigner httpsig.Signer
+ signerMu sync.Mutex
+}
+
+// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodGet {
+ return nil, errors.New("must be GET request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signGET(r)
+ }, retryOn...)
+}
+
+// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodPost {
+ return nil, errors.New("must be POST request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signPOST(r, body)
+ }, retryOn...)
+}
+
+func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
+ const maxRetries = 5
+ backoff := time.Second * 2
+
+ // Start a log entry for this request
+ l := logrus.WithFields(logrus.Fields{
+ "pubKeyID": t.pubKeyID,
+ "method": r.Method,
+ "url": r.URL.String(),
+ })
+
+ for i := 0; i < maxRetries; i++ {
+ // Reset signing header fields
+ now := t.controller.clock.Now().UTC()
+ r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
+ r.Header.Del("Signature")
+ r.Header.Del("Digest")
+
+ // Perform request signing
+ if err := signer(r); err != nil {
+ return nil, err
+ }
+
+ l.Infof("performing request")
+
+ // Attempt to perform request
+ rsp, err := t.controller.client.Do(r)
+ if err == nil { //nolint shutup linter
+ // TooManyRequest means we need to slow
+ // down and retry our request. Codes over
+ // 500 generally indicate temp. outages.
+ if code := rsp.StatusCode; code < 500 &&
+ code != http.StatusTooManyRequests &&
+ !containsInt(retryOn, rsp.StatusCode) {
+ return rsp, nil
+ }
+
+ // Generate error from status code for logging
+ err = errors.New(`http response "` + rsp.Status + `"`)
+ } else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
+ // Return early if context has cancelled
+ return nil, err
+ } else if strings.Contains(err.Error(), "stopped after 10 redirects") {
+ // Don't bother if net/http returned after too many redirects
+ return nil, err
+ } else if errors.As(err, &x509.UnknownAuthorityError{}) {
+ // Unknown authority errors we do NOT recover from
+ return nil, err
+ }
+
+ l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
+
+ select {
+ // Request ctx cancelled
+ case <-r.Context().Done():
+ return nil, r.Context().Err()
+
+ // Backoff for some time
+ case <-time.After(backoff):
+ backoff *= 2
+ }
+ }
+
+ return nil, errors.New("transport reached max retries")
+}
+
+// signGET will safely sign an HTTP GET request.
+func (t *transport) signGET(r *http.Request) (err error) {
+ t.safesign(func() {
+ err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
+ })
+ return
+}
+
+// signPOST will safely sign an HTTP POST request for given body.
+func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
+ t.safesign(func() {
+ err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
+ })
+ return
+}
+
+// safesign will perform sign function within mutex protection,
+// and ensured that httpsig.Signers are up-to-date.
+func (t *transport) safesign(sign func()) {
+ // Perform within mu safety
+ t.signerMu.Lock()
+ defer t.signerMu.Unlock()
+
+ if now := time.Now(); now.After(t.signerExp) {
+ const expiry = 120
+
+ // Signers have expired and require renewal
+ t.getSigner, _ = NewGETSigner(expiry)
+ t.postSigner, _ = NewPOSTSigner(expiry)
+ t.signerExp = now.Add(time.Second * expiry)
+ }
+
+ // Perform signing
+ sign()
}
-func (t *transport) SigTransport() pub.Transport {
- return t.sigTransport
+// containsInt checks if slice contains check.
+func containsInt(slice []int, check int) bool {
+ for _, i := range slice {
+ if i == check {
+ return true
+ }
+ }
+ return false
}