diff options
Diffstat (limited to 'internal/transport')
-rw-r--r-- | internal/transport/controller.go | 196 | ||||
-rw-r--r-- | internal/transport/deliver.go | 29 | ||||
-rw-r--r-- | internal/transport/dereference.go | 39 | ||||
-rw-r--r-- | internal/transport/derefinstance.go | 85 | ||||
-rw-r--r-- | internal/transport/derefmedia.go | 33 | ||||
-rw-r--r-- | internal/transport/finger.go | 50 | ||||
-rw-r--r-- | internal/transport/signing.go | 43 | ||||
-rw-r--r-- | internal/transport/transport.go | 163 |
8 files changed, 422 insertions, 216 deletions
diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 56a922a8b..280d4bc0b 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -20,13 +20,17 @@ package transport import ( "context" - "crypto" + "crypto/rsa" + "crypto/x509" "encoding/json" "fmt" "net/url" - "sync" + "runtime/debug" + "time" - "github.com/go-fed/httpsig" + "codeberg.org/gruf/go-byteutil" + "codeberg.org/gruf/go-cache/v2" + "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams" @@ -37,109 +41,85 @@ import ( // Controller generates transports for use in making federation requests to other servers. type Controller interface { - NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) + // NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key. + NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) + + // NewTransportForUsername searches for account with username, and returns result of .NewTransport(). NewTransportForUsername(ctx context.Context, username string) (Transport, error) } type controller struct { - db db.DB - clock pub.Clock - client pub.HttpClient - appAgent string - - // dereferenceFollowersShortcut is a shortcut to dereference followers of an - // account on this instance, without making any external api/http calls. - // - // It is passed to new transports, and should only be invoked when the iri.Host == this host. - dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) - - // dereferenceUserShortcut is a shortcut to dereference followers an account on - // this instance, without making any external api/http calls. - // - // It is passed to new transports, and should only be invoked when the iri.Host == this host. - dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) + db db.DB + fedDB federatingdb.DB + clock pub.Clock + client pub.HttpClient + cache cache.Cache[string, *transport] + userAgent string } -func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) { - return func(ctx context.Context, iri *url.URL) ([]byte, error) { - followers, err := federatingDB.Followers(ctx, iri) - if err != nil { - return nil, err - } +// NewController returns an implementation of the Controller interface for creating new transports +func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { + applicationName := viper.GetString(config.Keys.ApplicationName) + host := viper.GetString(config.Keys.Host) - i, err := streams.Serialize(followers) - if err != nil { - return nil, err - } + // Determine build information + build, _ := debug.ReadBuildInfo() - return json.Marshal(i) + c := &controller{ + db: db, + fedDB: federatingDB, + clock: clock, + client: client, + cache: cache.New[string, *transport](), + userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version), } -} -func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) { - return func(ctx context.Context, iri *url.URL) ([]byte, error) { - user, err := federatingDB.Get(ctx, iri) - if err != nil { - return nil, err - } - - i, err := streams.Serialize(user) - if err != nil { - return nil, err - } - - return json.Marshal(i) + // Transport cache has TTL=1hr freq=1m + c.cache.SetTTL(time.Hour, false) + if !c.cache.Start(time.Minute) { + logrus.Panic("failed to start transport controller cache") } -} -// NewController returns an implementation of the Controller interface for creating new transports -func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { - applicationName := viper.GetString(config.Keys.ApplicationName) - host := viper.GetString(config.Keys.Host) - appAgent := fmt.Sprintf("%s %s", applicationName, host) - - return &controller{ - db: db, - clock: clock, - client: client, - appAgent: appAgent, - dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB), - dereferenceUserShortcut: dereferenceUserShortcut(federatingDB), - } + return c } -// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key. -func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) { - prefs := []httpsig.Algorithm{httpsig.RSA_SHA256} - digestAlgo := httpsig.DigestSha256 - getHeaders := []string{httpsig.RequestTarget, "host", "date"} - postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"} +func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) { + // Generate public key string for cache key + // + // NOTE: it is safe to use the public key as the cache + // key here as we are generating it ourselves from the + // private key. If we were simply using a public key + // provided as argument that would absolutely NOT be safe. + pubStr := privkeyToPublicStr(privkey) + + // First check for cached transport + transp, ok := c.cache.Get(pubStr) + if ok { + return transp, nil + } - getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120) - if err != nil { - return nil, fmt.Errorf("error creating get signer: %s", err) + // Create the transport + transp = &transport{ + controller: c, + pubKeyID: pubKeyID, + privkey: privkey, } - postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120) - if err != nil { - return nil, fmt.Errorf("error creating post signer: %s", err) + // Cache this transport under pubkey + if !c.cache.Put(pubStr, transp) { + var cached *transport + + cached, ok = c.cache.Get(pubStr) + if !ok { + // Some ridiculous race cond. + c.cache.Set(pubStr, transp) + } else { + // Use already cached + transp = cached + } } - sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey) - - return &transport{ - client: c.client, - appAgent: c.appAgent, - gofedAgent: "(go-fed/activity v1.0.0)", - clock: c.clock, - pubKeyID: pubKeyID, - privkey: privkey, - sigTransport: sigTransport, - getSigner: getSigner, - getSignerMu: &sync.Mutex{}, - dereferenceFollowersShortcut: c.dereferenceFollowersShortcut, - dereferenceUserShortcut: c.dereferenceUserShortcut, - }, nil + return transp, nil } func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) { @@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin } return transport, nil } + +// dereferenceLocalFollowers is a shortcut to dereference followers of an +// account on this instance, without making any external api/http calls. +// +// It is passed to new transports, and should only be invoked when the iri.Host == this host. +func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) { + followers, err := c.fedDB.Followers(ctx, iri) + if err != nil { + return nil, err + } + + i, err := streams.Serialize(followers) + if err != nil { + return nil, err + } + + return json.Marshal(i) +} + +// dereferenceLocalUser is a shortcut to dereference followers an account on +// this instance, without making any external api/http calls. +// +// It is passed to new transports, and should only be invoked when the iri.Host == this host. +func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) { + user, err := c.fedDB.Get(ctx, iri) + if err != nil { + return nil, err + } + + i, err := streams.Serialize(user) + if err != nil { + return nil, err + } + + return json.Marshal(i) +} + +// privkeyToPublicStr will create a string representation of RSA public key from private. +func privkeyToPublicStr(privkey *rsa.PrivateKey) string { + b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey) + return byteutil.B2S(b) +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index fe17f7761..bacaa9b3a 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -19,13 +19,14 @@ package transport import ( + "bytes" "context" "fmt" + "net/http" "net/url" "strings" "sync" - "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/config" ) @@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { return nil } - logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String()) - return t.sigTransport.Deliver(ctx, b, to) + urlStr := to.String() + + req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b)) + if err != nil { + return err + } + + req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"") + req.Header.Add("Accept-Charset", "utf-8") + req.Header.Add("User-Agent", t.controller.userAgent) + req.Header.Set("Host", to.Host) + + resp, err := t.POST(req, b) + if err != nil { + return err + } + defer resp.Body.Close() + + if code := resp.StatusCode; code != http.StatusOK && + code != http.StatusCreated && code != http.StatusAccepted { + return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status) + } + + return nil } diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go index 61d99c5c5..36157b673 100644 --- a/internal/transport/dereference.go +++ b/internal/transport/dereference.go @@ -20,32 +20,55 @@ package transport import ( "context" + "fmt" + "io/ioutil" + "net/http" "net/url" - "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/uris" ) func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) { - l := logrus.WithField("func", "Dereference") - // if the request is to us, we can shortcut for certain URIs rather than going through // the normal request flow, thereby saving time and energy if iri.Host == viper.GetString(config.Keys.Host) { if uris.IsFollowersPath(iri) { // the request is for followers of one of our accounts, which we can shortcut - return t.dereferenceFollowersShortcut(ctx, iri) + return t.controller.dereferenceLocalFollowers(ctx, iri) } if uris.IsUserPath(iri) { // the request is for one of our accounts, which we can shortcut - return t.dereferenceUserShortcut(ctx, iri) + return t.controller.dereferenceLocalUser(ctx, iri) } } - // the request is either for a remote host or for us but we don't have a shortcut, so continue as normal - l.Debugf("performing GET to %s", iri.String()) - return t.sigTransport.Dereference(ctx, iri) + // Build IRI just once + iriStr := iri.String() + + // Prepare new HTTP request to endpoint + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"") + req.Header.Add("Accept-Charset", "utf-8") + req.Header.Add("User-Agent", t.controller.userAgent) + req.Header.Set("Host", iri.Host) + + // Perform the HTTP request + rsp, err := t.GET(req) + if err != nil { + return nil, err + } + defer rsp.Body.Close() + + // Check for an expected status code + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status) + } + + return ioutil.ReadAll(rsp.Body) } diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go index c64dced0f..1acbcc364 100644 --- a/internal/transport/derefinstance.go +++ b/internal/transport/derefinstance.go @@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts } func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) { - l := logrus.WithField("func", "dereferenceByAPIV1Instance") - cleanIRI := &url.URL{ Scheme: iri.Scheme, Host: iri.Host, Path: "api/v1/instance", } - l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) + // Build IRI just once + iriStr := cleanIRI.String() + + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, err } + req.Header.Add("Accept", "application/json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", cleanIRI.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + + resp, err := t.GET(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status) + return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) } + b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err - } - - if len(b) == 0 { + } else if len(b) == 0 { return nil, errors.New("response bytes was len 0") } @@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm } func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) { - l := logrus.WithField("func", "callNodeInfoWellKnown") - cleanIRI := &url.URL{ Scheme: iri.Scheme, Host: iri.Host, Path: ".well-known/nodeinfo", } - l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) + // Build IRI just once + iriStr := cleanIRI.String() + + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, err } - req.Header.Add("Accept", "application/json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", cleanIRI.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + + resp, err := t.GET(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status) + return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) } + b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err - } - - if len(b) == 0 { + } else if len(b) == 0 { return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0") } @@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur } func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) { - l := logrus.WithField("func", "callNodeInfo") + // Build IRI just once + iriStr := iri.String() - l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, err } - req.Header.Add("Accept", "application/json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", iri.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + + resp, err := t.GET(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status) + return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) } + b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err - } - - if len(b) == 0 { + } else if len(b) == 0 { return nil, errors.New("callNodeInfo: response bytes was len 0") } diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go index e3c86ce1e..8feb7ed20 100644 --- a/internal/transport/derefmedia.go +++ b/internal/transport/derefmedia.go @@ -24,34 +24,31 @@ import ( "io" "net/http" "net/url" - - "github.com/sirupsen/logrus" ) func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) { - l := logrus.WithField("func", "DereferenceMedia") - l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) + // Build IRI just once + iriStr := iri.String() + + // Prepare HTTP request to this media's IRI + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, 0, err } - req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", iri.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, 0, err - } - resp, err := t.client.Do(req) + + // Perform the HTTP request + rsp, err := t.GET(req) if err != nil { return nil, 0, err } - if resp.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status) + + // Check for an expected status code + if rsp.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status) } - return resp.Body, int(resp.ContentLength), nil + + return rsp.Body, int(rsp.ContentLength), nil } diff --git a/internal/transport/finger.go b/internal/transport/finger.go index a71bbb51e..7554a242f 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -23,46 +23,36 @@ import ( "fmt" "io/ioutil" "net/http" - "net/url" - - "github.com/sirupsen/logrus" ) func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) { - l := logrus.WithField("func", "Finger") - urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain) - l.Debugf("performing GET to %s", urlString) - - iri, err := url.Parse(urlString) - if err != nil { - return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err) - } - - l.Debugf("performing GET to %s", iri.String()) - - req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) + // Prepare URL string + urlStr := "https://" + + targetDomain + + "/.well-known/webfinger?resource=acct:" + + targetUsername + "@" + targetDomain + + // Generate new GET request from URL string + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { return nil, err } - req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/jrd+json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) - req.Header.Set("Host", iri.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + req.Header.Add("User-Agent", t.controller.userAgent) + req.Header.Set("Host", req.URL.Host) + + // Perform the HTTP request + rsp, err := t.GET(req) if err != nil { return nil, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status) + defer rsp.Body.Close() + + // Check for an expected status code + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status) } - return ioutil.ReadAll(resp.Body) + + return ioutil.ReadAll(rsp.Body) } diff --git a/internal/transport/signing.go b/internal/transport/signing.go new file mode 100644 index 000000000..39896a2a8 --- /dev/null +++ b/internal/transport/signing.go @@ -0,0 +1,43 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package transport + +import ( + "github.com/go-fed/httpsig" +) + +var ( + // http signer preferences + prefs = []httpsig.Algorithm{httpsig.RSA_SHA256} + digestAlgo = httpsig.DigestSha256 + getHeaders = []string{httpsig.RequestTarget, "host", "date"} + postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"} +) + +// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences. +func NewGETSigner(expiresIn int64) (httpsig.Signer, error) { + sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn) + return sig, err +} + +// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences. +func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) { + sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn) + return sig, err +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 40c11ca17..c52686c43 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -21,11 +21,18 @@ package transport import ( "context" "crypto" + "crypto/x509" + "errors" "io" + "net/http" "net/url" + "strings" "sync" + "time" + errorsv2 "codeberg.org/gruf/go-errors/v2" "github.com/go-fed/httpsig" + "github.com/sirupsen/logrus" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -43,28 +50,148 @@ type Transport interface { DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) // Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body. Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error) - // SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport. - SigTransport() pub.Transport } // transport implements the Transport interface type transport struct { - client pub.HttpClient - appAgent string - gofedAgent string - clock pub.Clock - pubKeyID string - privkey crypto.PrivateKey - sigTransport *pub.HttpSigTransport - getSigner httpsig.Signer - getSignerMu *sync.Mutex - - // shortcuts for dereferencing things that exist on our instance without making an http call to ourself - - dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) - dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) + controller *controller + pubKeyID string + privkey crypto.PrivateKey + + signerExp time.Time + getSigner httpsig.Signer + postSigner httpsig.Signer + signerMu sync.Mutex +} + +// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) { + if r.Method != http.MethodGet { + return nil, errors.New("must be GET request") + } + return t.do(r, func(r *http.Request) error { + return t.signGET(r) + }, retryOn...) +} + +// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) { + if r.Method != http.MethodPost { + return nil, errors.New("must be POST request") + } + return t.do(r, func(r *http.Request) error { + return t.signPOST(r, body) + }, retryOn...) +} + +func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) { + const maxRetries = 5 + backoff := time.Second * 2 + + // Start a log entry for this request + l := logrus.WithFields(logrus.Fields{ + "pubKeyID": t.pubKeyID, + "method": r.Method, + "url": r.URL.String(), + }) + + for i := 0; i < maxRetries; i++ { + // Reset signing header fields + now := t.controller.clock.Now().UTC() + r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") + r.Header.Del("Signature") + r.Header.Del("Digest") + + // Perform request signing + if err := signer(r); err != nil { + return nil, err + } + + l.Infof("performing request") + + // Attempt to perform request + rsp, err := t.controller.client.Do(r) + if err == nil { //nolint shutup linter + // TooManyRequest means we need to slow + // down and retry our request. Codes over + // 500 generally indicate temp. outages. + if code := rsp.StatusCode; code < 500 && + code != http.StatusTooManyRequests && + !containsInt(retryOn, rsp.StatusCode) { + return rsp, nil + } + + // Generate error from status code for logging + err = errors.New(`http response "` + rsp.Status + `"`) + } else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) { + // Return early if context has cancelled + return nil, err + } else if strings.Contains(err.Error(), "stopped after 10 redirects") { + // Don't bother if net/http returned after too many redirects + return nil, err + } else if errors.As(err, &x509.UnknownAuthorityError{}) { + // Unknown authority errors we do NOT recover from + return nil, err + } + + l.Errorf("backing off for %s after http request error: %v", backoff.String(), err) + + select { + // Request ctx cancelled + case <-r.Context().Done(): + return nil, r.Context().Err() + + // Backoff for some time + case <-time.After(backoff): + backoff *= 2 + } + } + + return nil, errors.New("transport reached max retries") +} + +// signGET will safely sign an HTTP GET request. +func (t *transport) signGET(r *http.Request) (err error) { + t.safesign(func() { + err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) + }) + return +} + +// signPOST will safely sign an HTTP POST request for given body. +func (t *transport) signPOST(r *http.Request, body []byte) (err error) { + t.safesign(func() { + err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) + }) + return +} + +// safesign will perform sign function within mutex protection, +// and ensured that httpsig.Signers are up-to-date. +func (t *transport) safesign(sign func()) { + // Perform within mu safety + t.signerMu.Lock() + defer t.signerMu.Unlock() + + if now := time.Now(); now.After(t.signerExp) { + const expiry = 120 + + // Signers have expired and require renewal + t.getSigner, _ = NewGETSigner(expiry) + t.postSigner, _ = NewPOSTSigner(expiry) + t.signerExp = now.Add(time.Second * expiry) + } + + // Perform signing + sign() } -func (t *transport) SigTransport() pub.Transport { - return t.sigTransport +// containsInt checks if slice contains check. +func containsInt(slice []int, check int) bool { + for _, i := range slice { + if i == check { + return true + } + } + return false } |