summaryrefslogtreecommitdiff
path: root/internal/transport/deliver.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/transport/deliver.go')
-rw-r--r--internal/transport/deliver.go111
1 files changed, 73 insertions, 38 deletions
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index 8ec939503..fff7dbcf4 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -22,7 +22,6 @@ import (
"fmt"
"net/http"
"net/url"
- "strings"
"sync"
"codeberg.org/gruf/go-byteutil"
@@ -32,54 +31,90 @@ import (
)
func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error {
- // concurrently deliver to recipients; for each delivery, buffer the error if it fails
- wg := sync.WaitGroup{}
- errCh := make(chan error, len(recipients))
- for _, recipient := range recipients {
- wg.Add(1)
- go func(r *url.URL) {
- defer wg.Done()
- if err := t.Deliver(ctx, b, r); err != nil {
- errCh <- err
+ var (
+ // errs accumulates errors received during
+ // attempted delivery by deliverer routines.
+ errs gtserror.MultiError
+
+ // wait blocks until all sender
+ // routines have returned.
+ wait sync.WaitGroup
+
+ // mutex protects 'recipients' and
+ // 'errs' for concurrent access.
+ mutex sync.Mutex
+
+ // Get current instance host info.
+ domain = config.GetAccountDomain()
+ host = config.GetHost()
+ )
+
+ // Block on expect no. senders.
+ wait.Add(t.controller.senders)
+
+ for i := 0; i < t.controller.senders; i++ {
+ go func() {
+ // Mark returned.
+ defer wait.Done()
+
+ for {
+ // Acquire lock.
+ mutex.Lock()
+
+ if len(recipients) == 0 {
+ // Reached end.
+ mutex.Unlock()
+ return
+ }
+
+ // Pop next recipient.
+ i := len(recipients) - 1
+ to := recipients[i]
+ recipients = recipients[:i]
+
+ // Done with lock.
+ mutex.Unlock()
+
+ // Skip delivery to recipient if it is "us".
+ if to.Host == host || to.Host == domain {
+ continue
+ }
+
+ // Attempt to deliver data to recipient.
+ if err := t.deliver(ctx, b, to); err != nil {
+ mutex.Lock() // safely append err to accumulator.
+ errs.Appendf("error delivering to %s: %v", to, err)
+ mutex.Unlock()
+ }
}
- }(recipient)
+ }()
}
- // wait until all deliveries have succeeded or failed
- wg.Wait()
-
- // receive any buffered errors
- errs := make([]string, 0, len(errCh))
-outer:
- for {
- select {
- case e := <-errCh:
- errs = append(errs, e.Error())
- default:
- break outer
- }
- }
-
- if len(errs) > 0 {
- return fmt.Errorf("BatchDeliver: at least one failure: %s", strings.Join(errs, "; "))
- }
+ // Wait for finish.
+ wait.Wait()
- return nil
+ // Return combined err.
+ return errs.Combine()
}
func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
- // if the 'to' host is our own, just skip this delivery since we by definition already have the message!
+ // if 'to' host is our own, skip as we don't need to deliver to ourselves...
if to.Host == config.GetHost() || to.Host == config.GetAccountDomain() {
return nil
}
- urlStr := to.String()
+ // Deliver data to recipient.
+ return t.deliver(ctx, b, to)
+}
+
+func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error {
+ url := to.String()
// Use rewindable bytes reader for body.
var body byteutil.ReadNopCloser
body.Reset(b)
- req, err := http.NewRequestWithContext(ctx, "POST", urlStr, &body)
+ req, err := http.NewRequestWithContext(ctx, "POST", url, &body)
if err != nil {
return err
}
@@ -88,16 +123,16 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
req.Header.Add("Accept-Charset", "utf-8")
req.Header.Set("Host", to.Host)
- resp, err := t.POST(req, b)
+ rsp, err := t.POST(req, b)
if err != nil {
return err
}
- defer resp.Body.Close()
+ defer rsp.Body.Close()
- if code := resp.StatusCode; code != http.StatusOK &&
+ if code := rsp.StatusCode; code != http.StatusOK &&
code != http.StatusCreated && code != http.StatusAccepted {
- err := fmt.Errorf("POST request to %s failed: %s", urlStr, resp.Status)
- return gtserror.WithStatusCode(err, resp.StatusCode)
+ err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status)
+ return gtserror.WithStatusCode(err, rsp.StatusCode)
}
return nil