diff options
| -rw-r--r-- | internal/gtserror/error.go | 4 | ||||
| -rw-r--r-- | internal/gtserror/new.go | 66 | ||||
| -rw-r--r-- | internal/gtserror/new_test.go | 91 | ||||
| -rw-r--r-- | internal/gtserror/util.go | 42 | ||||
| -rw-r--r-- | internal/httpclient/client.go | 21 | ||||
| -rw-r--r-- | internal/httpclient/validate.go | 62 | ||||
| -rw-r--r-- | internal/transport/deliver.go | 4 | ||||
| -rw-r--r-- | internal/transport/dereference.go | 4 | ||||
| -rw-r--r-- | internal/transport/derefinstance.go | 11 | ||||
| -rw-r--r-- | internal/transport/derefmedia.go | 4 | ||||
| -rw-r--r-- | internal/transport/finger.go | 21 | ||||
| -rw-r--r-- | testrig/transportcontroller.go | 1 | 
12 files changed, 299 insertions, 32 deletions
diff --git a/internal/gtserror/error.go b/internal/gtserror/error.go index 56e546cf1..e68ed7d3b 100644 --- a/internal/gtserror/error.go +++ b/internal/gtserror/error.go @@ -34,8 +34,8 @@ const (  	notFoundKey  	errorTypeKey -	// error types -	TypeSMTP ErrorType = "smtp" // smtp (mail) error +	// Types returnable from Type(...). +	TypeSMTP ErrorType = "smtp" // smtp (mail)  )  // StatusCode checks error for a stored status code value. For example diff --git a/internal/gtserror/new.go b/internal/gtserror/new.go new file mode 100644 index 000000000..ad20e5cac --- /dev/null +++ b/internal/gtserror/new.go @@ -0,0 +1,66 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package gtserror + +import ( +	"errors" +	"net/http" + +	"codeberg.org/gruf/go-byteutil" +) + +// NewResponseError crafts an error from provided HTTP response +// including the method, status and body (if any provided). This +// will also wrap the returned error using WithStatusCode(). +func NewResponseError(rsp *http.Response) error { +	var buf byteutil.Buffer + +	// Get URL string ahead of time. +	urlStr := rsp.Request.URL.String() + +	// Alloc guesstimate of required buf size. +	buf.Guarantee(0 + +		len(rsp.Request.Method) + +		12 + //  request to +		len(urlStr) + +		17 + //  failed: status=" +		len(rsp.Status) + +		8 + // " body=" +		256 + // max body size +		1, // " +	) + +	// Build error message string without +	// using "fmt", as chances are this will +	// be used in a hot code path and we +	// know all the incoming types involved. +	_, _ = buf.WriteString(rsp.Request.Method) +	_, _ = buf.WriteString(" request to ") +	_, _ = buf.WriteString(urlStr) +	_, _ = buf.WriteString(" failed: status=\"") +	_, _ = buf.WriteString(rsp.Status) +	_, _ = buf.WriteString("\" body=\"") +	_, _ = buf.WriteString(drainBody(rsp.Body, 256)) +	_, _ = buf.WriteString("\"") + +	// Create new error from msg. +	err := errors.New(buf.String()) + +	// Wrap error to provide status code. +	return WithStatusCode(err, rsp.StatusCode) +} diff --git a/internal/gtserror/new_test.go b/internal/gtserror/new_test.go new file mode 100644 index 000000000..b0824b5a7 --- /dev/null +++ b/internal/gtserror/new_test.go @@ -0,0 +1,91 @@ +package gtserror_test + +import ( +	"bytes" +	"fmt" +	"io" +	"net/http" +	"net/url" +	"strings" +	"testing" + +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +) + +func TestResponseError(t *testing.T) { +	testResponseError(t, http.Response{ +		Body: toBody(`{"error": "user not found"}`), +		Request: &http.Request{ +			Method: "GET", +			URL:    toURL("https://google.com/users/sundar"), +		}, +		Status: "404 Not Found", +	}) +	testResponseError(t, http.Response{ +		Body: toBody("Unauthorized"), +		Request: &http.Request{ +			Method: "POST", +			URL:    toURL("https://google.com/inbox"), +		}, +		Status: "401 Unauthorized", +	}) +	testResponseError(t, http.Response{ +		Body: toBody(""), +		Request: &http.Request{ +			Method: "GET", +			URL:    toURL("https://google.com/users/sundar"), +		}, +		Status: "404 Not Found", +	}) +} + +func testResponseError(t *testing.T, rsp http.Response) { +	var body string +	if rsp.Body == http.NoBody { +		body = "<empty>" +	} else { +		var b []byte +		rsp.Body, b = copyBody(rsp.Body) +		trunc := len(b) +		if trunc > 256 { +			trunc = 256 +		} +		body = string(b[:trunc]) +	} +	expect := fmt.Sprintf( +		"%s request to %s failed: status=\"%s\" body=\"%s\"", +		rsp.Request.Method, +		rsp.Request.URL.String(), +		rsp.Status, +		body, +	) +	err := gtserror.NewResponseError(&rsp) +	if str := err.Error(); str != expect { +		t.Errorf("unexpected error string: recv=%q expct=%q", str, expect) +	} +} + +func toURL(u string) *url.URL { +	url, err := url.Parse(u) +	if err != nil { +		panic(err) +	} +	return url +} + +func toBody(s string) io.ReadCloser { +	if s == "" { +		return http.NoBody +	} +	r := strings.NewReader(s) +	return io.NopCloser(r) +} + +func copyBody(rc io.ReadCloser) (io.ReadCloser, []byte) { +	b, err := io.ReadAll(rc) +	if err != nil { +		panic(err) +	} +	r := bytes.NewReader(b) +	return io.NopCloser(r), b +} diff --git a/internal/gtserror/util.go b/internal/gtserror/util.go new file mode 100644 index 000000000..635518b76 --- /dev/null +++ b/internal/gtserror/util.go @@ -0,0 +1,42 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package gtserror + +import ( +	"io" + +	"codeberg.org/gruf/go-byteutil" +) + +// drainBody will produce a truncated output of the content +// of given io.ReadCloser body, useful for logs / errors. +func drainBody(body io.ReadCloser, trunc int) string { +	// Limit response to 'trunc' bytes. +	buf := make([]byte, trunc) + +	// Read body into err buffer. +	n, _ := io.ReadFull(body, buf) + +	if n == 0 { +		// No error body, return +		// reasonable error str. +		return "<empty>" +	} + +	return byteutil.B2S(buf[:n]) +} diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index dd1a1bd6b..efbf4cd18 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -41,6 +41,9 @@ import (  )  var ( +	// ErrInvalidRequest is returned if a given HTTP request is invalid and cannot be performed. +	ErrInvalidRequest = errors.New("invalid http request") +  	// ErrInvalidNetwork is returned if the request would not be performed over TCP  	ErrInvalidNetwork = errors.New("invalid network type") @@ -90,6 +93,9 @@ type Config struct {  //     cases to protect against forged / unknown content-lengths  //   - protection from server side request forgery (SSRF) by only dialing  //     out to known public IP prefixes, configurable with allows/blocks +//   - retry-backoff logic for error temporary HTTP error responses +//   - optional request signing +//   - request logging  type Client struct {  	client   http.Client  	badHosts cache.Cache[string, struct{}] @@ -156,14 +162,14 @@ func New(cfg Config) *Client {  	return &c  } -// Do ... +// Do will essentially perform http.Client{}.Do() with retry-backoff functionality.  func (c *Client) Do(r *http.Request) (*http.Response, error) {  	return c.DoSigned(r, func(r *http.Request) error {  		return nil // no request signing  	})  } -// DoSigned ... +// DoSigned will essentially perform http.Client{}.Do() with retry-backoff functionality and requesting signing..  func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, err error) {  	const (  		// max no. attempts. @@ -173,6 +179,11 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e  		baseBackoff = 2 * time.Second  	) +	// First validate incoming request. +	if err := ValidateRequest(r); err != nil { +		return nil, err +	} +  	// Get request hostname.  	host := r.URL.Hostname() @@ -234,8 +245,8 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e  				return rsp, nil  			} -			// Generate error from status code for logging -			err = errors.New(`http response "` + rsp.Status + `"`) +			// Create loggable error from response status code. +			err = fmt.Errorf(`http response: %s`, rsp.Status)  			// Search for a provided "Retry-After" header value.  			if after := rsp.Header.Get("Retry-After"); after != "" { @@ -307,7 +318,7 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e  	return  } -// do ... +// do wraps http.Client{}.Do() to provide safely limited response bodies.  func (c *Client) do(req *http.Request) (*http.Response, error) {  	// Perform the HTTP request.  	rsp, err := c.client.Do(req) diff --git a/internal/httpclient/validate.go b/internal/httpclient/validate.go new file mode 100644 index 000000000..881d3f699 --- /dev/null +++ b/internal/httpclient/validate.go @@ -0,0 +1,62 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package httpclient + +import ( +	"fmt" +	"net/http" +	"strings" + +	"golang.org/x/net/http/httpguts" +) + +// ValidateRequest performs the same request validation logic found in the default +// net/http.Transport{}.roundTrip() function, but pulls it out into this separate +// function allowing validation errors to be wrapped under a single error type. +func ValidateRequest(r *http.Request) error { +	switch { +	case r.URL == nil: +		return fmt.Errorf("%w: nil url", ErrInvalidRequest) +	case r.Header == nil: +		return fmt.Errorf("%w: nil header", ErrInvalidRequest) +	case r.URL.Host == "": +		return fmt.Errorf("%w: empty url host", ErrInvalidRequest) +	case r.URL.Scheme != "http" && r.URL.Scheme != "https": +		return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme) +	case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1: +		return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method) +	} + +	for key, values := range r.Header { +		// Check field key name is valid +		if !httpguts.ValidHeaderFieldName(key) { +			return fmt.Errorf("%w: invalid header field name %q", ErrInvalidRequest, key) +		} + +		// Check each field value is valid +		for i := 0; i < len(values); i++ { +			if !httpguts.ValidHeaderFieldValue(values[i]) { +				return fmt.Errorf("%w: invalid header field value %q", ErrInvalidRequest, values[i]) +			} +		} +	} + +	// ps. kim wrote this + +	return nil +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index fff7dbcf4..054baa6a5 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -19,7 +19,6 @@ package transport  import (  	"context" -	"fmt"  	"net/http"  	"net/url"  	"sync" @@ -131,8 +130,7 @@ func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error {  	if code := rsp.StatusCode; code != http.StatusOK &&  		code != http.StatusCreated && code != http.StatusAccepted { -		err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status) -		return gtserror.WithStatusCode(err, rsp.StatusCode) +		return gtserror.NewResponseError(rsp)  	}  	return nil diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go index 71b10a0f1..e231e0954 100644 --- a/internal/transport/dereference.go +++ b/internal/transport/dereference.go @@ -19,7 +19,6 @@ package transport  import (  	"context" -	"fmt"  	"io"  	"net/http"  	"net/url" @@ -66,8 +65,7 @@ func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, erro  	defer rsp.Body.Close()  	if rsp.StatusCode != http.StatusOK { -		err := fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status) -		return nil, gtserror.WithStatusCode(err, rsp.StatusCode) +		return nil, gtserror.NewResponseError(rsp)  	}  	return io.ReadAll(rsp.Body) diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go index 466981348..c373a140a 100644 --- a/internal/transport/derefinstance.go +++ b/internal/transport/derefinstance.go @@ -102,8 +102,7 @@ func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL)  	defer resp.Body.Close()  	if resp.StatusCode != http.StatusOK { -		err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) -		return nil, gtserror.WithStatusCode(err, resp.StatusCode) +		return nil, gtserror.NewResponseError(resp)  	}  	b, err := io.ReadAll(resp.Body) @@ -133,7 +132,7 @@ func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL)  		ID:                     ulid,  		Domain:                 iri.Host,  		Title:                  apiResp.Title, -		URI:                    fmt.Sprintf("%s://%s", iri.Scheme, iri.Host), +		URI:                    iri.Scheme + "://" + iri.Host,  		ShortDescription:       apiResp.ShortDescription,  		Description:            apiResp.Description,  		ContactEmail:           apiResp.Email, @@ -253,8 +252,7 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur  	defer resp.Body.Close()  	if resp.StatusCode != http.StatusOK { -		err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) -		return nil, gtserror.WithStatusCode(err, resp.StatusCode) +		return nil, gtserror.NewResponseError(resp)  	}  	b, err := io.ReadAll(resp.Body) @@ -305,8 +303,7 @@ func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.No  	defer resp.Body.Close()  	if resp.StatusCode != http.StatusOK { -		err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) -		return nil, gtserror.WithStatusCode(err, resp.StatusCode) +		return nil, gtserror.NewResponseError(resp)  	}  	b, err := io.ReadAll(resp.Body) diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go index 2d9096493..ad47d99b5 100644 --- a/internal/transport/derefmedia.go +++ b/internal/transport/derefmedia.go @@ -19,7 +19,6 @@ package transport  import (  	"context" -	"fmt"  	"io"  	"net/http"  	"net/url" @@ -47,8 +46,7 @@ func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.Read  	// Check for an expected status code  	if rsp.StatusCode != http.StatusOK { -		err := fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status) -		return nil, 0, gtserror.WithStatusCode(err, rsp.StatusCode) +		return nil, 0, gtserror.NewResponseError(rsp)  	}  	return rsp.Body, rsp.ContentLength, nil diff --git a/internal/transport/finger.go b/internal/transport/finger.go index 18b028a64..e6086747b 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -27,6 +27,7 @@ import (  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  )  // webfingerURLFor returns the URL to try a webfinger request against, as @@ -105,14 +106,16 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom  	// From here on out, we're handling different failure scenarios and  	// deciding whether we should do a host-meta based fallback or not -	if (rsp.StatusCode >= 500 && rsp.StatusCode < 600) || cached { -		// In case we got a 5xx, bail out irrespective of if the value -		// was cached or not. The target may be broken or be signalling -		// us to back-off. -		// -		// If it's any error but the URL was cached, bail out too -		return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status) -	} +	// Response status codes >= 500 are returned as errors by the wrapped HTTP client. +	// +	// if (rsp.StatusCode >= 500 && rsp.StatusCode < 600) || cached { +	// In case we got a 5xx, bail out irrespective of if the value +	// was cached or not. The target may be broken or be signalling +	// us to back-off. +	// +	// If it's any error but the URL was cached, bail out too +	// return nil, gtserror.NewResponseError(rsp) +	// }  	// So far we've failed to get a successful response from the expected  	// webfinger endpoint. Lets try and discover the webfinger endpoint @@ -153,7 +156,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom  		}  		// We've reached the end of the line here, both the original request  		// and our attempt to resolve it through the fallback have failed -		return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status) +		return nil, gtserror.NewResponseError(rsp)  	}  	// Set the URL in cache here, since host-meta told us this should be the diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index b74888934..1c75e1974 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -209,6 +209,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat  		reader := bytes.NewReader(responseBytes)  		readCloser := io.NopCloser(reader)  		return &http.Response{ +			Request:       req,  			StatusCode:    responseCode,  			Body:          readCloser,  			ContentLength: int64(responseContentLength),  | 
