diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/util/negotiate.go | 56 | ||||
| -rw-r--r-- | internal/api/util/negotiate_test.go | 65 | ||||
| -rw-r--r-- | internal/transport/finger.go | 7 | ||||
| -rw-r--r-- | internal/web/profile.go | 2 | ||||
| -rw-r--r-- | internal/web/thread.go | 2 | 
5 files changed, 127 insertions, 5 deletions
diff --git a/internal/api/util/negotiate.go b/internal/api/util/negotiate.go index 8e7f41134..6d68a0df3 100644 --- a/internal/api/util/negotiate.go +++ b/internal/api/util/negotiate.go @@ -20,6 +20,7 @@ package util  import (  	"errors"  	"fmt" +	"strings"  	"github.com/gin-gonic/gin"  ) @@ -108,10 +109,63 @@ func NegotiateAccept(c *gin.Context, offers ...MIME) (string, error) {  		return strings[0], nil  	} -	format := c.NegotiateFormat(strings...) +	format := NegotiateFormat(c, strings...)  	if format == "" {  		return "", fmt.Errorf("no format can be offered for requested Accept header(s) %s; this endpoint offers %s", accepts, offers)  	}  	return format, nil  } + +// This is the exact same thing as gin.Context.NegotiateFormat except it contains +// tsmethurst's fix to make it work properly with multiple accept headers. +// +// https://github.com/gin-gonic/gin/pull/3156 +func NegotiateFormat(c *gin.Context, offered ...string) string { +	if len(offered) == 0 { +		panic("you must provide at least one offer") +	} + +	if c.Accepted == nil { +		for _, a := range c.Request.Header.Values("Accept") { +			c.Accepted = append(c.Accepted, parseAccept(a)...) +		} +	} +	if len(c.Accepted) == 0 { +		return offered[0] +	} +	for _, accepted := range c.Accepted { +		for _, offer := range offered { +			// According to RFC 2616 and RFC 2396, non-ASCII characters are not allowed in headers, +			// therefore we can just iterate over the string without casting it into []rune +			i := 0 +			for ; i < len(accepted); i++ { +				if accepted[i] == '*' || offer[i] == '*' { +					return offer +				} +				if accepted[i] != offer[i] { +					break +				} +			} +			if i == len(accepted) { +				return offer +			} +		} +	} +	return "" +} + +// https://github.com/gin-gonic/gin/blob/4787b8203b79012877ac98d7806422da3a678ba2/utils.go#L103 +func parseAccept(acceptHeader string) []string { +	parts := strings.Split(acceptHeader, ",") +	out := make([]string, 0, len(parts)) +	for _, part := range parts { +		if i := strings.IndexByte(part, ';'); i > 0 { +			part = part[:i] +		} +		if part = strings.TrimSpace(part); part != "" { +			out = append(out, part) +		} +	} +	return out +} diff --git a/internal/api/util/negotiate_test.go b/internal/api/util/negotiate_test.go new file mode 100644 index 000000000..a8b28b55f --- /dev/null +++ b/internal/api/util/negotiate_test.go @@ -0,0 +1,65 @@ +package util + +import ( +	"net/http" +	"net/http/httptest" +	"strings" +	"testing" + +	"github.com/gin-gonic/gin" +) + +type testMIMES []MIME + +func (tm testMIMES) String(t *testing.T) string { +	t.Helper() + +	res := tm.StringS(t) +	return strings.Join(res, ",") +} + +func (tm testMIMES) StringS(t *testing.T) []string { +	t.Helper() + +	res := make([]string, 0, len(tm)) +	for _, m := range tm { +		res = append(res, string(m)) +	} +	return res +} + +func TestNegotiateFormat(t *testing.T) { +	tests := []struct { +		incoming []string +		offered  testMIMES +		format   string +	}{ +		{incoming: testMIMES{AppJSON}.StringS(t), offered: testMIMES{AppJRDJSON, AppJSON}, format: "application/json"}, +		{incoming: testMIMES{AppJRDJSON}.StringS(t), offered: testMIMES{AppJRDJSON, AppJSON}, format: "application/jrd+json"}, +		{incoming: testMIMES{AppJRDJSON, AppJSON}.StringS(t), offered: testMIMES{AppJRDJSON}, format: "application/jrd+json"}, +		{incoming: testMIMES{AppJRDJSON, AppJSON}.StringS(t), offered: testMIMES{AppJSON}, format: "application/json"}, +		{incoming: testMIMES{"text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8"}.StringS(t), offered: testMIMES{AppJSON, AppXML}, format: "application/xml"}, +		{incoming: testMIMES{"text/html,application/xhtml+xml,application/xml;q=0.9;q=0.8"}.StringS(t), offered: testMIMES{TextHTML, AppXML}, format: "text/html"}, +	} + +	for _, tt := range tests { +		name := "incoming:" + strings.Join(tt.incoming, ",") + " offered:" + tt.offered.String(t) +		t.Run(name, func(t *testing.T) { +			tt := tt +			t.Parallel() + +			c, _ := gin.CreateTestContext(httptest.NewRecorder()) +			c.Request = &http.Request{ +				Header: make(http.Header), +			} +			for _, header := range tt.incoming { +				c.Request.Header.Add("accept", header) +			} + +			format := NegotiateFormat(c, tt.offered.StringS(t)...) +			if tt.format != format { +				t.Fatalf("expected format: '%s', got format: '%s'", tt.format, format) +			} +		}) +	} +} diff --git a/internal/transport/finger.go b/internal/transport/finger.go index f106019b5..18b028a64 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -61,8 +61,11 @@ func prepWebfingerReq(ctx context.Context, loc, domain, username string) (*http.  	// Prefer application/jrd+json, fall back to application/json.  	// See https://www.rfc-editor.org/rfc/rfc7033#section-10.2. -	req.Header.Add("Accept", string(apiutil.AppJRDJSON)) -	req.Header.Add("Accept", string(apiutil.AppJSON)) +	// +	// Some implementations don't handle multiple accept headers properly, +	// including Gin itself. So concat the accept header with a comma +	// instead which seems to work reliably +	req.Header.Add("Accept", string(apiutil.AppJRDJSON)+","+string(apiutil.AppJSON))  	req.Header.Set("Host", req.URL.Host)  	return req, nil diff --git a/internal/web/profile.go b/internal/web/profile.go index a4fddbafe..56f8e0a56 100644 --- a/internal/web/profile.go +++ b/internal/web/profile.go @@ -73,7 +73,7 @@ func (m *Module) profileGETHandler(c *gin.Context) {  	// if we're getting an AP request on this endpoint we  	// should render the account's AP representation instead -	accept := c.NegotiateFormat(string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) +	accept := apiutil.NegotiateFormat(c, string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON))  	if accept == string(apiutil.AppActivityJSON) || accept == string(apiutil.AppActivityLDJSON) {  		m.returnAPProfile(ctx, c, username, accept)  		return diff --git a/internal/web/thread.go b/internal/web/thread.go index fe57ddf1f..8d4e99bef 100644 --- a/internal/web/thread.go +++ b/internal/web/thread.go @@ -90,7 +90,7 @@ func (m *Module) threadGETHandler(c *gin.Context) {  	// if we're getting an AP request on this endpoint we  	// should render the status's AP representation instead -	accept := c.NegotiateFormat(string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON)) +	accept := apiutil.NegotiateFormat(c, string(apiutil.TextHTML), string(apiutil.AppActivityJSON), string(apiutil.AppActivityLDJSON))  	if accept == string(apiutil.AppActivityJSON) || accept == string(apiutil.AppActivityLDJSON) {  		m.returnAPStatus(ctx, c, username, statusID, accept)  		return  | 
