summaryrefslogtreecommitdiff
path: root/internal/transport/delivery
diff options
context:
space:
mode:
Diffstat (limited to 'internal/transport/delivery')
-rw-r--r--internal/transport/delivery/delivery.go16
-rw-r--r--internal/transport/delivery/delivery_test.go31
-rw-r--r--internal/transport/delivery/worker.go80
3 files changed, 73 insertions, 54 deletions
diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go
index 1e3ebb054..e11eea83c 100644
--- a/internal/transport/delivery/delivery.go
+++ b/internal/transport/delivery/delivery.go
@@ -33,10 +33,6 @@ import (
// be indexed (and so, dropped from queue)
// by any of these possible ID IRIs.
type Delivery struct {
- // PubKeyID is the signing public key
- // ID of the actor performing request.
- PubKeyID string
-
// ActorID contains the ActivityPub
// actor ID IRI (if any) of the activity
// being sent out by this request.
@@ -55,7 +51,7 @@ type Delivery struct {
// Request is the prepared (+ wrapped)
// httpclient.Client{} request that
// constitutes this ActivtyPub delivery.
- Request httpclient.Request
+ Request *httpclient.Request
// internal fields.
next time.Time
@@ -66,7 +62,6 @@ type Delivery struct {
// a json serialize / deserialize
// able shape that minimizes data.
type delivery struct {
- PubKeyID string `json:"pub_key_id,omitempty"`
ActorID string `json:"actor_id,omitempty"`
ObjectID string `json:"object_id,omitempty"`
TargetID string `json:"target_id,omitempty"`
@@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) {
// Marshal as internal JSON type.
return json.Marshal(delivery{
- PubKeyID: dlv.PubKeyID,
ActorID: dlv.ActorID,
ObjectID: dlv.ObjectID,
TargetID: dlv.TargetID,
@@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error {
}
// Copy over simplest fields.
- dlv.PubKeyID = idlv.PubKeyID
dlv.ActorID = idlv.ActorID
dlv.ObjectID = idlv.ObjectID
dlv.TargetID = idlv.TargetID
@@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error {
return err
}
+ // Copy over any stored header values.
+ for key, values := range idlv.Header {
+ for _, value := range values {
+ r.Header.Add(key, value)
+ }
+ }
+
// Wrap request in httpclient type.
dlv.Request = httpclient.WrapRequest(r)
diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go
index e9eaf8fd1..81f32d5f8 100644
--- a/internal/transport/delivery/delivery_test.go
+++ b/internal/transport/delivery/delivery_test.go
@@ -35,32 +35,30 @@ var deliveryCases = []struct {
}{
{
msg: delivery.Delivery{
- PubKeyID: "https://google.com/users/bigboy#pubkey",
ActorID: "https://google.com/users/bigboy",
ObjectID: "https://google.com/users/bigboy/follow/1",
TargetID: "https://askjeeves.com/users/smallboy",
- Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")),
+ Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}),
},
data: toJSON(map[string]any{
- "pub_key_id": "https://google.com/users/bigboy#pubkey",
- "actor_id": "https://google.com/users/bigboy",
- "object_id": "https://google.com/users/bigboy/follow/1",
- "target_id": "https://askjeeves.com/users/smallboy",
- "method": "POST",
- "url": "https://askjeeves.com/users/smallboy/inbox",
- "body": []byte("data!"),
- // "header": map[string][]string{},
+ "actor_id": "https://google.com/users/bigboy",
+ "object_id": "https://google.com/users/bigboy/follow/1",
+ "target_id": "https://askjeeves.com/users/smallboy",
+ "method": "POST",
+ "url": "https://askjeeves.com/users/smallboy/inbox",
+ "body": []byte("data!"),
+ "header": map[string][]string{"Hello": {"world1", "world2"}},
}),
},
{
msg: delivery.Delivery{
- Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")),
+ Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil),
},
data: toJSON(map[string]any{
"method": "GET",
"url": "https://google.com",
"body": []byte("uwu im just a wittle seawch engwin"),
- // "header": map[string][]string{},
+ // "header": map[string][]string{},
}),
},
}
@@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) {
}
// Check that delivery fields are as expected.
- assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID)
assert.Equal(t, test.msg.ActorID, msg.ActorID)
assert.Equal(t, test.msg.ObjectID, msg.ObjectID)
assert.Equal(t, test.msg.TargetID, msg.TargetID)
assert.Equal(t, test.msg.Request.Method, msg.Request.Method)
assert.Equal(t, test.msg.Request.URL, msg.Request.URL)
assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body))
+ assert.Equal(t, test.msg.Request.Header, msg.Request.Header)
}
}
// toRequest creates httpclient.Request from HTTP method, URL and body data.
-func toRequest(method string, url string, body []byte) httpclient.Request {
+func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
var rbody io.Reader
if body != nil {
rbody = bytes.NewReader(body)
@@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request {
if err != nil {
panic(err)
}
+ for key, values := range hdr {
+ for _, value := range values {
+ req.Header.Add(key, value)
+ }
+ }
return httpclient.WrapRequest(req)
}
diff --git a/internal/transport/delivery/worker.go b/internal/transport/delivery/worker.go
index ef31e94a6..d6d253769 100644
--- a/internal/transport/delivery/worker.go
+++ b/internal/transport/delivery/worker.go
@@ -19,6 +19,7 @@ package delivery
import (
"context"
+ "errors"
"slices"
"time"
@@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool {
loop:
for {
+ // Before trying to get
+ // next delivery, check
+ // context still valid.
+ if ctx.Err() != nil {
+ return true
+ }
+
// Get next delivery.
dlv, ok := w.next(ctx)
if !ok {
@@ -195,16 +203,30 @@ loop:
// Attempt delivery of AP request.
rsp, retry, err := w.Client.DoOnce(
- &dlv.Request,
+ dlv.Request,
)
- if err == nil {
+ switch {
+ case err == nil:
// Ensure body closed.
_ = rsp.Body.Close()
continue loop
- }
- if !retry {
+ case errors.Is(err, context.Canceled) &&
+ ctx.Err() != nil:
+ // In the case of our own context
+ // being cancelled, push delivery
+ // back onto queue for persisting.
+ //
+ // Note we specifically check against
+ // context.Canceled here as it will
+ // be faster than the mutex lock of
+ // ctx.Err(), so gives an initial
+ // faster check in the if-clause.
+ w.Queue.Push(dlv)
+ continue loop
+
+ case !retry:
// Drop deliveries when no
// retry requested, or they
// reached max (either).
@@ -222,42 +244,36 @@ loop:
// next gets the next available delivery, blocking until available if necessary.
func (w *Worker) next(ctx context.Context) (*Delivery, bool) {
-loop:
- for {
- // Try pop next queued.
- dlv, ok := w.Queue.Pop()
+ // Try a fast-pop of queued
+ // delivery before anything.
+ dlv, ok := w.Queue.Pop()
- if !ok {
- // Check the backlog.
- if len(w.backlog) > 0 {
+ if !ok {
+ // Check the backlog.
+ if len(w.backlog) > 0 {
- // Sort by 'next' time.
- sortDeliveries(w.backlog)
+ // Sort by 'next' time.
+ sortDeliveries(w.backlog)
- // Pop next delivery.
- dlv := w.popBacklog()
+ // Pop next delivery.
+ dlv := w.popBacklog()
- return dlv, true
- }
-
- select {
- // Backlog is empty, we MUST
- // block until next enqueued.
- case <-w.Queue.Wait():
- continue loop
+ return dlv, true
+ }
- // Worker was stopped.
- case <-ctx.Done():
- return nil, false
- }
+ // Block on next delivery push
+ // OR worker context canceled.
+ dlv, ok = w.Queue.PopCtx(ctx)
+ if !ok {
+ return nil, false
}
+ }
- // Replace request context for worker state canceling.
- ctx := gtscontext.WithValues(ctx, dlv.Request.Context())
- dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
+ // Replace request context for worker state canceling.
+ ctx = gtscontext.WithValues(ctx, dlv.Request.Context())
+ dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
- return dlv, true
- }
+ return dlv, true
}
// popBacklog pops next available from the backlog.