diff options
author | 2023-12-18 14:18:25 +0000 | |
---|---|---|
committer | 2023-12-18 14:18:25 +0000 | |
commit | 8ebb7775a35b632d49a8f294d83ac786666631f3 (patch) | |
tree | 02ac5475274125170132b0a4d9f69bd67491a32c /internal/middleware/headerfilter.go | |
parent | fix poll total vote double count (#2464) (diff) | |
download | gotosocial-8ebb7775a35b632d49a8f294d83ac786666631f3.tar.xz |
[feature] request blocking by http headers (#2409)
Diffstat (limited to 'internal/middleware/headerfilter.go')
-rw-r--r-- | internal/middleware/headerfilter.go | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/internal/middleware/headerfilter.go b/internal/middleware/headerfilter.go new file mode 100644 index 000000000..18c9d1e67 --- /dev/null +++ b/internal/middleware/headerfilter.go @@ -0,0 +1,251 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package middleware + +import ( + "sync" + + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/headerfilter" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" +) + +var ( + allowMatches = matchstats{m: make(map[string]uint64)} + blockMatches = matchstats{m: make(map[string]uint64)} +) + +// matchstats is a simple statistics +// counter for header filter matches. +// TODO: replace with otel. +type matchstats struct { + m map[string]uint64 + l sync.Mutex +} + +func (m *matchstats) Add(hdr, regex string) { + m.l.Lock() + key := hdr + ":" + regex + m.m[key]++ + m.l.Unlock() +} + +// HeaderFilter returns a gin middleware handler that provides HTTP +// request blocking (filtering) based on database allow / block filters. +func HeaderFilter(state *state.State) gin.HandlerFunc { + switch mode := config.GetAdvancedHeaderFilterMode(); mode { + case config.RequestHeaderFilterModeDisabled: + return func(ctx *gin.Context) {} + + case config.RequestHeaderFilterModeAllow: + return headerFilterAllowMode(state) + + case config.RequestHeaderFilterModeBlock: + return headerFilterBlockMode(state) + + default: + panic("unrecognized filter mode: " + mode) + } +} + +func headerFilterAllowMode(state *state.State) func(c *gin.Context) { + _ = *state //nolint + // Allowlist mode: explicit block takes + // precedence over explicit allow. + // + // Headers that have neither block + // or allow entries are blocked. + return func(c *gin.Context) { + + // Check if header is explicitly blocked. + block, err := isHeaderBlocked(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if block { + respondBlocked(c) + return + } + + // Check if header is missing explicit allow. + notAllow, err := isHeaderNotAllowed(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if notAllow { + respondBlocked(c) + return + } + + // Allowed! + c.Next() + } +} + +func headerFilterBlockMode(state *state.State) func(c *gin.Context) { + _ = *state //nolint + // Blocklist/default mode: explicit allow + // takes precedence over explicit block. + // + // Headers that have neither block + // or allow entries are allowed. + return func(c *gin.Context) { + + // Check if header is explicitly allowed. + allow, err := isHeaderAllowed(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if !allow { + // Check if header is explicitly blocked. + block, err := isHeaderBlocked(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if block { + respondBlocked(c) + return + } + } + + // Allowed! + c.Next() + } +} + +func isHeaderBlocked(state *state.State, c *gin.Context) (bool, error) { + var ( + ctx = c.Request.Context() + hdr = c.Request.Header + ) + + // Perform an explicit is-blocked check on request header. + key, expr, err := state.DB.BlockHeaderRegularMatch(ctx, hdr) + switch err { + case nil: + break + + case headerfilter.ErrLargeHeaderValue: + log.Warn(ctx, "large header value") + key = "*" // block large headers + + default: + err := gtserror.Newf("error checking header: %w", err) + return false, err + } + + if key != "" { + if expr != "" { + // Increment block matches stat. + // TODO: replace expvar with build + // taggable metrics types in State{}. + blockMatches.Add(key, expr) + } + + // A header was matched against! + // i.e. this request is blocked. + return true, nil + } + + return false, nil +} + +func isHeaderAllowed(state *state.State, c *gin.Context) (bool, error) { + var ( + ctx = c.Request.Context() + hdr = c.Request.Header + ) + + // Perform an explicit is-allowed check on request header. + key, expr, err := state.DB.AllowHeaderRegularMatch(ctx, hdr) + switch err { + case nil: + break + + case headerfilter.ErrLargeHeaderValue: + log.Warn(ctx, "large header value") + key = "" // block large headers + + default: + err := gtserror.Newf("error checking header: %w", err) + return false, err + } + + if key != "" { + if expr != "" { + // Increment allow matches stat. + // TODO: replace expvar with build + // taggable metrics types in State{}. + allowMatches.Add(key, expr) + } + + // A header was matched against! + // i.e. this request is allowed. + return true, nil + } + + return false, nil +} + +func isHeaderNotAllowed(state *state.State, c *gin.Context) (bool, error) { + var ( + ctx = c.Request.Context() + hdr = c.Request.Header + ) + + // Perform an explicit is-NOT-allowed check on request header. + key, expr, err := state.DB.AllowHeaderInverseMatch(ctx, hdr) + switch err { + case nil: + break + + case headerfilter.ErrLargeHeaderValue: + log.Warn(ctx, "large header value") + key = "*" // block large headers + + default: + err := gtserror.Newf("error checking header: %w", err) + return false, err + } + + if key != "" { + if expr != "" { + // Increment allow matches stat. + // TODO: replace expvar with build + // taggable metrics types in State{}. + allowMatches.Add(key, expr) + } + + // A header was matched against! + // i.e. request is NOT allowed. + return true, nil + } + + return false, nil +} |