summaryrefslogtreecommitdiff
path: root/vendor/github.com/gin-contrib
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gin-contrib')
-rw-r--r--vendor/github.com/gin-contrib/cors/config.go44
-rw-r--r--vendor/github.com/gin-contrib/cors/cors.go32
2 files changed, 55 insertions, 21 deletions
diff --git a/vendor/github.com/gin-contrib/cors/config.go b/vendor/github.com/gin-contrib/cors/config.go
index 427cfc00b..8a295e3db 100644
--- a/vendor/github.com/gin-contrib/cors/config.go
+++ b/vendor/github.com/gin-contrib/cors/config.go
@@ -8,14 +8,15 @@ import (
)
type cors struct {
- allowAllOrigins bool
- allowCredentials bool
- allowOriginFunc func(string) bool
- allowOrigins []string
- normalHeaders http.Header
- preflightHeaders http.Header
- wildcardOrigins [][]string
- optionsResponseStatusCode int
+ allowAllOrigins bool
+ allowCredentials bool
+ allowOriginFunc func(string) bool
+ allowOriginWithContextFunc func(*gin.Context, string) bool
+ allowOrigins []string
+ normalHeaders http.Header
+ preflightHeaders http.Header
+ wildcardOrigins [][]string
+ optionsResponseStatusCode int
}
var (
@@ -54,14 +55,15 @@ func newCors(config Config) *cors {
}
return &cors{
- allowOriginFunc: config.AllowOriginFunc,
- allowAllOrigins: config.AllowAllOrigins,
- allowCredentials: config.AllowCredentials,
- allowOrigins: normalize(config.AllowOrigins),
- normalHeaders: generateNormalHeaders(config),
- preflightHeaders: generatePreflightHeaders(config),
- wildcardOrigins: config.parseWildcardRules(),
- optionsResponseStatusCode: config.OptionsResponseStatusCode,
+ allowOriginFunc: config.AllowOriginFunc,
+ allowOriginWithContextFunc: config.AllowOriginWithContextFunc,
+ allowAllOrigins: config.AllowAllOrigins,
+ allowCredentials: config.AllowCredentials,
+ allowOrigins: normalize(config.AllowOrigins),
+ normalHeaders: generateNormalHeaders(config),
+ preflightHeaders: generatePreflightHeaders(config),
+ wildcardOrigins: config.parseWildcardRules(),
+ optionsResponseStatusCode: config.OptionsResponseStatusCode,
}
}
@@ -79,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) {
return
}
- if !cors.validateOrigin(origin) {
+ if !cors.isOriginValid(c, origin) {
c.AbortWithStatus(http.StatusForbidden)
return
}
@@ -112,6 +114,14 @@ func (cors *cors) validateWildcardOrigin(origin string) bool {
return false
}
+func (cors *cors) isOriginValid(c *gin.Context, origin string) bool {
+ valid := cors.validateOrigin(origin)
+ if !valid && cors.allowOriginWithContextFunc != nil {
+ valid = cors.allowOriginWithContextFunc(c, origin)
+ }
+ return valid
+}
+
func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
diff --git a/vendor/github.com/gin-contrib/cors/cors.go b/vendor/github.com/gin-contrib/cors/cors.go
index b32522277..2261df759 100644
--- a/vendor/github.com/gin-contrib/cors/cors.go
+++ b/vendor/github.com/gin-contrib/cors/cors.go
@@ -2,6 +2,7 @@ package cors
import (
"errors"
+ "fmt"
"strings"
"time"
@@ -22,6 +23,12 @@ type Config struct {
// set, the content of AllowOrigins is ignored.
AllowOriginFunc func(origin string) bool
+ // Same as AllowOriginFunc except also receives the full request context.
+ // This function should use the context as a read only source and not
+ // have any side effects on the request, such as aborting or injecting
+ // values on the request.
+ AllowOriginWithContextFunc func(c *gin.Context, origin string) bool
+
// AllowMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS)
AllowMethods []string
@@ -51,6 +58,9 @@ type Config struct {
// Allows usage of popular browser extensions schemas
AllowBrowserExtensions bool
+ // Allows to add custom schema like tauri://
+ CustomSchemas []string
+
// Allows usage of WebSocket protocol
AllowWebSockets bool
@@ -87,6 +97,9 @@ func (c Config) getAllowedSchemas() []string {
if c.AllowFiles {
allowedSchemas = append(allowedSchemas, FileSchemas...)
}
+ if c.CustomSchemas != nil {
+ allowedSchemas = append(allowedSchemas, c.CustomSchemas...)
+ }
return allowedSchemas
}
@@ -102,10 +115,21 @@ func (c Config) validateAllowedSchemas(origin string) bool {
// Validate is check configuration of user defined.
func (c Config) Validate() error {
- if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
- return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed")
+ hasOriginFn := c.AllowOriginFunc != nil
+ hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil
+
+ if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) {
+ originFields := strings.Join([]string{
+ "AllowOriginFunc",
+ "AllowOriginFuncWithContext",
+ "AllowOrigins",
+ }, " or ")
+ return fmt.Errorf(
+ "conflict settings: all origins enabled. %s is not needed",
+ originFields,
+ )
}
- if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
+ if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
@@ -138,7 +162,7 @@ func (c Config) parseWildcardRules() [][]string {
continue
}
if i == (len(o) - 1) {
- wRules = append(wRules, []string{o[:i-1], "*"})
+ wRules = append(wRules, []string{o[:i], "*"})
continue
}