diff options
Diffstat (limited to 'internal/transport/transport.go')
-rw-r--r-- | internal/transport/transport.go | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/internal/transport/transport.go b/internal/transport/transport.go new file mode 100644 index 000000000..afd408519 --- /dev/null +++ b/internal/transport/transport.go @@ -0,0 +1,77 @@ +package transport + +import ( + "context" + "crypto" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "sync" + + "github.com/go-fed/activity/pub" + "github.com/go-fed/httpsig" +) + +// Transport wraps the pub.Transport interface with some additional +// functionality for fetching remote media. +type Transport interface { + pub.Transport + DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) +} + +// transport implements the Transport interface +type transport struct { + client pub.HttpClient + appAgent string + gofedAgent string + clock pub.Clock + pubKeyID string + privkey crypto.PrivateKey + sigTransport *pub.HttpSigTransport + getSigner httpsig.Signer + getSignerMu *sync.Mutex +} + +func (t *transport) BatchDeliver(c context.Context, b []byte, recipients []*url.URL) error { + return t.sigTransport.BatchDeliver(c, b, recipients) +} + +func (t *transport) Deliver(c context.Context, b []byte, to *url.URL) error { + return t.sigTransport.Deliver(c, b, to) +} + +func (t *transport) Dereference(c context.Context, iri *url.URL) ([]byte, error) { + return t.sigTransport.Dereference(c, iri) +} + +func (t *transport) DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) { + req, err := http.NewRequest("GET", iri.String(), nil) + if err != nil { + return nil, err + } + req = req.WithContext(c) + if expectedContentType == "" { + req.Header.Add("Accept", "*/*") + } else { + req.Header.Add("Accept", expectedContentType) + } + req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") + req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Set("Host", iri.Host) + t.getSignerMu.Lock() + err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) + t.getSignerMu.Unlock() + if err != nil { + return nil, err + } + resp, err := t.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status) + } + return ioutil.ReadAll(resp.Body) +} |