diff options
Diffstat (limited to 'internal/api/util/negotiate.go')
-rw-r--r-- | internal/api/util/negotiate.go | 56 |
1 files changed, 55 insertions, 1 deletions
diff --git a/internal/api/util/negotiate.go b/internal/api/util/negotiate.go index 8e7f41134..6d68a0df3 100644 --- a/internal/api/util/negotiate.go +++ b/internal/api/util/negotiate.go @@ -20,6 +20,7 @@ package util import ( "errors" "fmt" + "strings" "github.com/gin-gonic/gin" ) @@ -108,10 +109,63 @@ func NegotiateAccept(c *gin.Context, offers ...MIME) (string, error) { return strings[0], nil } - format := c.NegotiateFormat(strings...) + format := NegotiateFormat(c, strings...) if format == "" { return "", fmt.Errorf("no format can be offered for requested Accept header(s) %s; this endpoint offers %s", accepts, offers) } return format, nil } + +// This is the exact same thing as gin.Context.NegotiateFormat except it contains +// tsmethurst's fix to make it work properly with multiple accept headers. +// +// https://github.com/gin-gonic/gin/pull/3156 +func NegotiateFormat(c *gin.Context, offered ...string) string { + if len(offered) == 0 { + panic("you must provide at least one offer") + } + + if c.Accepted == nil { + for _, a := range c.Request.Header.Values("Accept") { + c.Accepted = append(c.Accepted, parseAccept(a)...) + } + } + if len(c.Accepted) == 0 { + return offered[0] + } + for _, accepted := range c.Accepted { + for _, offer := range offered { + // According to RFC 2616 and RFC 2396, non-ASCII characters are not allowed in headers, + // therefore we can just iterate over the string without casting it into []rune + i := 0 + for ; i < len(accepted); i++ { + if accepted[i] == '*' || offer[i] == '*' { + return offer + } + if accepted[i] != offer[i] { + break + } + } + if i == len(accepted) { + return offer + } + } + } + return "" +} + +// https://github.com/gin-gonic/gin/blob/4787b8203b79012877ac98d7806422da3a678ba2/utils.go#L103 +func parseAccept(acceptHeader string) []string { + parts := strings.Split(acceptHeader, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if i := strings.IndexByte(part, ';'); i > 0 { + part = part[:i] + } + if part = strings.TrimSpace(part); part != "" { + out = append(out, part) + } + } + return out +} |