summaryrefslogtreecommitdiff
path: root/internal/middleware
diff options
context:
space:
mode:
Diffstat (limited to 'internal/middleware')
-rw-r--r--internal/middleware/headerfilter.go251
-rw-r--r--internal/middleware/headerfilter_test.go299
-rw-r--r--internal/middleware/ratelimit.go2
-rw-r--r--internal/middleware/useragent.go9
-rw-r--r--internal/middleware/util.go51
5 files changed, 607 insertions, 5 deletions
diff --git a/internal/middleware/headerfilter.go b/internal/middleware/headerfilter.go
new file mode 100644
index 000000000..18c9d1e67
--- /dev/null
+++ b/internal/middleware/headerfilter.go
@@ -0,0 +1,251 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package middleware
+
+import (
+ "sync"
+
+ "github.com/gin-gonic/gin"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/headerfilter"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+)
+
+var (
+ allowMatches = matchstats{m: make(map[string]uint64)}
+ blockMatches = matchstats{m: make(map[string]uint64)}
+)
+
+// matchstats is a simple statistics
+// counter for header filter matches.
+// TODO: replace with otel.
+type matchstats struct {
+ m map[string]uint64
+ l sync.Mutex
+}
+
+func (m *matchstats) Add(hdr, regex string) {
+ m.l.Lock()
+ key := hdr + ":" + regex
+ m.m[key]++
+ m.l.Unlock()
+}
+
+// HeaderFilter returns a gin middleware handler that provides HTTP
+// request blocking (filtering) based on database allow / block filters.
+func HeaderFilter(state *state.State) gin.HandlerFunc {
+ switch mode := config.GetAdvancedHeaderFilterMode(); mode {
+ case config.RequestHeaderFilterModeDisabled:
+ return func(ctx *gin.Context) {}
+
+ case config.RequestHeaderFilterModeAllow:
+ return headerFilterAllowMode(state)
+
+ case config.RequestHeaderFilterModeBlock:
+ return headerFilterBlockMode(state)
+
+ default:
+ panic("unrecognized filter mode: " + mode)
+ }
+}
+
+func headerFilterAllowMode(state *state.State) func(c *gin.Context) {
+ _ = *state //nolint
+ // Allowlist mode: explicit block takes
+ // precedence over explicit allow.
+ //
+ // Headers that have neither block
+ // or allow entries are blocked.
+ return func(c *gin.Context) {
+
+ // Check if header is explicitly blocked.
+ block, err := isHeaderBlocked(state, c)
+ if err != nil {
+ respondInternalServerError(c, err)
+ return
+ }
+
+ if block {
+ respondBlocked(c)
+ return
+ }
+
+ // Check if header is missing explicit allow.
+ notAllow, err := isHeaderNotAllowed(state, c)
+ if err != nil {
+ respondInternalServerError(c, err)
+ return
+ }
+
+ if notAllow {
+ respondBlocked(c)
+ return
+ }
+
+ // Allowed!
+ c.Next()
+ }
+}
+
+func headerFilterBlockMode(state *state.State) func(c *gin.Context) {
+ _ = *state //nolint
+ // Blocklist/default mode: explicit allow
+ // takes precedence over explicit block.
+ //
+ // Headers that have neither block
+ // or allow entries are allowed.
+ return func(c *gin.Context) {
+
+ // Check if header is explicitly allowed.
+ allow, err := isHeaderAllowed(state, c)
+ if err != nil {
+ respondInternalServerError(c, err)
+ return
+ }
+
+ if !allow {
+ // Check if header is explicitly blocked.
+ block, err := isHeaderBlocked(state, c)
+ if err != nil {
+ respondInternalServerError(c, err)
+ return
+ }
+
+ if block {
+ respondBlocked(c)
+ return
+ }
+ }
+
+ // Allowed!
+ c.Next()
+ }
+}
+
+func isHeaderBlocked(state *state.State, c *gin.Context) (bool, error) {
+ var (
+ ctx = c.Request.Context()
+ hdr = c.Request.Header
+ )
+
+ // Perform an explicit is-blocked check on request header.
+ key, expr, err := state.DB.BlockHeaderRegularMatch(ctx, hdr)
+ switch err {
+ case nil:
+ break
+
+ case headerfilter.ErrLargeHeaderValue:
+ log.Warn(ctx, "large header value")
+ key = "*" // block large headers
+
+ default:
+ err := gtserror.Newf("error checking header: %w", err)
+ return false, err
+ }
+
+ if key != "" {
+ if expr != "" {
+ // Increment block matches stat.
+ // TODO: replace expvar with build
+ // taggable metrics types in State{}.
+ blockMatches.Add(key, expr)
+ }
+
+ // A header was matched against!
+ // i.e. this request is blocked.
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func isHeaderAllowed(state *state.State, c *gin.Context) (bool, error) {
+ var (
+ ctx = c.Request.Context()
+ hdr = c.Request.Header
+ )
+
+ // Perform an explicit is-allowed check on request header.
+ key, expr, err := state.DB.AllowHeaderRegularMatch(ctx, hdr)
+ switch err {
+ case nil:
+ break
+
+ case headerfilter.ErrLargeHeaderValue:
+ log.Warn(ctx, "large header value")
+ key = "" // block large headers
+
+ default:
+ err := gtserror.Newf("error checking header: %w", err)
+ return false, err
+ }
+
+ if key != "" {
+ if expr != "" {
+ // Increment allow matches stat.
+ // TODO: replace expvar with build
+ // taggable metrics types in State{}.
+ allowMatches.Add(key, expr)
+ }
+
+ // A header was matched against!
+ // i.e. this request is allowed.
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func isHeaderNotAllowed(state *state.State, c *gin.Context) (bool, error) {
+ var (
+ ctx = c.Request.Context()
+ hdr = c.Request.Header
+ )
+
+ // Perform an explicit is-NOT-allowed check on request header.
+ key, expr, err := state.DB.AllowHeaderInverseMatch(ctx, hdr)
+ switch err {
+ case nil:
+ break
+
+ case headerfilter.ErrLargeHeaderValue:
+ log.Warn(ctx, "large header value")
+ key = "*" // block large headers
+
+ default:
+ err := gtserror.Newf("error checking header: %w", err)
+ return false, err
+ }
+
+ if key != "" {
+ if expr != "" {
+ // Increment allow matches stat.
+ // TODO: replace expvar with build
+ // taggable metrics types in State{}.
+ allowMatches.Add(key, expr)
+ }
+
+ // A header was matched against!
+ // i.e. request is NOT allowed.
+ return true, nil
+ }
+
+ return false, nil
+}
diff --git a/internal/middleware/headerfilter_test.go b/internal/middleware/headerfilter_test.go
new file mode 100644
index 000000000..a28644153
--- /dev/null
+++ b/internal/middleware/headerfilter_test.go
@@ -0,0 +1,299 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package middleware_test
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/headerfilter"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/middleware"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+func TestHeaderFilter(t *testing.T) {
+ testrig.InitTestLog()
+ testrig.InitTestConfig()
+
+ for _, test := range []struct {
+ mode string
+ allow []filter
+ block []filter
+ input http.Header
+ expect bool
+ }{
+ {
+ // Allow mode with expected 200 OK.
+ mode: config.RequestHeaderFilterModeAllow,
+ allow: []filter{
+ {"User-Agent", ".*Firefox.*"},
+ },
+ block: []filter{},
+ input: http.Header{
+ "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
+ },
+ expect: true,
+ },
+ {
+ // Allow mode with expected 403 Forbidden.
+ mode: config.RequestHeaderFilterModeAllow,
+ allow: []filter{
+ {"User-Agent", ".*Firefox.*"},
+ },
+ block: []filter{},
+ input: http.Header{
+ "User-Agent": []string{"Chromium v169.42; Extra Tracking Info"},
+ },
+ expect: false,
+ },
+ {
+ // Allow mode with too long header value expecting 403 Forbidden.
+ mode: config.RequestHeaderFilterModeAllow,
+ allow: []filter{
+ {"User-Agent", ".*"},
+ },
+ block: []filter{},
+ input: http.Header{
+ "User-Agent": []string{func() string {
+ var buf strings.Builder
+ for i := 0; i < headerfilter.MaxHeaderValue+1; i++ {
+ buf.WriteByte(' ')
+ }
+ return buf.String()
+ }()},
+ },
+ expect: false,
+ },
+ {
+ // Allow mode with explicit block expecting 403 Forbidden.
+ mode: config.RequestHeaderFilterModeAllow,
+ allow: []filter{
+ {"User-Agent", ".*Firefox.*"},
+ },
+ block: []filter{
+ {"User-Agent", ".*Firefox v169\\.42.*"},
+ },
+ input: http.Header{
+ "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
+ },
+ expect: false,
+ },
+ {
+ // Block mode with an expected 403 Forbidden.
+ mode: config.RequestHeaderFilterModeBlock,
+ allow: []filter{},
+ block: []filter{
+ {"User-Agent", ".*Firefox.*"},
+ },
+ input: http.Header{
+ "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
+ },
+ expect: false,
+ },
+ {
+ // Block mode with an expected 200 OK.
+ mode: config.RequestHeaderFilterModeBlock,
+ allow: []filter{},
+ block: []filter{
+ {"User-Agent", ".*Firefox.*"},
+ },
+ input: http.Header{
+ "User-Agent": []string{"Chromium v169.42; Extra Tracking Info"},
+ },
+ expect: true,
+ },
+ {
+ // Block mode with too long header value expecting 403 Forbidden.
+ mode: config.RequestHeaderFilterModeBlock,
+ allow: []filter{},
+ block: []filter{
+ {"User-Agent", "none"},
+ },
+ input: http.Header{
+ "User-Agent": []string{func() string {
+ var buf strings.Builder
+ for i := 0; i < headerfilter.MaxHeaderValue+1; i++ {
+ buf.WriteByte(' ')
+ }
+ return buf.String()
+ }()},
+ },
+ expect: false,
+ },
+ {
+ // Block mode with explicit allow expecting 200 OK.
+ mode: config.RequestHeaderFilterModeBlock,
+ allow: []filter{
+ {"User-Agent", ".*Firefox.*"},
+ },
+ block: []filter{
+ {"User-Agent", ".*Firefox v169\\.42.*"},
+ },
+ input: http.Header{
+ "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
+ },
+ expect: true,
+ },
+ {
+ // Disabled mode with an expected 200 OK.
+ mode: config.RequestHeaderFilterModeDisabled,
+ allow: []filter{
+ {"Key1", "only-this"},
+ {"Key2", "only-this"},
+ {"Key3", "only-this"},
+ },
+ block: []filter{
+ {"Key1", "Value"},
+ {"Key2", "Value"},
+ {"Key3", "Value"},
+ },
+ input: http.Header{
+ "Key1": []string{"Value"},
+ "Key2": []string{"Value"},
+ "Key3": []string{"Value"},
+ },
+ expect: true,
+ },
+ } {
+ // Generate a unique name for this test case.
+ name := fmt.Sprintf("%s allow=%v block=%v => expect=%v",
+ test.mode,
+ test.allow,
+ test.block,
+ test.expect,
+ )
+
+ // Update header filter mode to test case.
+ config.SetAdvancedHeaderFilterMode(test.mode)
+
+ // Run this particular test case.
+ ok := t.Run(name, func(t *testing.T) {
+ testHeaderFilter(t,
+ test.allow,
+ test.block,
+ test.input,
+ test.expect,
+ )
+ })
+
+ if !ok {
+ return
+ }
+ }
+}
+
+func testHeaderFilter(t *testing.T, allow, block []filter, input http.Header, expect bool) {
+ var err error
+
+ // Create test context with cancel.
+ ctx := context.Background()
+ ctx, cncl := context.WithCancel(ctx)
+ defer cncl()
+
+ // Initialize caches.
+ var state state.State
+ state.Caches.Init()
+
+ // Create new database instance with test config.
+ state.DB, err = bundb.NewBunDBService(ctx, &state)
+ if err != nil {
+ t.Fatalf("error opening database: %v", err)
+ }
+
+ // Insert all allow filters into DB.
+ for _, filter := range allow {
+ filter := &gtsmodel.HeaderFilter{
+ ID: id.NewULID(),
+ Header: filter.header,
+ Regex: filter.regex,
+ AuthorID: "admin-id",
+ Author: nil,
+ }
+
+ if err := state.DB.PutAllowHeaderFilter(ctx, filter); err != nil {
+ t.Fatalf("error inserting allow filter into database: %v", err)
+ }
+ }
+
+ // Insert all block filters into DB.
+ for _, filter := range block {
+ filter := &gtsmodel.HeaderFilter{
+ ID: id.NewULID(),
+ Header: filter.header,
+ Regex: filter.regex,
+ AuthorID: "admin-id",
+ Author: nil,
+ }
+
+ if err := state.DB.PutBlockHeaderFilter(ctx, filter); err != nil {
+ t.Fatalf("error inserting block filter into database: %v", err)
+ }
+ }
+
+ // Gin test http engine
+ // (used for ctx init).
+ e := gin.New()
+
+ // Create new filter middleware to test against.
+ middleware := middleware.HeaderFilter(&state)
+ e.Use(middleware)
+
+ // Set the empty gin handler (always returns okay).
+ e.Handle("GET", "/", func(ctx *gin.Context) { ctx.Status(200) })
+
+ // Prepare a gin test context.
+ r := httptest.NewRequest("GET", "/", nil)
+ rw := httptest.NewRecorder()
+
+ // Set input headers.
+ r.Header = input
+
+ // Pass req through
+ // engine handler.
+ e.ServeHTTP(rw, r)
+
+ // Get http result.
+ res := rw.Result()
+
+ switch {
+ case expect && res.StatusCode != http.StatusOK:
+ t.Errorf("unexpected response (should allow): %s", res.Status)
+
+ case !expect && res.StatusCode != http.StatusForbidden:
+ t.Errorf("unexpected response (should block): %s", res.Status)
+ }
+}
+
+type filter struct {
+ header string
+ regex string
+}
+
+func (hf *filter) String() string {
+ return fmt.Sprintf("%s=%q", hf.header, hf.regex)
+}
diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go
index 57055fe70..352a30c22 100644
--- a/internal/middleware/ratelimit.go
+++ b/internal/middleware/ratelimit.go
@@ -146,7 +146,7 @@ func RateLimit(limit int, exceptions []string) gin.HandlerFunc {
apiutil.Data(c,
http.StatusTooManyRequests,
apiutil.AppJSON,
- apiutil.ErrorRateLimitReached,
+ apiutil.ErrorRateLimited,
)
c.Abort()
return
diff --git a/internal/middleware/useragent.go b/internal/middleware/useragent.go
index 6dc3e401f..38d28f4e5 100644
--- a/internal/middleware/useragent.go
+++ b/internal/middleware/useragent.go
@@ -18,21 +18,22 @@
package middleware
import (
- "errors"
"net/http"
"github.com/gin-gonic/gin"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
)
// UserAgent returns a gin middleware which aborts requests with
// empty user agent strings, returning code 418 - I'm a teapot.
func UserAgent() gin.HandlerFunc {
// todo: make this configurable
+ var rsp = []byte(`{"error": "I'm a teapot: no user-agent sent with request"}`)
return func(c *gin.Context) {
if ua := c.Request.UserAgent(); ua == "" {
- code := http.StatusTeapot
- err := errors.New(http.StatusText(code) + ": no user-agent sent with request")
- c.AbortWithStatusJSON(code, gin.H{"error": err.Error()})
+ apiutil.Data(c,
+ http.StatusTeapot, apiutil.AppJSON, rsp)
+ c.Abort()
}
}
}
diff --git a/internal/middleware/util.go b/internal/middleware/util.go
new file mode 100644
index 000000000..82850fd6d
--- /dev/null
+++ b/internal/middleware/util.go
@@ -0,0 +1,51 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package middleware
+
+import (
+ "net/http"
+
+ "github.com/gin-gonic/gin"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+)
+
+// respondBlocked responds to the given gin context with
+// status forbidden, and a generic forbidden JSON response,
+// finally aborting the gin handler chain.
+func respondBlocked(c *gin.Context) {
+ apiutil.Data(c,
+ http.StatusForbidden,
+ apiutil.AppJSON,
+ apiutil.StatusForbiddenJSON,
+ )
+ c.Abort()
+}
+
+// respondInternalServerError responds to the given gin context
+// with status internal server error, a generic internal server
+// error JSON response, sets the given error on the gin context
+// for later logging, finally aborting the gin handler chain.
+func respondInternalServerError(c *gin.Context, err error) {
+ apiutil.Data(c,
+ http.StatusInternalServerError,
+ apiutil.AppJSON,
+ apiutil.StatusInternalServerErrorJSON,
+ )
+ _ = c.Error(err)
+ c.Abort()
+}