diff options
author | 2023-12-18 14:18:25 +0000 | |
---|---|---|
committer | 2023-12-18 14:18:25 +0000 | |
commit | 8ebb7775a35b632d49a8f294d83ac786666631f3 (patch) | |
tree | 02ac5475274125170132b0a4d9f69bd67491a32c /internal | |
parent | fix poll total vote double count (#2464) (diff) | |
download | gotosocial-8ebb7775a35b632d49a8f294d83ac786666631f3.tar.xz |
[feature] request blocking by http headers (#2409)
Diffstat (limited to 'internal')
31 files changed, 2254 insertions, 76 deletions
diff --git a/internal/api/client/admin/admin.go b/internal/api/client/admin/admin.go index 16c5fa8f8..6173218e0 100644 --- a/internal/api/client/admin/admin.go +++ b/internal/api/client/admin/admin.go @@ -35,6 +35,10 @@ const ( DomainAllowsPath = BasePath + "/domain_allows" DomainAllowsPathWithID = DomainAllowsPath + "/:" + IDKey DomainKeysExpirePath = BasePath + "/domain_keys_expire" + HeaderAllowsPath = BasePath + "/header_allows" + HeaderAllowsPathWithID = HeaderAllowsPath + "/:" + IDKey + HeaderBlocksPath = BasePath + "/header_blocks" + HeaderBlocksPathWithID = HeaderAllowsPath + "/:" + IDKey AccountsPath = BasePath + "/accounts" AccountsPathWithID = AccountsPath + "/:" + IDKey AccountsActionPath = AccountsPathWithID + "/action" @@ -95,6 +99,16 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H attachHandler(http.MethodGet, DomainAllowsPathWithID, m.DomainAllowGETHandler) attachHandler(http.MethodDelete, DomainAllowsPathWithID, m.DomainAllowDELETEHandler) + // header filtering administration routes + attachHandler(http.MethodGet, HeaderAllowsPathWithID, m.HeaderFilterAllowGET) + attachHandler(http.MethodGet, HeaderBlocksPathWithID, m.HeaderFilterBlockGET) + attachHandler(http.MethodGet, HeaderAllowsPath, m.HeaderFilterAllowsGET) + attachHandler(http.MethodGet, HeaderBlocksPath, m.HeaderFilterBlocksGET) + attachHandler(http.MethodPost, HeaderAllowsPath, m.HeaderFilterAllowPOST) + attachHandler(http.MethodPost, HeaderBlocksPath, m.HeaderFilterBlockPOST) + attachHandler(http.MethodDelete, HeaderAllowsPathWithID, m.HeaderFilterAllowDELETE) + attachHandler(http.MethodDelete, HeaderBlocksPathWithID, m.HeaderFilterBlockDELETE) + // domain maintenance stuff attachHandler(http.MethodPost, DomainKeysExpirePath, m.DomainKeysExpirePOSTHandler) diff --git a/internal/api/client/admin/headerfilter.go b/internal/api/client/admin/headerfilter.go new file mode 100644 index 000000000..7b1a85c86 --- /dev/null +++ b/internal/api/client/admin/headerfilter.go @@ -0,0 +1,173 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package admin + +import ( + "context" + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// getHeaderFilter is a gin handler function that returns details of an HTTP header filter with provided ID, using given get function. +func (m *Module) getHeaderFilter(c *gin.Context, get func(context.Context, string) (*apimodel.HeaderFilter, gtserror.WithCode)) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if !*authed.User.Admin { + const text = "user not an admin" + errWithCode := gtserror.NewErrorForbidden(errors.New(text), text) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + filterID, errWithCode := apiutil.ParseID(c.Param("ID")) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + filter, errWithCode := get(c.Request.Context(), filterID) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, filter) +} + +// getHeaderFilters is a gin handler function that returns details of all HTTP header filters using given get function. +func (m *Module) getHeaderFilters(c *gin.Context, get func(context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode)) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if !*authed.User.Admin { + const text = "user not an admin" + errWithCode := gtserror.NewErrorForbidden(errors.New(text), text) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + filters, errWithCode := get(c.Request.Context()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, filters) +} + +// createHeaderFilter is a gin handler function that creates a HTTP header filter entry using provided form data, passing to given create function. +func (m *Module) createHeaderFilter(c *gin.Context, create func(context.Context, *gtsmodel.Account, *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode)) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if !*authed.User.Admin { + const text = "user not an admin" + errWithCode := gtserror.NewErrorForbidden(errors.New(text), text) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + var form apimodel.HeaderFilterRequest + + if err := c.ShouldBind(&form); err != nil { + errWithCode := gtserror.NewErrorBadRequest(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + filter, errWithCode := create( + c.Request.Context(), + authed.Account, + &form, + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, filter) +} + +// deleteHeaderFilter is a gin handler function that deletes an HTTP header filter with provided ID, using given delete function. +func (m *Module) deleteHeaderFilter(c *gin.Context, delete func(context.Context, string) gtserror.WithCode) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if !*authed.User.Admin { + const text = "user not an admin" + errWithCode := gtserror.NewErrorForbidden(errors.New(text), text) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + filterID, errWithCode := apiutil.ParseID(c.Param("ID")) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + errWithCode = delete(c.Request.Context(), filterID) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.Status(http.StatusAccepted) +} diff --git a/internal/api/client/admin/headerfilter_create.go b/internal/api/client/admin/headerfilter_create.go new file mode 100644 index 000000000..d74dc5e15 --- /dev/null +++ b/internal/api/client/admin/headerfilter_create.go @@ -0,0 +1,102 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package admin + +import ( + "github.com/gin-gonic/gin" +) + +// HeaderFilterAllowPOST swagger:operation POST /api/v1/admin/header_allows headerFilterAllowCreate +// +// Create new "allow" HTTP request header filter. +// +// The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'. +// The parameters can also be given in the body of the request, as XML, if the content-type is set to 'application/xml'. +// +// --- +// tags: +// - admin +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '200': +// description: The newly created "allow" header filter. +// schema: +// "$ref": "#/definitions/headerFilter" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '500': +// description: internal server error +func (m *Module) HeaderFilterAllowPOST(c *gin.Context) { + m.createHeaderFilter(c, m.processor.Admin().CreateAllowHeaderFilter) +} + +// HeaderFilterBlockPOST swagger:operation POST /api/v1/admin/header_blocks headerFilterBlockCreate +// +// Create new "block" HTTP request header filter. +// +// The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'. +// The parameters can also be given in the body of the request, as XML, if the content-type is set to 'application/xml'. +// +// --- +// tags: +// - admin +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '200': +// description: The newly created "block" header filter. +// schema: +// "$ref": "#/definitions/headerFilter" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '500': +// description: internal server error +func (m *Module) HeaderFilterBlockPOST(c *gin.Context) { + m.createHeaderFilter(c, m.processor.Admin().CreateBlockHeaderFilter) +} diff --git a/internal/api/client/admin/headerfilter_delete.go b/internal/api/client/admin/headerfilter_delete.go new file mode 100644 index 000000000..806e62a04 --- /dev/null +++ b/internal/api/client/admin/headerfilter_delete.go @@ -0,0 +1,96 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package admin + +import ( + "github.com/gin-gonic/gin" +) + +// HeaderFilterAllowDELETE swagger:operation DELETE /api/v1/admin/header_allows/{id} headerFilterAllowDelete +// +// Delete the "allow" header filter with the given ID. +// +// --- +// tags: +// - admin +// +// parameters: +// - +// name: id +// type: string +// description: Target header filter ID. +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '202': +// description: Accepted +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '500': +// description: internal server error +func (m *Module) HeaderFilterAllowDELETE(c *gin.Context) { + m.deleteHeaderFilter(c, m.processor.Admin().DeleteAllowHeaderFilter) +} + +// HeaderFilterBlockDELETE swagger:operation DELETE /api/v1/admin/header_blocks/{id} headerFilterBlockDelete +// +// Delete the "block" header filter with the given ID. +// +// --- +// tags: +// - admin +// +// parameters: +// - +// name: id +// type: string +// description: Target header filter ID. +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '202': +// description: Accepted +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '500': +// description: internal server error +func (m *Module) HeaderFilterBlockDELETE(c *gin.Context) { + m.deleteHeaderFilter(c, m.processor.Admin().DeleteAllowHeaderFilter) +} diff --git a/internal/api/client/admin/headerfilter_get.go b/internal/api/client/admin/headerfilter_get.go new file mode 100644 index 000000000..5bca6d18d --- /dev/null +++ b/internal/api/client/admin/headerfilter_get.go @@ -0,0 +1,164 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package admin + +import "github.com/gin-gonic/gin" + +// HeaderFilterAllowGET swagger:operation GET /api/v1/admin/header_allows/{id} headerFilterAllowGet +// +// Get "allow" header filter with the given ID. +// +// --- +// tags: +// - admin +// +// parameters: +// - +// name: id +// type: string +// description: Target header filter ID. +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '200': +// description: The requested "allow" header filter. +// schema: +// "$ref": "#/definitions/headerFilter" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '500': +// description: internal server error +func (m *Module) HeaderFilterAllowGET(c *gin.Context) { + m.getHeaderFilter(c, m.processor.Admin().GetAllowHeaderFilter) +} + +// HeaderFilterBlockGET swagger:operation GET /api/v1/admin/header_blocks/{id} headerFilterBlockGet +// +// Get "block" header filter with the given ID. +// +// --- +// tags: +// - admin +// +// parameters: +// - +// name: id +// type: string +// description: Target header filter ID. +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '200': +// description: The requested "block" header filter. +// schema: +// "$ref": "#/definitions/headerFilter" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '500': +// description: internal server error +func (m *Module) HeaderFilterBlockGET(c *gin.Context) { + m.getHeaderFilter(c, m.processor.Admin().GetBlockHeaderFilter) +} + +// HeaderFilterAllowsGET swagger:operation GET /api/v1/admin/header_allows headerFilterAllowsGet +// +// Get all "allow" header filters currently in place. +// +// --- +// tags: +// - admin +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '200': +// description: All "allow" header filters currently in place. +// schema: +// type: array +// items: +// "$ref": "#/definitions/headerFilter" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '500': +// description: internal server error +func (m *Module) HeaderFilterAllowsGET(c *gin.Context) { + m.getHeaderFilters(c, m.processor.Admin().GetAllowHeaderFilters) +} + +// HeaderFilterBlocksGET swagger:operation GET /api/v1/admin/header_blocks headerFilterBlocksGet +// +// Get all "allow" header filters currently in place. +// +// --- +// tags: +// - admin +// +// security: +// - OAuth2 Bearer: +// - admin +// +// responses: +// '200': +// description: All "block" header filters currently in place. +// schema: +// type: array +// items: +// "$ref": "#/definitions/headerFilter" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '500': +// description: internal server error +func (m *Module) HeaderFilterBlocksGET(c *gin.Context) { + m.getHeaderFilters(c, m.processor.Admin().GetBlockHeaderFilters) +} diff --git a/internal/api/model/headerfilter.go b/internal/api/model/headerfilter.go new file mode 100644 index 000000000..96ba819f5 --- /dev/null +++ b/internal/api/model/headerfilter.go @@ -0,0 +1,55 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package model + +// HeaderFilter represents a regex value filter applied to one particular HTTP header (allow / block). +type HeaderFilter struct { + // The ID of the header filter. + // example: 01FBW21XJA09XYX51KV5JVBW0F + // readonly: true + ID string `json:"id"` + + // The HTTP header to match against. + // example: User-Agent + Header string `json:"header"` + + // The header value matching regular expression. + // example: .*Firefox.* + Regex string `json:"regex"` + + // The ID of the admin account that created this header filter. + // example: 01FBW2758ZB6PBR200YPDDJK4C + // readonly: true + CreatedBy string `json:"created_by"` + + // Time at which the header filter was created (ISO 8601 Datetime). + // example: 2021-07-30T09:20:25+00:00 + // readonly: true + CreatedAt string `json:"created_at"` +} + +// HeaderFilterRequest is the form submitted as a POST to create a new header filter entry (allow / block). +// +// swagger:model headerFilterCreateRequest +type HeaderFilterRequest struct { + // The HTTP header to match against (e.g. User-Agent). + Header string `form:"header" json:"header" xml:"header"` + + // The header value matching regular expression. + Regex string `form:"regex" json:"regex" xml:"regex"` +} diff --git a/internal/api/util/response.go b/internal/api/util/response.go index 150d2ac2e..753eaefb8 100644 --- a/internal/api/util/response.go +++ b/internal/api/util/response.go @@ -39,14 +39,17 @@ var ( StatusAcceptedJSON = mustJSON(map[string]string{ "status": http.StatusText(http.StatusAccepted), }) + StatusForbiddenJSON = mustJSON(map[string]string{ + "status": http.StatusText(http.StatusForbidden), + }) StatusInternalServerErrorJSON = mustJSON(map[string]string{ "status": http.StatusText(http.StatusInternalServerError), }) ErrorCapacityExceeded = mustJSON(map[string]string{ - "error": "server capacity exceeded!", + "error": "server capacity exceeded", }) - ErrorRateLimitReached = mustJSON(map[string]string{ - "error": "rate limit reached!", + ErrorRateLimited = mustJSON(map[string]string{ + "error": "rate limit reached", }) EmptyJSONObject = mustJSON("{}") EmptyJSONArray = mustJSON("[]") diff --git a/internal/cache/cache.go b/internal/cache/cache.go index dabf151ff..73e3ad6f0 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -18,21 +18,26 @@ package cache import ( + "github.com/superseriousbusiness/gotosocial/internal/cache/headerfilter" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" ) type Caches struct { - // GTS provides access to the collection of gtsmodel object caches. - // (used by the database). + // GTS provides access to the collection of + // gtsmodel object caches. (used by the database). GTS GTSCaches - // AP provides access to the collection of ActivityPub object caches. - // (planned to be used by the typeconverter). - AP APCaches + // AllowHeaderFilters provides access to + // the allow []headerfilter.Filter cache. + AllowHeaderFilters headerfilter.Cache - // Visibility provides access to the item visibility cache. - // (used by the visibility filter). + // BlockHeaderFilters provides access to + // the block []headerfilter.Filter cache. + BlockHeaderFilters headerfilter.Cache + + // Visibility provides access to the item visibility + // cache. (used by the visibility filter). Visibility VisibilityCache // prevent pass-by-value. @@ -45,7 +50,6 @@ func (c *Caches) Init() { log.Infof(nil, "init: %p", c) c.GTS.Init() - c.AP.Init() c.Visibility.Init() // Setup cache invalidate hooks. @@ -58,7 +62,6 @@ func (c *Caches) Start() { log.Infof(nil, "start: %p", c) c.GTS.Start() - c.AP.Start() c.Visibility.Start() } @@ -67,7 +70,6 @@ func (c *Caches) Stop() { log.Infof(nil, "stop: %p", c) c.GTS.Stop() - c.AP.Stop() c.Visibility.Stop() } diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go index 051ec5c1b..1b836ed28 100644 --- a/internal/cache/domain/domain.go +++ b/internal/cache/domain/domain.go @@ -21,7 +21,6 @@ import ( "fmt" "strings" "sync/atomic" - "unsafe" "golang.org/x/exp/slices" ) @@ -37,17 +36,17 @@ import ( // The .Clear() function can be used to invalidate the cache, // e.g. when an entry is added / deleted from the database. type Cache struct { - // atomically updated ptr value to the // current domain cache radix trie. - rootptr unsafe.Pointer + rootptr atomic.Pointer[root] } // Matches checks whether domain matches an entry in the cache. // If the cache is not currently loaded, then the provided load // function is used to hydrate it. func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, error) { - // Load the current root pointer value. - ptr := atomic.LoadPointer(&c.rootptr) + // Load the current + // root pointer value. + ptr := c.rootptr.Load() if ptr == nil { // Cache is not hydrated. @@ -60,35 +59,32 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err // Allocate new radix trie // node to store matches. - root := new(root) + ptr = new(root) // Add each domain to the trie. for _, domain := range domains { - root.Add(domain) + ptr.Add(domain) } // Sort the trie. - root.Sort() + ptr.Sort() - // Store the new node ptr. - ptr = unsafe.Pointer(root) - atomic.StorePointer(&c.rootptr, ptr) + // Store new node ptr. + c.rootptr.Store(ptr) } - // Look for a match in the trie node. - return (*root)(ptr).Match(domain), nil + // Look for match in trie node. + return ptr.Match(domain), nil } // Clear will drop the currently loaded domain list, // triggering a reload on next call to .Matches(). -func (c *Cache) Clear() { - atomic.StorePointer(&c.rootptr, nil) -} +func (c *Cache) Clear() { c.rootptr.Store(nil) } // String returns a string representation of stored domains in cache. func (c *Cache) String() string { - if ptr := atomic.LoadPointer(&c.rootptr); ptr != nil { - return (*root)(ptr).String() + if ptr := c.rootptr.Load(); ptr != nil { + return ptr.String() } return "<empty>" } diff --git a/internal/cache/headerfilter/filter.go b/internal/cache/headerfilter/filter.go new file mode 100644 index 000000000..96b6e757f --- /dev/null +++ b/internal/cache/headerfilter/filter.go @@ -0,0 +1,105 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package headerfilter + +import ( + "fmt" + "net/http" + "sync/atomic" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/headerfilter" +) + +// Cache provides a means of caching headerfilter.Filters in +// memory to reduce load on an underlying storage mechanism. +type Cache struct { + // current cached header filters slice. + ptr atomic.Pointer[headerfilter.Filters] +} + +// RegularMatch performs .RegularMatch() on cached headerfilter.Filters, loading using callback if necessary. +func (c *Cache) RegularMatch(h http.Header, load func() ([]*gtsmodel.HeaderFilter, error)) (string, string, error) { + // Load ptr value. + ptr := c.ptr.Load() + + if ptr == nil { + // Cache is not hydrated. + // Load filters from callback. + filters, err := loadFilters(load) + if err != nil { + return "", "", err + } + + // Store the new + // header filters. + ptr = &filters + c.ptr.Store(ptr) + } + + // Deref and perform match. + return ptr.RegularMatch(h) +} + +// InverseMatch performs .InverseMatch() on cached headerfilter.Filters, loading using callback if necessary. +func (c *Cache) InverseMatch(h http.Header, load func() ([]*gtsmodel.HeaderFilter, error)) (string, string, error) { + // Load ptr value. + ptr := c.ptr.Load() + + if ptr == nil { + // Cache is not hydrated. + // Load filters from callback. + filters, err := loadFilters(load) + if err != nil { + return "", "", err + } + + // Store the new + // header filters. + ptr = &filters + c.ptr.Store(ptr) + } + + // Deref and perform match. + return ptr.InverseMatch(h) +} + +// Clear will drop the currently loaded filters, +// triggering a reload on next call to ._Match(). +func (c *Cache) Clear() { c.ptr.Store(nil) } + +// loadFilters will load filters from given load callback, creating and parsing raw filters. +func loadFilters(load func() ([]*gtsmodel.HeaderFilter, error)) (headerfilter.Filters, error) { + // Load filters from callback. + hdrFilters, err := load() + if err != nil { + return nil, fmt.Errorf("error reloading cache: %w", err) + } + + // Allocate new header filter slice to store expressions. + filters := make(headerfilter.Filters, 0, len(hdrFilters)) + + // Add all raw expression to filter slice. + for _, filter := range hdrFilters { + if err := filters.Append(filter.Header, filter.Regex); err != nil { + return nil, fmt.Errorf("error appending exprs: %w", err) + } + } + + return filters, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 173999b53..68c065852 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -163,6 +163,7 @@ type Configuration struct { AdvancedThrottlingRetryAfter time.Duration `name:"advanced-throttling-retry-after" usage:"Retry-After duration response to send for throttled requests."` AdvancedSenderMultiplier int `name:"advanced-sender-multiplier" usage:"Multiplier to use per cpu for batching outgoing fedi messages. 0 or less turns batching off (not recommended)."` AdvancedCSPExtraURIs []string `name:"advanced-csp-extra-uris" usage:"Additional URIs to allow when building content-security-policy for media + images."` + AdvancedHeaderFilterMode string `name:"advanced-header-filter-mode" usage:"Set incoming request header filtering mode."` // HTTPClient configuration vars. HTTPClient HTTPClientConfiguration `name:"http-client"` diff --git a/internal/config/const.go b/internal/config/const.go index 29e4b14e8..48087c4ce 100644 --- a/internal/config/const.go +++ b/internal/config/const.go @@ -17,10 +17,16 @@ package config -// Instance federation mode determines how this -// instance federates with others (if at all). const ( + // Instance federation mode determines how this + // instance federates with others (if at all). InstanceFederationModeBlocklist = "blocklist" InstanceFederationModeAllowlist = "allowlist" InstanceFederationModeDefault = InstanceFederationModeBlocklist + + // Request header filter mode determines how + // this instance will perform request filtering. + RequestHeaderFilterModeAllow = "allow" + RequestHeaderFilterModeBlock = "block" + RequestHeaderFilterModeDisabled = "" ) diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 5aba6c689..3996b133f 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -135,6 +135,7 @@ var Defaults = Configuration{ AdvancedThrottlingRetryAfter: time.Second * 30, AdvancedSenderMultiplier: 2, // 2 senders per CPU AdvancedCSPExtraURIs: []string{}, + AdvancedHeaderFilterMode: RequestHeaderFilterModeDisabled, Cache: CacheConfiguration{ // Rough memory target that the total diff --git a/internal/config/flags.go b/internal/config/flags.go index 45ba70f9e..350f56635 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -158,6 +158,7 @@ func (s *ConfigState) AddServerFlags(cmd *cobra.Command) { cmd.Flags().Duration(AdvancedThrottlingRetryAfterFlag(), cfg.AdvancedThrottlingRetryAfter, fieldtag("AdvancedThrottlingRetryAfter", "usage")) cmd.Flags().Int(AdvancedSenderMultiplierFlag(), cfg.AdvancedSenderMultiplier, fieldtag("AdvancedSenderMultiplier", "usage")) cmd.Flags().StringSlice(AdvancedCSPExtraURIsFlag(), cfg.AdvancedCSPExtraURIs, fieldtag("AdvancedCSPExtraURIs", "usage")) + cmd.Flags().String(AdvancedHeaderFilterModeFlag(), cfg.AdvancedHeaderFilterMode, fieldtag("AdvancedHeaderFilterMode", "usage")) cmd.Flags().String(RequestIDHeaderFlag(), cfg.RequestIDHeader, fieldtag("RequestIDHeader", "usage")) }) diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 393a1b1e9..72b2a7fd9 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2600,6 +2600,31 @@ func GetAdvancedCSPExtraURIs() []string { return global.GetAdvancedCSPExtraURIs( // SetAdvancedCSPExtraURIs safely sets the value for global configuration 'AdvancedCSPExtraURIs' field func SetAdvancedCSPExtraURIs(v []string) { global.SetAdvancedCSPExtraURIs(v) } +// GetAdvancedHeaderFilterMode safely fetches the Configuration value for state's 'AdvancedHeaderFilterMode' field +func (st *ConfigState) GetAdvancedHeaderFilterMode() (v string) { + st.mutex.RLock() + v = st.config.AdvancedHeaderFilterMode + st.mutex.RUnlock() + return +} + +// SetAdvancedHeaderFilterMode safely sets the Configuration value for state's 'AdvancedHeaderFilterMode' field +func (st *ConfigState) SetAdvancedHeaderFilterMode(v string) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.AdvancedHeaderFilterMode = v + st.reloadToViper() +} + +// AdvancedHeaderFilterModeFlag returns the flag name for the 'AdvancedHeaderFilterMode' field +func AdvancedHeaderFilterModeFlag() string { return "advanced-header-filter-mode" } + +// GetAdvancedHeaderFilterMode safely fetches the value for global configuration 'AdvancedHeaderFilterMode' field +func GetAdvancedHeaderFilterMode() string { return global.GetAdvancedHeaderFilterMode() } + +// SetAdvancedHeaderFilterMode safely sets the value for global configuration 'AdvancedHeaderFilterMode' field +func SetAdvancedHeaderFilterMode(v string) { global.SetAdvancedHeaderFilterMode(v) } + // GetHTTPClientAllowIPs safely fetches the Configuration value for state's 'HTTPClient.AllowIPs' field func (st *ConfigState) GetHTTPClientAllowIPs() (v []string) { st.mutex.RLock() diff --git a/internal/db/basic.go b/internal/db/basic.go index 7cd690aef..3a8e2af8d 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -25,10 +25,6 @@ type Basic interface { // For implementations that don't use tables, this can just return nil. CreateTable(ctx context.Context, i interface{}) error - // CreateAllTables creates *all* tables necessary for the running of GoToSocial. - // Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized. - CreateAllTables(ctx context.Context) error - // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. DropTable(ctx context.Context, i interface{}) error diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index e68903efa..488f59ad5 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -22,7 +22,6 @@ import ( "errors" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" ) @@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error { return err } -func (b *basicDB) CreateAllTables(ctx context.Context) error { - models := []interface{}{ - >smodel.Account{}, - >smodel.Application{}, - >smodel.Block{}, - >smodel.DomainBlock{}, - >smodel.EmailDomainBlock{}, - >smodel.Follow{}, - >smodel.FollowRequest{}, - >smodel.MediaAttachment{}, - >smodel.Mention{}, - >smodel.Status{}, - >smodel.StatusToEmoji{}, - >smodel.StatusFave{}, - >smodel.StatusBookmark{}, - >smodel.ThreadMute{}, - >smodel.Tag{}, - >smodel.User{}, - >smodel.Emoji{}, - >smodel.Instance{}, - >smodel.Notification{}, - >smodel.RouterSession{}, - >smodel.Token{}, - >smodel.Client{}, - } - for _, i := range models { - if err := b.CreateTable(ctx, i); err != nil { - return err - } - } - return nil -} - func (b *basicDB) DropTable(ctx context.Context, i interface{}) error { _, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx) return err diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index f7417cfeb..d9415eff4 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -67,6 +67,7 @@ type DBService struct { db.Basic db.Domain db.Emoji + db.HeaderFilter db.Instance db.List db.Marker @@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + HeaderFilter: &headerFilterDB{ + db: db, + state: state, + }, Instance: &instanceDB{ db: db, state: state, diff --git a/internal/db/bundb/headerfilter.go b/internal/db/bundb/headerfilter.go new file mode 100644 index 000000000..087b65c82 --- /dev/null +++ b/internal/db/bundb/headerfilter.go @@ -0,0 +1,207 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "net/http" + "time" + "unsafe" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/uptrace/bun" +) + +type headerFilterDB struct { + db *DB + state *state.State +} + +func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetAllowHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetAllowHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetBlockHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetBlockHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) { + filter := new(gtsmodel.HeaderFilterAllow) + if err := h.db.NewSelect(). + Model(filter). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx); err != nil { + return nil, err + } + return fromAllowFilter(filter), nil +} + +func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) { + filter := new(gtsmodel.HeaderFilterBlock) + if err := h.db.NewSelect(). + Model(filter). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx); err != nil { + return nil, err + } + return fromBlockFilter(filter), nil +} + +func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) { + var filters []*gtsmodel.HeaderFilterAllow + err := h.db.NewSelect(). + Model(&filters). + Scan(ctx, &filters) + return fromAllowFilters(filters), err +} + +func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) { + var filters []*gtsmodel.HeaderFilterBlock + err := h.db.NewSelect(). + Model(&filters). + Scan(ctx, &filters) + return fromBlockFilters(filters), err +} + +func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error { + if _, err := h.db.NewInsert(). + Model(toAllowFilter(filter)). + Exec(ctx); err != nil { + return err + } + h.state.Caches.AllowHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error { + if _, err := h.db.NewInsert(). + Model(toBlockFilter(filter)). + Exec(ctx); err != nil { + return err + } + h.state.Caches.BlockHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error { + filter.UpdatedAt = time.Now() + if len(cols) > 0 { + // If we're updating by column, + // ensure "updated_at" is included. + cols = append(cols, "updated_at") + } + if _, err := h.db.NewUpdate(). + Model(toAllowFilter(filter)). + Column(cols...). + Where("? = ?", bun.Ident("id"), filter.ID). + Exec(ctx); err != nil { + return err + } + h.state.Caches.AllowHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error { + filter.UpdatedAt = time.Now() + if len(cols) > 0 { + // If we're updating by column, + // ensure "updated_at" is included. + cols = append(cols, "updated_at") + } + if _, err := h.db.NewUpdate(). + Model(toBlockFilter(filter)). + Column(cols...). + Where("? = ?", bun.Ident("id"), filter.ID). + Exec(ctx); err != nil { + return err + } + h.state.Caches.BlockHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error { + if _, err := h.db.NewDelete(). + Table("header_filter_allows"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + h.state.Caches.AllowHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error { + if _, err := h.db.NewDelete(). + Table("header_filter_blocks"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + h.state.Caches.BlockHeaderFilters.Clear() + return nil +} + +// NOTE: +// all of the below unsafe cast functions +// are only possible because HeaderFilterAllow{}, +// HeaderFilterBlock{}, HeaderFilter{} while +// different types in source, have exactly the +// same size and layout in memory. the unsafe +// cast simply changes the type associated with +// that block of memory. + +func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow { + return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter)) +} + +func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock { + return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter)) +} + +func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter { + return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter)) +} + +func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter { + return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter)) +} + +func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter { + return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters)) +} + +func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter { + return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters)) +} diff --git a/internal/db/bundb/headerfilter_test.go b/internal/db/bundb/headerfilter_test.go new file mode 100644 index 000000000..d7e2b26ee --- /dev/null +++ b/internal/db/bundb/headerfilter_test.go @@ -0,0 +1,125 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type HeaderFilterTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() { + suite.testHeaderFilterGetPutUpdateDelete( + suite.db.GetAllowHeaderFilter, + suite.db.GetAllowHeaderFilters, + suite.db.PutAllowHeaderFilter, + suite.db.UpdateAllowHeaderFilter, + suite.db.DeleteAllowHeaderFilter, + ) +} + +func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() { + suite.testHeaderFilterGetPutUpdateDelete( + suite.db.GetBlockHeaderFilter, + suite.db.GetBlockHeaderFilters, + suite.db.PutBlockHeaderFilter, + suite.db.UpdateBlockHeaderFilter, + suite.db.DeleteBlockHeaderFilter, + ) +} + +func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete( + get func(context.Context, string) (*gtsmodel.HeaderFilter, error), + getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error), + put func(context.Context, *gtsmodel.HeaderFilter) error, + update func(context.Context, *gtsmodel.HeaderFilter, ...string) error, + delete func(context.Context, string) error, +) { + t := suite.T() + + // Create new example header filter. + filter := gtsmodel.HeaderFilter{ + ID: "some unique id", + Header: "Http-Header-Key", + Regex: ".*", + AuthorID: "some unique author id", + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the example header filter into db. + if err := put(ctx, &filter); err != nil { + t.Fatalf("error inserting header filter: %v", err) + } + + // Now fetch newly created filter. + check, err := get(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching header filter: %v", err) + } + + // Check all expected fields match. + suite.Equal(filter.ID, check.ID) + suite.Equal(filter.Header, check.Header) + suite.Equal(filter.Regex, check.Regex) + suite.Equal(filter.AuthorID, check.AuthorID) + + // Fetch all header filters. + all, err := getAll(ctx) + if err != nil { + t.Fatalf("error fetching header filters: %v", err) + } + + // Ensure contains example. + suite.Equal(len(all), 1) + suite.Equal(all[0].ID, filter.ID) + + // Update the header filter regex value. + check.Regex = "new regex value" + if err := update(ctx, check); err != nil { + t.Fatalf("error updating header filter: %v", err) + } + + // Ensure 'updated_at' was updated on check model. + suite.True(check.UpdatedAt.After(filter.UpdatedAt)) + + // Now delete the header filter from db. + if err := delete(ctx, filter.ID); err != nil { + t.Fatalf("error deleting header filter: %v", err) + } + + // Ensure we can't refetch it. + _, err = get(ctx, filter.ID) + if err != db.ErrNoEntries { + t.Fatalf("deleted header filter returned unexpected error: %v", err) + } +} + +func TestHeaderFilterTestSuite(t *testing.T) { + suite.Run(t, new(HeaderFilterTestSuite)) +} diff --git a/internal/db/bundb/migrations/20231212144715_add_header_filters.go b/internal/db/bundb/migrations/20231212144715_add_header_filters.go new file mode 100644 index 000000000..2d671bf46 --- /dev/null +++ b/internal/db/bundb/migrations/20231212144715_add_header_filters.go @@ -0,0 +1,54 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + for _, model := range []any{ + >smodel.HeaderFilterAllow{}, + >smodel.HeaderFilterBlock{}, + } { + _, err := db.NewCreateTable(). + IfNotExists(). + Model(model). + Exec(ctx) + if err != nil { + return err + } + } + + return nil + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/db.go b/internal/db/db.go index 2914d9b59..361687e94 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -30,6 +30,7 @@ type DB interface { Basic Domain Emoji + HeaderFilter Instance List Marker diff --git a/internal/db/headerfilter.go b/internal/db/headerfilter.go new file mode 100644 index 000000000..5fe8a5b17 --- /dev/null +++ b/internal/db/headerfilter.go @@ -0,0 +1,73 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( + "context" + "net/http" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type HeaderFilter interface { + // AllowHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached allow header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // AllowHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached allow header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // BlockHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached block header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // BlockHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached block header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // GetAllowHeaderFilter fetches the allow header filter with ID from the database. + GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) + + // GetBlockHeaderFilter fetches the block header filter with ID from the database. + GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) + + // GetAllowHeaderFilters fetches all allow header filters from the database. + GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) + + // GetBlockHeaderFilters fetches all block header filters from the database. + GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) + + // PutAllowHeaderFilter inserts the given allow header filter into the database. + PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error + + // PutBlockHeaderFilter inserts the given block header filter into the database. + PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error + + // UpdateAllowHeaderFilter updates the given allow header filter in the database, only updating given columns if provided. + UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error + + // UpdateBlockHeaderFilter updates the given block header filter in the database, only updating given columns if provided. + UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error + + // DeleteAllowHeaderFilter deletes the allow header filter with ID from the database. + DeleteAllowHeaderFilter(ctx context.Context, id string) error + + // DeleteBlockHeaderFilter deletes the block header filter with ID from the database. + DeleteBlockHeaderFilter(ctx context.Context, id string) error +} diff --git a/internal/gtsmodel/headerfilter.go b/internal/gtsmodel/headerfilter.go new file mode 100644 index 000000000..d1fcb146e --- /dev/null +++ b/internal/gtsmodel/headerfilter.go @@ -0,0 +1,54 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package gtsmodel + +import ( + "time" + "unsafe" +) + +func init() { + // Note that since all of the below calculations are + // constant, these should be optimized out of builds. + const filterSz = unsafe.Sizeof(HeaderFilter{}) + if unsafe.Sizeof(HeaderFilterAllow{}) != filterSz { + panic("HeaderFilterAllow{} needs to have the same in-memory size / layout as HeaderFilter{}") + } + if unsafe.Sizeof(HeaderFilterBlock{}) != filterSz { + panic("HeaderFilterBlock{} needs to have the same in-memory size / layout as HeaderFilter{}") + } +} + +// HeaderFilterAllow represents an allow HTTP header filter in the database. +type HeaderFilterAllow struct{ HeaderFilter } + +// HeaderFilterBlock represents a block HTTP header filter in the database. +type HeaderFilterBlock struct{ HeaderFilter } + +// HeaderFilter represents an HTTP request filter in +// the database, with a header to match against, value +// matching regex, and details about its creation. +type HeaderFilter struct { + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // ID of this item in the database + Header string `bun:",nullzero,notnull,unique:header_regex"` // Request header this filter pertains to + Regex string `bun:",nullzero,notnull,unique:header_regex"` // Request header value matching regular expression + AuthorID string `bun:"type:CHAR(26),nullzero,notnull"` // Account ID of the creator of this filter + Author *Account `bun:"-"` // Account corresponding to AuthorID + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated +} diff --git a/internal/headerfilter/filter.go b/internal/headerfilter/filter.go new file mode 100644 index 000000000..ab2aa914c --- /dev/null +++ b/internal/headerfilter/filter.go @@ -0,0 +1,136 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package headerfilter + +import ( + "errors" + "fmt" + "net/http" + "net/textproto" + "regexp" +) + +// Maximum header value size before we return +// an instant negative match. They shouldn't +// go beyond this size in most cases anywho. +const MaxHeaderValue = 1024 + +// ErrLargeHeaderValue is returned on attempting to match on a value > MaxHeaderValue. +var ErrLargeHeaderValue = errors.New("header value too large") + +// Filters represents a set of http.Header regular +// expression filters built-in statistic tracking. +type Filters []headerfilter + +type headerfilter struct { + // key is the header key to match against + // in canonical textproto mime header format. + key string + + // exprs contains regular expressions to + // match values against for this header key. + exprs []*regexp.Regexp +} + +// Append will add new header filter expression under given header key. +func (fs *Filters) Append(key string, expr string) error { + var filter *headerfilter + + // Ensure in canonical mime header format. + key = textproto.CanonicalMIMEHeaderKey(key) + + // Look for existing filter + // with key in filter slice. + for i := range *fs { + if (*fs)[i].key == key { + filter = &((*fs)[i]) + break + } + } + + if filter == nil { + // No existing filter found, create new. + + // Append new header filter to slice. + (*fs) = append((*fs), headerfilter{}) + + // Then take ptr to this new filter + // at the last index in the slice. + filter = &((*fs)[len((*fs))-1]) + + // Setup new key. + filter.key = key + } + + // Compile regular expression. + reg, err := regexp.Compile(expr) + if err != nil { + return fmt.Errorf("error compiling regexp %q: %w", expr, err) + } + + // Append regular expression to filter. + filter.exprs = append(filter.exprs, reg) + + return nil +} + +// RegularMatch returns whether any values in http header +// matches any of the receiving filter regular expressions. +// This returns the matched header key, and matching regexp. +func (fs Filters) RegularMatch(h http.Header) (string, string, error) { + for _, filter := range fs { + for _, value := range h[filter.key] { + // Don't perform match on large values + // to mitigate denial of service attacks. + if len(value) > MaxHeaderValue { + return "", "", ErrLargeHeaderValue + } + + // Compare against regular exprs. + for _, expr := range filter.exprs { + if expr.MatchString(value) { + return filter.key, expr.String(), nil + } + } + } + } + return "", "", nil +} + +// InverseMatch returns whether any values in http header do +// NOT match any of the receiving filter regular expressions. +// This returns the matched header key, and matching regexp. +func (fs Filters) InverseMatch(h http.Header) (string, string, error) { + for _, filter := range fs { + for _, value := range h[filter.key] { + // Don't perform match on large values + // to mitigate denial of service attacks. + if len(value) > MaxHeaderValue { + return "", "", ErrLargeHeaderValue + } + + // Compare against regular exprs. + for _, expr := range filter.exprs { + if !expr.MatchString(value) { + return filter.key, expr.String(), nil + } + } + } + } + return "", "", nil +} diff --git a/internal/middleware/headerfilter.go b/internal/middleware/headerfilter.go new file mode 100644 index 000000000..18c9d1e67 --- /dev/null +++ b/internal/middleware/headerfilter.go @@ -0,0 +1,251 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package middleware + +import ( + "sync" + + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/headerfilter" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" +) + +var ( + allowMatches = matchstats{m: make(map[string]uint64)} + blockMatches = matchstats{m: make(map[string]uint64)} +) + +// matchstats is a simple statistics +// counter for header filter matches. +// TODO: replace with otel. +type matchstats struct { + m map[string]uint64 + l sync.Mutex +} + +func (m *matchstats) Add(hdr, regex string) { + m.l.Lock() + key := hdr + ":" + regex + m.m[key]++ + m.l.Unlock() +} + +// HeaderFilter returns a gin middleware handler that provides HTTP +// request blocking (filtering) based on database allow / block filters. +func HeaderFilter(state *state.State) gin.HandlerFunc { + switch mode := config.GetAdvancedHeaderFilterMode(); mode { + case config.RequestHeaderFilterModeDisabled: + return func(ctx *gin.Context) {} + + case config.RequestHeaderFilterModeAllow: + return headerFilterAllowMode(state) + + case config.RequestHeaderFilterModeBlock: + return headerFilterBlockMode(state) + + default: + panic("unrecognized filter mode: " + mode) + } +} + +func headerFilterAllowMode(state *state.State) func(c *gin.Context) { + _ = *state //nolint + // Allowlist mode: explicit block takes + // precedence over explicit allow. + // + // Headers that have neither block + // or allow entries are blocked. + return func(c *gin.Context) { + + // Check if header is explicitly blocked. + block, err := isHeaderBlocked(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if block { + respondBlocked(c) + return + } + + // Check if header is missing explicit allow. + notAllow, err := isHeaderNotAllowed(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if notAllow { + respondBlocked(c) + return + } + + // Allowed! + c.Next() + } +} + +func headerFilterBlockMode(state *state.State) func(c *gin.Context) { + _ = *state //nolint + // Blocklist/default mode: explicit allow + // takes precedence over explicit block. + // + // Headers that have neither block + // or allow entries are allowed. + return func(c *gin.Context) { + + // Check if header is explicitly allowed. + allow, err := isHeaderAllowed(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if !allow { + // Check if header is explicitly blocked. + block, err := isHeaderBlocked(state, c) + if err != nil { + respondInternalServerError(c, err) + return + } + + if block { + respondBlocked(c) + return + } + } + + // Allowed! + c.Next() + } +} + +func isHeaderBlocked(state *state.State, c *gin.Context) (bool, error) { + var ( + ctx = c.Request.Context() + hdr = c.Request.Header + ) + + // Perform an explicit is-blocked check on request header. + key, expr, err := state.DB.BlockHeaderRegularMatch(ctx, hdr) + switch err { + case nil: + break + + case headerfilter.ErrLargeHeaderValue: + log.Warn(ctx, "large header value") + key = "*" // block large headers + + default: + err := gtserror.Newf("error checking header: %w", err) + return false, err + } + + if key != "" { + if expr != "" { + // Increment block matches stat. + // TODO: replace expvar with build + // taggable metrics types in State{}. + blockMatches.Add(key, expr) + } + + // A header was matched against! + // i.e. this request is blocked. + return true, nil + } + + return false, nil +} + +func isHeaderAllowed(state *state.State, c *gin.Context) (bool, error) { + var ( + ctx = c.Request.Context() + hdr = c.Request.Header + ) + + // Perform an explicit is-allowed check on request header. + key, expr, err := state.DB.AllowHeaderRegularMatch(ctx, hdr) + switch err { + case nil: + break + + case headerfilter.ErrLargeHeaderValue: + log.Warn(ctx, "large header value") + key = "" // block large headers + + default: + err := gtserror.Newf("error checking header: %w", err) + return false, err + } + + if key != "" { + if expr != "" { + // Increment allow matches stat. + // TODO: replace expvar with build + // taggable metrics types in State{}. + allowMatches.Add(key, expr) + } + + // A header was matched against! + // i.e. this request is allowed. + return true, nil + } + + return false, nil +} + +func isHeaderNotAllowed(state *state.State, c *gin.Context) (bool, error) { + var ( + ctx = c.Request.Context() + hdr = c.Request.Header + ) + + // Perform an explicit is-NOT-allowed check on request header. + key, expr, err := state.DB.AllowHeaderInverseMatch(ctx, hdr) + switch err { + case nil: + break + + case headerfilter.ErrLargeHeaderValue: + log.Warn(ctx, "large header value") + key = "*" // block large headers + + default: + err := gtserror.Newf("error checking header: %w", err) + return false, err + } + + if key != "" { + if expr != "" { + // Increment allow matches stat. + // TODO: replace expvar with build + // taggable metrics types in State{}. + allowMatches.Add(key, expr) + } + + // A header was matched against! + // i.e. request is NOT allowed. + return true, nil + } + + return false, nil +} diff --git a/internal/middleware/headerfilter_test.go b/internal/middleware/headerfilter_test.go new file mode 100644 index 000000000..a28644153 --- /dev/null +++ b/internal/middleware/headerfilter_test.go @@ -0,0 +1,299 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package middleware_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/headerfilter" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/middleware" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func TestHeaderFilter(t *testing.T) { + testrig.InitTestLog() + testrig.InitTestConfig() + + for _, test := range []struct { + mode string + allow []filter + block []filter + input http.Header + expect bool + }{ + { + // Allow mode with expected 200 OK. + mode: config.RequestHeaderFilterModeAllow, + allow: []filter{ + {"User-Agent", ".*Firefox.*"}, + }, + block: []filter{}, + input: http.Header{ + "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"}, + }, + expect: true, + }, + { + // Allow mode with expected 403 Forbidden. + mode: config.RequestHeaderFilterModeAllow, + allow: []filter{ + {"User-Agent", ".*Firefox.*"}, + }, + block: []filter{}, + input: http.Header{ + "User-Agent": []string{"Chromium v169.42; Extra Tracking Info"}, + }, + expect: false, + }, + { + // Allow mode with too long header value expecting 403 Forbidden. + mode: config.RequestHeaderFilterModeAllow, + allow: []filter{ + {"User-Agent", ".*"}, + }, + block: []filter{}, + input: http.Header{ + "User-Agent": []string{func() string { + var buf strings.Builder + for i := 0; i < headerfilter.MaxHeaderValue+1; i++ { + buf.WriteByte(' ') + } + return buf.String() + }()}, + }, + expect: false, + }, + { + // Allow mode with explicit block expecting 403 Forbidden. + mode: config.RequestHeaderFilterModeAllow, + allow: []filter{ + {"User-Agent", ".*Firefox.*"}, + }, + block: []filter{ + {"User-Agent", ".*Firefox v169\\.42.*"}, + }, + input: http.Header{ + "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"}, + }, + expect: false, + }, + { + // Block mode with an expected 403 Forbidden. + mode: config.RequestHeaderFilterModeBlock, + allow: []filter{}, + block: []filter{ + {"User-Agent", ".*Firefox.*"}, + }, + input: http.Header{ + "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"}, + }, + expect: false, + }, + { + // Block mode with an expected 200 OK. + mode: config.RequestHeaderFilterModeBlock, + allow: []filter{}, + block: []filter{ + {"User-Agent", ".*Firefox.*"}, + }, + input: http.Header{ + "User-Agent": []string{"Chromium v169.42; Extra Tracking Info"}, + }, + expect: true, + }, + { + // Block mode with too long header value expecting 403 Forbidden. + mode: config.RequestHeaderFilterModeBlock, + allow: []filter{}, + block: []filter{ + {"User-Agent", "none"}, + }, + input: http.Header{ + "User-Agent": []string{func() string { + var buf strings.Builder + for i := 0; i < headerfilter.MaxHeaderValue+1; i++ { + buf.WriteByte(' ') + } + return buf.String() + }()}, + }, + expect: false, + }, + { + // Block mode with explicit allow expecting 200 OK. + mode: config.RequestHeaderFilterModeBlock, + allow: []filter{ + {"User-Agent", ".*Firefox.*"}, + }, + block: []filter{ + {"User-Agent", ".*Firefox v169\\.42.*"}, + }, + input: http.Header{ + "User-Agent": []string{"Firefox v169.42; Extra Tracking Info"}, + }, + expect: true, + }, + { + // Disabled mode with an expected 200 OK. + mode: config.RequestHeaderFilterModeDisabled, + allow: []filter{ + {"Key1", "only-this"}, + {"Key2", "only-this"}, + {"Key3", "only-this"}, + }, + block: []filter{ + {"Key1", "Value"}, + {"Key2", "Value"}, + {"Key3", "Value"}, + }, + input: http.Header{ + "Key1": []string{"Value"}, + "Key2": []string{"Value"}, + "Key3": []string{"Value"}, + }, + expect: true, + }, + } { + // Generate a unique name for this test case. + name := fmt.Sprintf("%s allow=%v block=%v => expect=%v", + test.mode, + test.allow, + test.block, + test.expect, + ) + + // Update header filter mode to test case. + config.SetAdvancedHeaderFilterMode(test.mode) + + // Run this particular test case. + ok := t.Run(name, func(t *testing.T) { + testHeaderFilter(t, + test.allow, + test.block, + test.input, + test.expect, + ) + }) + + if !ok { + return + } + } +} + +func testHeaderFilter(t *testing.T, allow, block []filter, input http.Header, expect bool) { + var err error + + // Create test context with cancel. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Initialize caches. + var state state.State + state.Caches.Init() + + // Create new database instance with test config. + state.DB, err = bundb.NewBunDBService(ctx, &state) + if err != nil { + t.Fatalf("error opening database: %v", err) + } + + // Insert all allow filters into DB. + for _, filter := range allow { + filter := >smodel.HeaderFilter{ + ID: id.NewULID(), + Header: filter.header, + Regex: filter.regex, + AuthorID: "admin-id", + Author: nil, + } + + if err := state.DB.PutAllowHeaderFilter(ctx, filter); err != nil { + t.Fatalf("error inserting allow filter into database: %v", err) + } + } + + // Insert all block filters into DB. + for _, filter := range block { + filter := >smodel.HeaderFilter{ + ID: id.NewULID(), + Header: filter.header, + Regex: filter.regex, + AuthorID: "admin-id", + Author: nil, + } + + if err := state.DB.PutBlockHeaderFilter(ctx, filter); err != nil { + t.Fatalf("error inserting block filter into database: %v", err) + } + } + + // Gin test http engine + // (used for ctx init). + e := gin.New() + + // Create new filter middleware to test against. + middleware := middleware.HeaderFilter(&state) + e.Use(middleware) + + // Set the empty gin handler (always returns okay). + e.Handle("GET", "/", func(ctx *gin.Context) { ctx.Status(200) }) + + // Prepare a gin test context. + r := httptest.NewRequest("GET", "/", nil) + rw := httptest.NewRecorder() + + // Set input headers. + r.Header = input + + // Pass req through + // engine handler. + e.ServeHTTP(rw, r) + + // Get http result. + res := rw.Result() + + switch { + case expect && res.StatusCode != http.StatusOK: + t.Errorf("unexpected response (should allow): %s", res.Status) + + case !expect && res.StatusCode != http.StatusForbidden: + t.Errorf("unexpected response (should block): %s", res.Status) + } +} + +type filter struct { + header string + regex string +} + +func (hf *filter) String() string { + return fmt.Sprintf("%s=%q", hf.header, hf.regex) +} diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index 57055fe70..352a30c22 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -146,7 +146,7 @@ func RateLimit(limit int, exceptions []string) gin.HandlerFunc { apiutil.Data(c, http.StatusTooManyRequests, apiutil.AppJSON, - apiutil.ErrorRateLimitReached, + apiutil.ErrorRateLimited, ) c.Abort() return diff --git a/internal/middleware/useragent.go b/internal/middleware/useragent.go index 6dc3e401f..38d28f4e5 100644 --- a/internal/middleware/useragent.go +++ b/internal/middleware/useragent.go @@ -18,21 +18,22 @@ package middleware import ( - "errors" "net/http" "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" ) // UserAgent returns a gin middleware which aborts requests with // empty user agent strings, returning code 418 - I'm a teapot. func UserAgent() gin.HandlerFunc { // todo: make this configurable + var rsp = []byte(`{"error": "I'm a teapot: no user-agent sent with request"}`) return func(c *gin.Context) { if ua := c.Request.UserAgent(); ua == "" { - code := http.StatusTeapot - err := errors.New(http.StatusText(code) + ": no user-agent sent with request") - c.AbortWithStatusJSON(code, gin.H{"error": err.Error()}) + apiutil.Data(c, + http.StatusTeapot, apiutil.AppJSON, rsp) + c.Abort() } } } diff --git a/internal/middleware/util.go b/internal/middleware/util.go new file mode 100644 index 000000000..82850fd6d --- /dev/null +++ b/internal/middleware/util.go @@ -0,0 +1,51 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" +) + +// respondBlocked responds to the given gin context with +// status forbidden, and a generic forbidden JSON response, +// finally aborting the gin handler chain. +func respondBlocked(c *gin.Context) { + apiutil.Data(c, + http.StatusForbidden, + apiutil.AppJSON, + apiutil.StatusForbiddenJSON, + ) + c.Abort() +} + +// respondInternalServerError responds to the given gin context +// with status internal server error, a generic internal server +// error JSON response, sets the given error on the gin context +// for later logging, finally aborting the gin handler chain. +func respondInternalServerError(c *gin.Context, err error) { + apiutil.Data(c, + http.StatusInternalServerError, + apiutil.AppJSON, + apiutil.StatusInternalServerErrorJSON, + ) + _ = c.Error(err) + c.Abort() +} diff --git a/internal/processing/admin/headerfilter.go b/internal/processing/admin/headerfilter.go new file mode 100644 index 000000000..13105d191 --- /dev/null +++ b/internal/processing/admin/headerfilter.go @@ -0,0 +1,215 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package admin + +import ( + "context" + "errors" + "net/textproto" + "regexp" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/headerfilter" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// GetAllowHeaderFilter fetches allow HTTP header filter with provided ID from the database. +func (p *Processor) GetAllowHeaderFilter(ctx context.Context, id string) (*apimodel.HeaderFilter, gtserror.WithCode) { + return p.getHeaderFilter(ctx, id, p.state.DB.GetAllowHeaderFilter) +} + +// GetBlockHeaderFilter fetches block HTTP header filter with provided ID from the database. +func (p *Processor) GetBlockHeaderFilter(ctx context.Context, id string) (*apimodel.HeaderFilter, gtserror.WithCode) { + return p.getHeaderFilter(ctx, id, p.state.DB.GetBlockHeaderFilter) +} + +// GetAllowHeaderFilters fetches all allow HTTP header filters stored in the database. +func (p *Processor) GetAllowHeaderFilters(ctx context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode) { + return p.getHeaderFilters(ctx, p.state.DB.GetAllowHeaderFilters) +} + +// GetBlockHeaderFilters fetches all block HTTP header filters stored in the database. +func (p *Processor) GetBlockHeaderFilters(ctx context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode) { + return p.getHeaderFilters(ctx, p.state.DB.GetBlockHeaderFilters) +} + +// CreateAllowHeaderFilter inserts the incoming allow HTTP header filter into the database, marking as authored by provided admin account. +func (p *Processor) CreateAllowHeaderFilter(ctx context.Context, admin *gtsmodel.Account, request *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode) { + return p.createHeaderFilter(ctx, admin, request, p.state.DB.PutAllowHeaderFilter) +} + +// CreateBlockHeaderFilter inserts the incoming block HTTP header filter into the database, marking as authored by provided admin account. +func (p *Processor) CreateBlockHeaderFilter(ctx context.Context, admin *gtsmodel.Account, request *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode) { + return p.createHeaderFilter(ctx, admin, request, p.state.DB.PutBlockHeaderFilter) +} + +// DeleteAllowHeaderFilter deletes the allowing HTTP header filter with provided ID from the database. +func (p *Processor) DeleteAllowHeaderFilter(ctx context.Context, id string) gtserror.WithCode { + return p.deleteHeaderFilter(ctx, id, p.state.DB.DeleteAllowHeaderFilter) +} + +// DeleteBlockHeaderFilter deletes the blocking HTTP header filter with provided ID from the database. +func (p *Processor) DeleteBlockHeaderFilter(ctx context.Context, id string) gtserror.WithCode { + return p.deleteHeaderFilter(ctx, id, p.state.DB.DeleteBlockHeaderFilter) +} + +// getHeaderFilter fetches an HTTP header filter with +// provided ID, using given get function, converting the +// resulting filter to returnable frontend API model. +func (p *Processor) getHeaderFilter( + ctx context.Context, + id string, + get func(context.Context, string) (*gtsmodel.HeaderFilter, error), +) ( + *apimodel.HeaderFilter, + gtserror.WithCode, +) { + // Select filter by ID from db. + filter, err := get(ctx, id) + + switch { + // Successfully found. + case err == nil: + return toAPIHeaderFilter(filter), nil + + // Filter does not exist with ID. + case errors.Is(err, db.ErrNoEntries): + const text = "filter not found" + return nil, gtserror.NewErrorNotFound(errors.New(text), text) + + // Any other error type. + default: + err := gtserror.Newf("error selecting from database: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } +} + +// getHeaderFilters fetches all HTTP header filters +// using given get function, converting the resulting +// filters to returnable frontend API models. +func (p *Processor) getHeaderFilters( + ctx context.Context, + get func(context.Context) ([]*gtsmodel.HeaderFilter, error), +) ( + []*apimodel.HeaderFilter, + gtserror.WithCode, +) { + // Select all filters from DB. + filters, err := get(ctx) + + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Only handle errors other than not-found types. + err := gtserror.Newf("error selecting from database: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // Convert passed header filters to apimodel filters. + apiFilters := make([]*apimodel.HeaderFilter, len(filters)) + for i := range filters { + apiFilters[i] = toAPIHeaderFilter(filters[i]) + } + + return apiFilters, nil +} + +// createHeaderFilter inserts the given HTTP header +// filter into database, marking as authored by the +// provided admin, using the given insert function. +func (p *Processor) createHeaderFilter( + ctx context.Context, + admin *gtsmodel.Account, + request *apimodel.HeaderFilterRequest, + insert func(context.Context, *gtsmodel.HeaderFilter) error, +) ( + *apimodel.HeaderFilter, + gtserror.WithCode, +) { + // Convert header key to canonical mime header format. + request.Header = textproto.CanonicalMIMEHeaderKey(request.Header) + + // Validate incoming header filter. + if errWithCode := validateHeaderFilter( + request.Header, + request.Regex, + ); errWithCode != nil { + return nil, errWithCode + } + + // Create new database model with ID. + var filter gtsmodel.HeaderFilter + filter.ID = id.NewULID() + filter.Header = request.Header + filter.Regex = request.Regex + filter.AuthorID = admin.ID + filter.Author = admin + + // Insert new header filter into the database. + if err := insert(ctx, &filter); err != nil { + err := gtserror.Newf("error inserting into database: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // Finally return API model response. + return toAPIHeaderFilter(&filter), nil +} + +// deleteHeaderFilter deletes the HTTP header filter +// with provided ID, using the given delete function. +func (p *Processor) deleteHeaderFilter( + ctx context.Context, + id string, + delete func(context.Context, string) error, +) gtserror.WithCode { + if err := delete(ctx, id); err != nil && !errors.Is(err, db.ErrNoEntries) { + err := gtserror.Newf("error deleting from database: %w", err) + return gtserror.NewErrorInternalError(err) + } + return nil +} + +// toAPIFilter performs a simple conversion of database model HeaderFilter to API model. +func toAPIHeaderFilter(filter *gtsmodel.HeaderFilter) *apimodel.HeaderFilter { + return &apimodel.HeaderFilter{ + ID: filter.ID, + Header: filter.Header, + Regex: filter.Regex, + CreatedBy: filter.AuthorID, + CreatedAt: util.FormatISO8601(filter.CreatedAt), + } +} + +// validateHeaderFilter validates incoming filter's header key, and regular expression. +func validateHeaderFilter(header, regex string) gtserror.WithCode { + // Check header validity (within our own bound checks). + if header == "" || len(header) > headerfilter.MaxHeaderValue { + const text = "invalid request header key (empty or too long)" + return gtserror.NewErrorBadRequest(errors.New(text), text) + } + + // Ensure this is compilable regex. + _, err := regexp.Compile(regex) + if err != nil { + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + return nil +} |