diff options
Diffstat (limited to 'vendor/github.com/gin-contrib/cors/cors.go')
-rw-r--r-- | vendor/github.com/gin-contrib/cors/cors.go | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/vendor/github.com/gin-contrib/cors/cors.go b/vendor/github.com/gin-contrib/cors/cors.go index b32522277..2261df759 100644 --- a/vendor/github.com/gin-contrib/cors/cors.go +++ b/vendor/github.com/gin-contrib/cors/cors.go @@ -2,6 +2,7 @@ package cors import ( "errors" + "fmt" "strings" "time" @@ -22,6 +23,12 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool + // Same as AllowOriginFunc except also receives the full request context. + // This function should use the context as a read only source and not + // have any side effects on the request, such as aborting or injecting + // values on the request. + AllowOriginWithContextFunc func(c *gin.Context, origin string) bool + // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS) AllowMethods []string @@ -51,6 +58,9 @@ type Config struct { // Allows usage of popular browser extensions schemas AllowBrowserExtensions bool + // Allows to add custom schema like tauri:// + CustomSchemas []string + // Allows usage of WebSocket protocol AllowWebSockets bool @@ -87,6 +97,9 @@ func (c Config) getAllowedSchemas() []string { if c.AllowFiles { allowedSchemas = append(allowedSchemas, FileSchemas...) } + if c.CustomSchemas != nil { + allowedSchemas = append(allowedSchemas, c.CustomSchemas...) + } return allowedSchemas } @@ -102,10 +115,21 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { - return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") + hasOriginFn := c.AllowOriginFunc != nil + hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil + + if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) { + originFields := strings.Join([]string{ + "AllowOriginFunc", + "AllowOriginFuncWithContext", + "AllowOrigins", + }, " or ") + return fmt.Errorf( + "conflict settings: all origins enabled. %s is not needed", + originFields, + ) } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } for _, origin := range c.AllowOrigins { @@ -138,7 +162,7 @@ func (c Config) parseWildcardRules() [][]string { continue } if i == (len(o) - 1) { - wRules = append(wRules, []string{o[:i-1], "*"}) + wRules = append(wRules, []string{o[:i], "*"}) continue } |