summaryrefslogtreecommitdiff
path: root/internal/transport
diff options
context:
space:
mode:
Diffstat (limited to 'internal/transport')
-rw-r--r--internal/transport/controller.go196
-rw-r--r--internal/transport/deliver.go29
-rw-r--r--internal/transport/dereference.go39
-rw-r--r--internal/transport/derefinstance.go85
-rw-r--r--internal/transport/derefmedia.go33
-rw-r--r--internal/transport/finger.go50
-rw-r--r--internal/transport/signing.go43
-rw-r--r--internal/transport/transport.go163
8 files changed, 422 insertions, 216 deletions
diff --git a/internal/transport/controller.go b/internal/transport/controller.go
index 56a922a8b..280d4bc0b 100644
--- a/internal/transport/controller.go
+++ b/internal/transport/controller.go
@@ -20,13 +20,17 @@ package transport
import (
"context"
- "crypto"
+ "crypto/rsa"
+ "crypto/x509"
"encoding/json"
"fmt"
"net/url"
- "sync"
+ "runtime/debug"
+ "time"
- "github.com/go-fed/httpsig"
+ "codeberg.org/gruf/go-byteutil"
+ "codeberg.org/gruf/go-cache/v2"
+ "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
@@ -37,109 +41,85 @@ import (
// Controller generates transports for use in making federation requests to other servers.
type Controller interface {
- NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
+ // NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key.
+ NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error)
+
+ // NewTransportForUsername searches for account with username, and returns result of .NewTransport().
NewTransportForUsername(ctx context.Context, username string) (Transport, error)
}
type controller struct {
- db db.DB
- clock pub.Clock
- client pub.HttpClient
- appAgent string
-
- // dereferenceFollowersShortcut is a shortcut to dereference followers of an
- // account on this instance, without making any external api/http calls.
- //
- // It is passed to new transports, and should only be invoked when the iri.Host == this host.
- dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
-
- // dereferenceUserShortcut is a shortcut to dereference followers an account on
- // this instance, without making any external api/http calls.
- //
- // It is passed to new transports, and should only be invoked when the iri.Host == this host.
- dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
+ db db.DB
+ fedDB federatingdb.DB
+ clock pub.Clock
+ client pub.HttpClient
+ cache cache.Cache[string, *transport]
+ userAgent string
}
-func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
- return func(ctx context.Context, iri *url.URL) ([]byte, error) {
- followers, err := federatingDB.Followers(ctx, iri)
- if err != nil {
- return nil, err
- }
+// NewController returns an implementation of the Controller interface for creating new transports
+func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
+ applicationName := viper.GetString(config.Keys.ApplicationName)
+ host := viper.GetString(config.Keys.Host)
- i, err := streams.Serialize(followers)
- if err != nil {
- return nil, err
- }
+ // Determine build information
+ build, _ := debug.ReadBuildInfo()
- return json.Marshal(i)
+ c := &controller{
+ db: db,
+ fedDB: federatingDB,
+ clock: clock,
+ client: client,
+ cache: cache.New[string, *transport](),
+ userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version),
}
-}
-func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
- return func(ctx context.Context, iri *url.URL) ([]byte, error) {
- user, err := federatingDB.Get(ctx, iri)
- if err != nil {
- return nil, err
- }
-
- i, err := streams.Serialize(user)
- if err != nil {
- return nil, err
- }
-
- return json.Marshal(i)
+ // Transport cache has TTL=1hr freq=1m
+ c.cache.SetTTL(time.Hour, false)
+ if !c.cache.Start(time.Minute) {
+ logrus.Panic("failed to start transport controller cache")
}
-}
-// NewController returns an implementation of the Controller interface for creating new transports
-func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
- applicationName := viper.GetString(config.Keys.ApplicationName)
- host := viper.GetString(config.Keys.Host)
- appAgent := fmt.Sprintf("%s %s", applicationName, host)
-
- return &controller{
- db: db,
- clock: clock,
- client: client,
- appAgent: appAgent,
- dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB),
- dereferenceUserShortcut: dereferenceUserShortcut(federatingDB),
- }
+ return c
}
-// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key.
-func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) {
- prefs := []httpsig.Algorithm{httpsig.RSA_SHA256}
- digestAlgo := httpsig.DigestSha256
- getHeaders := []string{httpsig.RequestTarget, "host", "date"}
- postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"}
+func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) {
+ // Generate public key string for cache key
+ //
+ // NOTE: it is safe to use the public key as the cache
+ // key here as we are generating it ourselves from the
+ // private key. If we were simply using a public key
+ // provided as argument that would absolutely NOT be safe.
+ pubStr := privkeyToPublicStr(privkey)
+
+ // First check for cached transport
+ transp, ok := c.cache.Get(pubStr)
+ if ok {
+ return transp, nil
+ }
- getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120)
- if err != nil {
- return nil, fmt.Errorf("error creating get signer: %s", err)
+ // Create the transport
+ transp = &transport{
+ controller: c,
+ pubKeyID: pubKeyID,
+ privkey: privkey,
}
- postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120)
- if err != nil {
- return nil, fmt.Errorf("error creating post signer: %s", err)
+ // Cache this transport under pubkey
+ if !c.cache.Put(pubStr, transp) {
+ var cached *transport
+
+ cached, ok = c.cache.Get(pubStr)
+ if !ok {
+ // Some ridiculous race cond.
+ c.cache.Set(pubStr, transp)
+ } else {
+ // Use already cached
+ transp = cached
+ }
}
- sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey)
-
- return &transport{
- client: c.client,
- appAgent: c.appAgent,
- gofedAgent: "(go-fed/activity v1.0.0)",
- clock: c.clock,
- pubKeyID: pubKeyID,
- privkey: privkey,
- sigTransport: sigTransport,
- getSigner: getSigner,
- getSignerMu: &sync.Mutex{},
- dereferenceFollowersShortcut: c.dereferenceFollowersShortcut,
- dereferenceUserShortcut: c.dereferenceUserShortcut,
- }, nil
+ return transp, nil
}
func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
@@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
}
return transport, nil
}
+
+// dereferenceLocalFollowers is a shortcut to dereference followers of an
+// account on this instance, without making any external api/http calls.
+//
+// It is passed to new transports, and should only be invoked when the iri.Host == this host.
+func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) {
+ followers, err := c.fedDB.Followers(ctx, iri)
+ if err != nil {
+ return nil, err
+ }
+
+ i, err := streams.Serialize(followers)
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(i)
+}
+
+// dereferenceLocalUser is a shortcut to dereference followers an account on
+// this instance, without making any external api/http calls.
+//
+// It is passed to new transports, and should only be invoked when the iri.Host == this host.
+func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) {
+ user, err := c.fedDB.Get(ctx, iri)
+ if err != nil {
+ return nil, err
+ }
+
+ i, err := streams.Serialize(user)
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(i)
+}
+
+// privkeyToPublicStr will create a string representation of RSA public key from private.
+func privkeyToPublicStr(privkey *rsa.PrivateKey) string {
+ b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey)
+ return byteutil.B2S(b)
+}
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index fe17f7761..bacaa9b3a 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -19,13 +19,14 @@
package transport
import (
+ "bytes"
"context"
"fmt"
+ "net/http"
"net/url"
"strings"
"sync"
- "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
@@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
return nil
}
- logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
- return t.sigTransport.Deliver(ctx, b, to)
+ urlStr := to.String()
+
+ req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b))
+ if err != nil {
+ return err
+ }
+
+ req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
+ req.Header.Add("Accept-Charset", "utf-8")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", to.Host)
+
+ resp, err := t.POST(req, b)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if code := resp.StatusCode; code != http.StatusOK &&
+ code != http.StatusCreated && code != http.StatusAccepted {
+ return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status)
+ }
+
+ return nil
}
diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go
index 61d99c5c5..36157b673 100644
--- a/internal/transport/dereference.go
+++ b/internal/transport/dereference.go
@@ -20,32 +20,55 @@ package transport
import (
"context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
"net/url"
- "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
- l := logrus.WithField("func", "Dereference")
-
// if the request is to us, we can shortcut for certain URIs rather than going through
// the normal request flow, thereby saving time and energy
if iri.Host == viper.GetString(config.Keys.Host) {
if uris.IsFollowersPath(iri) {
// the request is for followers of one of our accounts, which we can shortcut
- return t.dereferenceFollowersShortcut(ctx, iri)
+ return t.controller.dereferenceLocalFollowers(ctx, iri)
}
if uris.IsUserPath(iri) {
// the request is for one of our accounts, which we can shortcut
- return t.dereferenceUserShortcut(ctx, iri)
+ return t.controller.dereferenceLocalUser(ctx, iri)
}
}
- // the request is either for a remote host or for us but we don't have a shortcut, so continue as normal
- l.Debugf("performing GET to %s", iri.String())
- return t.sigTransport.Dereference(ctx, iri)
+ // Build IRI just once
+ iriStr := iri.String()
+
+ // Prepare new HTTP request to endpoint
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
+ req.Header.Add("Accept-Charset", "utf-8")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", iri.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
+ if err != nil {
+ return nil, err
+ }
+ defer rsp.Body.Close()
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
+ }
+
+ return ioutil.ReadAll(rsp.Body)
}
diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go
index c64dced0f..1acbcc364 100644
--- a/internal/transport/derefinstance.go
+++ b/internal/transport/derefinstance.go
@@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts
}
func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
- l := logrus.WithField("func", "dereferenceByAPIV1Instance")
-
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: "api/v1/instance",
}
- l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
+ // Build IRI just once
+ iriStr := cleanIRI.String()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
+
req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("response bytes was len 0")
}
@@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
}
func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
- l := logrus.WithField("func", "callNodeInfoWellKnown")
-
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: ".well-known/nodeinfo",
}
- l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
+ // Build IRI just once
+ iriStr := cleanIRI.String()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0")
}
@@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
}
func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
- l := logrus.WithField("func", "callNodeInfo")
+ // Build IRI just once
+ iriStr := iri.String()
- l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("callNodeInfo: response bytes was len 0")
}
diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go
index e3c86ce1e..8feb7ed20 100644
--- a/internal/transport/derefmedia.go
+++ b/internal/transport/derefmedia.go
@@ -24,34 +24,31 @@ import (
"io"
"net/http"
"net/url"
-
- "github.com/sirupsen/logrus"
)
func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) {
- l := logrus.WithField("func", "DereferenceMedia")
- l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ // Build IRI just once
+ iriStr := iri.String()
+
+ // Prepare HTTP request to this media's IRI
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, 0, err
}
-
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, 0, err
- }
- resp, err := t.client.Do(req)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
if err != nil {
return nil, 0, err
}
- if resp.StatusCode != http.StatusOK {
- return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
}
- return resp.Body, int(resp.ContentLength), nil
+
+ return rsp.Body, int(rsp.ContentLength), nil
}
diff --git a/internal/transport/finger.go b/internal/transport/finger.go
index a71bbb51e..7554a242f 100644
--- a/internal/transport/finger.go
+++ b/internal/transport/finger.go
@@ -23,46 +23,36 @@ import (
"fmt"
"io/ioutil"
"net/http"
- "net/url"
-
- "github.com/sirupsen/logrus"
)
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
- l := logrus.WithField("func", "Finger")
- urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
- l.Debugf("performing GET to %s", urlString)
-
- iri, err := url.Parse(urlString)
- if err != nil {
- return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err)
- }
-
- l.Debugf("performing GET to %s", iri.String())
-
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ // Prepare URL string
+ urlStr := "https://" +
+ targetDomain +
+ "/.well-known/webfinger?resource=acct:" +
+ targetUsername + "@" + targetDomain
+
+ // Generate new GET request from URL string
+ req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
req.Header.Add("Accept", "application/jrd+json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
- req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", req.URL.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
if err != nil {
return nil, err
}
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
+ defer rsp.Body.Close()
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status)
}
- return ioutil.ReadAll(resp.Body)
+
+ return ioutil.ReadAll(rsp.Body)
}
diff --git a/internal/transport/signing.go b/internal/transport/signing.go
new file mode 100644
index 000000000..39896a2a8
--- /dev/null
+++ b/internal/transport/signing.go
@@ -0,0 +1,43 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package transport
+
+import (
+ "github.com/go-fed/httpsig"
+)
+
+var (
+ // http signer preferences
+ prefs = []httpsig.Algorithm{httpsig.RSA_SHA256}
+ digestAlgo = httpsig.DigestSha256
+ getHeaders = []string{httpsig.RequestTarget, "host", "date"}
+ postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"}
+)
+
+// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences.
+func NewGETSigner(expiresIn int64) (httpsig.Signer, error) {
+ sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn)
+ return sig, err
+}
+
+// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences.
+func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) {
+ sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn)
+ return sig, err
+}
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index 40c11ca17..c52686c43 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -21,11 +21,18 @@ package transport
import (
"context"
"crypto"
+ "crypto/x509"
+ "errors"
"io"
+ "net/http"
"net/url"
+ "strings"
"sync"
+ "time"
+ errorsv2 "codeberg.org/gruf/go-errors/v2"
"github.com/go-fed/httpsig"
+ "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -43,28 +50,148 @@ type Transport interface {
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
- // SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
- SigTransport() pub.Transport
}
// transport implements the Transport interface
type transport struct {
- client pub.HttpClient
- appAgent string
- gofedAgent string
- clock pub.Clock
- pubKeyID string
- privkey crypto.PrivateKey
- sigTransport *pub.HttpSigTransport
- getSigner httpsig.Signer
- getSignerMu *sync.Mutex
-
- // shortcuts for dereferencing things that exist on our instance without making an http call to ourself
-
- dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
- dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
+ controller *controller
+ pubKeyID string
+ privkey crypto.PrivateKey
+
+ signerExp time.Time
+ getSigner httpsig.Signer
+ postSigner httpsig.Signer
+ signerMu sync.Mutex
+}
+
+// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodGet {
+ return nil, errors.New("must be GET request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signGET(r)
+ }, retryOn...)
+}
+
+// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodPost {
+ return nil, errors.New("must be POST request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signPOST(r, body)
+ }, retryOn...)
+}
+
+func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
+ const maxRetries = 5
+ backoff := time.Second * 2
+
+ // Start a log entry for this request
+ l := logrus.WithFields(logrus.Fields{
+ "pubKeyID": t.pubKeyID,
+ "method": r.Method,
+ "url": r.URL.String(),
+ })
+
+ for i := 0; i < maxRetries; i++ {
+ // Reset signing header fields
+ now := t.controller.clock.Now().UTC()
+ r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
+ r.Header.Del("Signature")
+ r.Header.Del("Digest")
+
+ // Perform request signing
+ if err := signer(r); err != nil {
+ return nil, err
+ }
+
+ l.Infof("performing request")
+
+ // Attempt to perform request
+ rsp, err := t.controller.client.Do(r)
+ if err == nil { //nolint shutup linter
+ // TooManyRequest means we need to slow
+ // down and retry our request. Codes over
+ // 500 generally indicate temp. outages.
+ if code := rsp.StatusCode; code < 500 &&
+ code != http.StatusTooManyRequests &&
+ !containsInt(retryOn, rsp.StatusCode) {
+ return rsp, nil
+ }
+
+ // Generate error from status code for logging
+ err = errors.New(`http response "` + rsp.Status + `"`)
+ } else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
+ // Return early if context has cancelled
+ return nil, err
+ } else if strings.Contains(err.Error(), "stopped after 10 redirects") {
+ // Don't bother if net/http returned after too many redirects
+ return nil, err
+ } else if errors.As(err, &x509.UnknownAuthorityError{}) {
+ // Unknown authority errors we do NOT recover from
+ return nil, err
+ }
+
+ l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
+
+ select {
+ // Request ctx cancelled
+ case <-r.Context().Done():
+ return nil, r.Context().Err()
+
+ // Backoff for some time
+ case <-time.After(backoff):
+ backoff *= 2
+ }
+ }
+
+ return nil, errors.New("transport reached max retries")
+}
+
+// signGET will safely sign an HTTP GET request.
+func (t *transport) signGET(r *http.Request) (err error) {
+ t.safesign(func() {
+ err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
+ })
+ return
+}
+
+// signPOST will safely sign an HTTP POST request for given body.
+func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
+ t.safesign(func() {
+ err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
+ })
+ return
+}
+
+// safesign will perform sign function within mutex protection,
+// and ensured that httpsig.Signers are up-to-date.
+func (t *transport) safesign(sign func()) {
+ // Perform within mu safety
+ t.signerMu.Lock()
+ defer t.signerMu.Unlock()
+
+ if now := time.Now(); now.After(t.signerExp) {
+ const expiry = 120
+
+ // Signers have expired and require renewal
+ t.getSigner, _ = NewGETSigner(expiry)
+ t.postSigner, _ = NewPOSTSigner(expiry)
+ t.signerExp = now.Add(time.Second * expiry)
+ }
+
+ // Perform signing
+ sign()
}
-func (t *transport) SigTransport() pub.Transport {
- return t.sigTransport
+// containsInt checks if slice contains check.
+func containsInt(slice []int, check int) bool {
+ for _, i := range slice {
+ if i == check {
+ return true
+ }
+ }
+ return false
}