diff options
Diffstat (limited to 'internal/transport/delivery')
| -rw-r--r-- | internal/transport/delivery/delivery.go | 16 | ||||
| -rw-r--r-- | internal/transport/delivery/delivery_test.go | 31 | ||||
| -rw-r--r-- | internal/transport/delivery/worker.go | 80 |
3 files changed, 73 insertions, 54 deletions
diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go index 1e3ebb054..e11eea83c 100644 --- a/internal/transport/delivery/delivery.go +++ b/internal/transport/delivery/delivery.go @@ -33,10 +33,6 @@ import ( // be indexed (and so, dropped from queue) // by any of these possible ID IRIs. type Delivery struct { - // PubKeyID is the signing public key - // ID of the actor performing request. - PubKeyID string - // ActorID contains the ActivityPub // actor ID IRI (if any) of the activity // being sent out by this request. @@ -55,7 +51,7 @@ type Delivery struct { // Request is the prepared (+ wrapped) // httpclient.Client{} request that // constitutes this ActivtyPub delivery. - Request httpclient.Request + Request *httpclient.Request // internal fields. next time.Time @@ -66,7 +62,6 @@ type Delivery struct { // a json serialize / deserialize // able shape that minimizes data. type delivery struct { - PubKeyID string `json:"pub_key_id,omitempty"` ActorID string `json:"actor_id,omitempty"` ObjectID string `json:"object_id,omitempty"` TargetID string `json:"target_id,omitempty"` @@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) { // Marshal as internal JSON type. return json.Marshal(delivery{ - PubKeyID: dlv.PubKeyID, ActorID: dlv.ActorID, ObjectID: dlv.ObjectID, TargetID: dlv.TargetID, @@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error { } // Copy over simplest fields. - dlv.PubKeyID = idlv.PubKeyID dlv.ActorID = idlv.ActorID dlv.ObjectID = idlv.ObjectID dlv.TargetID = idlv.TargetID @@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error { return err } + // Copy over any stored header values. + for key, values := range idlv.Header { + for _, value := range values { + r.Header.Add(key, value) + } + } + // Wrap request in httpclient type. dlv.Request = httpclient.WrapRequest(r) diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go index e9eaf8fd1..81f32d5f8 100644 --- a/internal/transport/delivery/delivery_test.go +++ b/internal/transport/delivery/delivery_test.go @@ -35,32 +35,30 @@ var deliveryCases = []struct { }{ { msg: delivery.Delivery{ - PubKeyID: "https://google.com/users/bigboy#pubkey", ActorID: "https://google.com/users/bigboy", ObjectID: "https://google.com/users/bigboy/follow/1", TargetID: "https://askjeeves.com/users/smallboy", - Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")), + Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}), }, data: toJSON(map[string]any{ - "pub_key_id": "https://google.com/users/bigboy#pubkey", - "actor_id": "https://google.com/users/bigboy", - "object_id": "https://google.com/users/bigboy/follow/1", - "target_id": "https://askjeeves.com/users/smallboy", - "method": "POST", - "url": "https://askjeeves.com/users/smallboy/inbox", - "body": []byte("data!"), - // "header": map[string][]string{}, + "actor_id": "https://google.com/users/bigboy", + "object_id": "https://google.com/users/bigboy/follow/1", + "target_id": "https://askjeeves.com/users/smallboy", + "method": "POST", + "url": "https://askjeeves.com/users/smallboy/inbox", + "body": []byte("data!"), + "header": map[string][]string{"Hello": {"world1", "world2"}}, }), }, { msg: delivery.Delivery{ - Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")), + Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil), }, data: toJSON(map[string]any{ "method": "GET", "url": "https://google.com", "body": []byte("uwu im just a wittle seawch engwin"), - // "header": map[string][]string{}, + // "header": map[string][]string{}, }), }, } @@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) { } // Check that delivery fields are as expected. - assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID) assert.Equal(t, test.msg.ActorID, msg.ActorID) assert.Equal(t, test.msg.ObjectID, msg.ObjectID) assert.Equal(t, test.msg.TargetID, msg.TargetID) assert.Equal(t, test.msg.Request.Method, msg.Request.Method) assert.Equal(t, test.msg.Request.URL, msg.Request.URL) assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body)) + assert.Equal(t, test.msg.Request.Header, msg.Request.Header) } } // toRequest creates httpclient.Request from HTTP method, URL and body data. -func toRequest(method string, url string, body []byte) httpclient.Request { +func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request { var rbody io.Reader if body != nil { rbody = bytes.NewReader(body) @@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request { if err != nil { panic(err) } + for key, values := range hdr { + for _, value := range values { + req.Header.Add(key, value) + } + } return httpclient.WrapRequest(req) } diff --git a/internal/transport/delivery/worker.go b/internal/transport/delivery/worker.go index ef31e94a6..d6d253769 100644 --- a/internal/transport/delivery/worker.go +++ b/internal/transport/delivery/worker.go @@ -19,6 +19,7 @@ package delivery import ( "context" + "errors" "slices" "time" @@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool { loop: for { + // Before trying to get + // next delivery, check + // context still valid. + if ctx.Err() != nil { + return true + } + // Get next delivery. dlv, ok := w.next(ctx) if !ok { @@ -195,16 +203,30 @@ loop: // Attempt delivery of AP request. rsp, retry, err := w.Client.DoOnce( - &dlv.Request, + dlv.Request, ) - if err == nil { + switch { + case err == nil: // Ensure body closed. _ = rsp.Body.Close() continue loop - } - if !retry { + case errors.Is(err, context.Canceled) && + ctx.Err() != nil: + // In the case of our own context + // being cancelled, push delivery + // back onto queue for persisting. + // + // Note we specifically check against + // context.Canceled here as it will + // be faster than the mutex lock of + // ctx.Err(), so gives an initial + // faster check in the if-clause. + w.Queue.Push(dlv) + continue loop + + case !retry: // Drop deliveries when no // retry requested, or they // reached max (either). @@ -222,42 +244,36 @@ loop: // next gets the next available delivery, blocking until available if necessary. func (w *Worker) next(ctx context.Context) (*Delivery, bool) { -loop: - for { - // Try pop next queued. - dlv, ok := w.Queue.Pop() + // Try a fast-pop of queued + // delivery before anything. + dlv, ok := w.Queue.Pop() - if !ok { - // Check the backlog. - if len(w.backlog) > 0 { + if !ok { + // Check the backlog. + if len(w.backlog) > 0 { - // Sort by 'next' time. - sortDeliveries(w.backlog) + // Sort by 'next' time. + sortDeliveries(w.backlog) - // Pop next delivery. - dlv := w.popBacklog() + // Pop next delivery. + dlv := w.popBacklog() - return dlv, true - } - - select { - // Backlog is empty, we MUST - // block until next enqueued. - case <-w.Queue.Wait(): - continue loop + return dlv, true + } - // Worker was stopped. - case <-ctx.Done(): - return nil, false - } + // Block on next delivery push + // OR worker context canceled. + dlv, ok = w.Queue.PopCtx(ctx) + if !ok { + return nil, false } + } - // Replace request context for worker state canceling. - ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) - dlv.Request.Request = dlv.Request.Request.WithContext(ctx) + // Replace request context for worker state canceling. + ctx = gtscontext.WithValues(ctx, dlv.Request.Context()) + dlv.Request.Request = dlv.Request.Request.WithContext(ctx) - return dlv, true - } + return dlv, true } // popBacklog pops next available from the backlog. |
