diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/transport/deliver.go | 12 | ||||
| -rw-r--r-- | internal/transport/dereference.go | 2 | ||||
| -rw-r--r-- | internal/transport/derefinstance.go | 6 | ||||
| -rw-r--r-- | internal/transport/derefmedia.go | 2 | ||||
| -rw-r--r-- | internal/transport/finger.go | 2 | ||||
| -rw-r--r-- | internal/transport/transport.go | 30 | 
6 files changed, 32 insertions, 22 deletions
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index 476152c10..7db3bf338 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -19,7 +19,6 @@  package transport  import ( -	"bytes"  	"context"  	"fmt"  	"net/http" @@ -27,6 +26,7 @@ import (  	"strings"  	"sync" +	"codeberg.org/gruf/go-byteutil"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"  	"github.com/superseriousbusiness/gotosocial/internal/config"  ) @@ -49,7 +49,7 @@ func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*ur  	wg.Wait()  	// receive any buffered errors -	errs := make([]string, 0, len(recipients)) +	errs := make([]string, 0, len(errCh))  outer:  	for {  		select { @@ -75,7 +75,11 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {  	urlStr := to.String() -	req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b)) +	// Use rewindable bytes reader for body. +	var body byteutil.ReadNopCloser +	body.Reset(b) + +	req, err := http.NewRequestWithContext(ctx, "POST", urlStr, &body)  	if err != nil {  		return err  	} @@ -92,7 +96,7 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {  	if code := resp.StatusCode; code != http.StatusOK &&  		code != http.StatusCreated && code != http.StatusAccepted { -		return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status) +		return fmt.Errorf("POST request to %s failed: %s", urlStr, resp.Status)  	}  	return nil diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go index 37bdbd58f..f42f146ea 100644 --- a/internal/transport/dereference.go +++ b/internal/transport/dereference.go @@ -78,6 +78,6 @@ func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, erro  	case http.StatusGone:  		return nil, ErrGone  	default: -		return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status) +		return nil, fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status)  	}  } diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go index e46b52554..2dcf367a0 100644 --- a/internal/transport/derefinstance.go +++ b/internal/transport/derefinstance.go @@ -102,7 +102,7 @@ func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL)  	defer resp.Body.Close()  	if resp.StatusCode != http.StatusOK { -		return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) +		return nil, fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status)  	}  	b, err := io.ReadAll(resp.Body) @@ -252,7 +252,7 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur  	defer resp.Body.Close()  	if resp.StatusCode != http.StatusOK { -		return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) +		return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed: %s", iriStr, resp.Status)  	}  	b, err := io.ReadAll(resp.Body) @@ -303,7 +303,7 @@ func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.No  	defer resp.Body.Close()  	if resp.StatusCode != http.StatusOK { -		return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) +		return nil, fmt.Errorf("callNodeInfo: GET request to %s failed: %s", iriStr, resp.Status)  	}  	b, err := io.ReadAll(resp.Body) diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go index 5edfc0e44..c8a817eef 100644 --- a/internal/transport/derefmedia.go +++ b/internal/transport/derefmedia.go @@ -46,7 +46,7 @@ func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.Read  	// Check for an expected status code  	if rsp.StatusCode != http.StatusOK { -		return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status) +		return nil, 0, fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status)  	}  	return rsp.Body, rsp.ContentLength, nil diff --git a/internal/transport/finger.go b/internal/transport/finger.go index 1e52a59f2..4e6594df4 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -52,7 +52,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom  	// Check for an expected status code  	if rsp.StatusCode != http.StatusOK { -		return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status) +		return nil, fmt.Errorf("GET request to %s failed: %s", urlStr, rsp.Status)  	}  	return io.ReadAll(rsp.Body) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 18c40f79f..b0f68f707 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -32,6 +32,7 @@ import (  	"sync"  	"time" +	"codeberg.org/gruf/go-byteutil"  	errorsv2 "codeberg.org/gruf/go-errors/v2"  	"codeberg.org/gruf/go-kv"  	"github.com/go-fed/httpsig" @@ -84,7 +85,7 @@ type transport struct {  	signerMu   sync.Mutex  } -// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +// GET will perform given http request using transport client, retrying on certain preset errors.  func (t *transport) GET(r *http.Request) (*http.Response, error) {  	if r.Method != http.MethodGet {  		return nil, errors.New("must be GET request") @@ -94,7 +95,7 @@ func (t *transport) GET(r *http.Request) (*http.Response, error) {  	})  } -// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +// POST will perform given http request using transport client, retrying on certain preset errors.  func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) {  	if r.Method != http.MethodPost {  		return nil, errors.New("must be POST request") @@ -116,18 +117,17 @@ func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http  	// Get request hostname  	host := r.URL.Hostname() -	// Check if recently reached max retries for this host -	// so we don't need to bother reattempting it. The only -	// errors that are retried upon are server failure and -	// domain resolution type errors, so this cached result -	// indicates this server is likely having issues. -	if t.controller.badHosts.Has(host) { -		return nil, errors.New("too many failed attempts") -	} -  	// Check whether request should fast fail, we check this  	// before loop as each context.Value() requires mutex lock.  	fastFail := IsFastfail(r.Context()) +	if !fastFail { +		// Check if recently reached max retries for this host +		// so we don't bother with a retry-backoff loop. The only +		// errors that are retried upon are server failure and +		// domain resolution type errors, so this cached result +		// indicates this server is likely having issues. +		fastFail = t.controller.badHosts.Has(host) +	}  	// Start a log entry for this request  	l := log.WithContext(r.Context()). @@ -148,6 +148,12 @@ func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http  		r.Header.Del("Signature")  		r.Header.Del("Digest") +		// Rewind body reader and content-length if set. +		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { +			r.ContentLength = int64(rc.Len()) +			rc.Rewind() +		} +  		// Perform request signing  		if err := signer(r); err != nil {  			return nil, err @@ -226,7 +232,7 @@ func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http  		}  	} -	// Add "bad" entry for this host +	// Add "bad" entry for this host.  	t.controller.badHosts.Set(host, struct{}{})  	return nil, errors.New("transport reached max retries")  | 
