diff options
Diffstat (limited to 'vendor/github.com/gin-contrib')
-rw-r--r-- | vendor/github.com/gin-contrib/cors/config.go | 44 | ||||
-rw-r--r-- | vendor/github.com/gin-contrib/cors/cors.go | 32 |
2 files changed, 55 insertions, 21 deletions
diff --git a/vendor/github.com/gin-contrib/cors/config.go b/vendor/github.com/gin-contrib/cors/config.go index 427cfc00b..8a295e3db 100644 --- a/vendor/github.com/gin-contrib/cors/config.go +++ b/vendor/github.com/gin-contrib/cors/config.go @@ -8,14 +8,15 @@ import ( ) type cors struct { - allowAllOrigins bool - allowCredentials bool - allowOriginFunc func(string) bool - allowOrigins []string - normalHeaders http.Header - preflightHeaders http.Header - wildcardOrigins [][]string - optionsResponseStatusCode int + allowAllOrigins bool + allowCredentials bool + allowOriginFunc func(string) bool + allowOriginWithContextFunc func(*gin.Context, string) bool + allowOrigins []string + normalHeaders http.Header + preflightHeaders http.Header + wildcardOrigins [][]string + optionsResponseStatusCode int } var ( @@ -54,14 +55,15 @@ func newCors(config Config) *cors { } return &cors{ - allowOriginFunc: config.AllowOriginFunc, - allowAllOrigins: config.AllowAllOrigins, - allowCredentials: config.AllowCredentials, - allowOrigins: normalize(config.AllowOrigins), - normalHeaders: generateNormalHeaders(config), - preflightHeaders: generatePreflightHeaders(config), - wildcardOrigins: config.parseWildcardRules(), - optionsResponseStatusCode: config.OptionsResponseStatusCode, + allowOriginFunc: config.AllowOriginFunc, + allowOriginWithContextFunc: config.AllowOriginWithContextFunc, + allowAllOrigins: config.AllowAllOrigins, + allowCredentials: config.AllowCredentials, + allowOrigins: normalize(config.AllowOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), + wildcardOrigins: config.parseWildcardRules(), + optionsResponseStatusCode: config.OptionsResponseStatusCode, } } @@ -79,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) { + if !cors.isOriginValid(c, origin) { c.AbortWithStatus(http.StatusForbidden) return } @@ -112,6 +114,14 @@ func (cors *cors) validateWildcardOrigin(origin string) bool { return false } +func (cors *cors) isOriginValid(c *gin.Context, origin string) bool { + valid := cors.validateOrigin(origin) + if !valid && cors.allowOriginWithContextFunc != nil { + valid = cors.allowOriginWithContextFunc(c, origin) + } + return valid +} + func (cors *cors) validateOrigin(origin string) bool { if cors.allowAllOrigins { return true diff --git a/vendor/github.com/gin-contrib/cors/cors.go b/vendor/github.com/gin-contrib/cors/cors.go index b32522277..2261df759 100644 --- a/vendor/github.com/gin-contrib/cors/cors.go +++ b/vendor/github.com/gin-contrib/cors/cors.go @@ -2,6 +2,7 @@ package cors import ( "errors" + "fmt" "strings" "time" @@ -22,6 +23,12 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool + // Same as AllowOriginFunc except also receives the full request context. + // This function should use the context as a read only source and not + // have any side effects on the request, such as aborting or injecting + // values on the request. + AllowOriginWithContextFunc func(c *gin.Context, origin string) bool + // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS) AllowMethods []string @@ -51,6 +58,9 @@ type Config struct { // Allows usage of popular browser extensions schemas AllowBrowserExtensions bool + // Allows to add custom schema like tauri:// + CustomSchemas []string + // Allows usage of WebSocket protocol AllowWebSockets bool @@ -87,6 +97,9 @@ func (c Config) getAllowedSchemas() []string { if c.AllowFiles { allowedSchemas = append(allowedSchemas, FileSchemas...) } + if c.CustomSchemas != nil { + allowedSchemas = append(allowedSchemas, c.CustomSchemas...) + } return allowedSchemas } @@ -102,10 +115,21 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { - return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") + hasOriginFn := c.AllowOriginFunc != nil + hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil + + if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) { + originFields := strings.Join([]string{ + "AllowOriginFunc", + "AllowOriginFuncWithContext", + "AllowOrigins", + }, " or ") + return fmt.Errorf( + "conflict settings: all origins enabled. %s is not needed", + originFields, + ) } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } for _, origin := range c.AllowOrigins { @@ -138,7 +162,7 @@ func (c Config) parseWildcardRules() [][]string { continue } if i == (len(o) - 1) { - wRules = append(wRules, []string{o[:i-1], "*"}) + wRules = append(wRules, []string{o[:i], "*"}) continue } |