diff options
author | 2024-03-06 02:15:58 -0800 | |
---|---|---|
committer | 2024-03-06 11:15:58 +0100 | |
commit | 61a2b91f454a6eb0dd383fc8614fee154654fa08 (patch) | |
tree | fcf6159f00c3a0833e6647dd00cd03d03774e2b2 /internal | |
parent | [chore]: Bump github.com/stretchr/testify from 1.8.4 to 1.9.0 (#2714) (diff) | |
download | gotosocial-61a2b91f454a6eb0dd383fc8614fee154654fa08.tar.xz |
[feature] Filters v1 (#2594)
* Implement client-side v1 filters
* Exclude linter false positives
* Update test/envparsing.sh
* Fix minor Swagger, style, and Bun usage issues
* Regenerate Swagger
* De-generify filter keywords
* Remove updating filter statuses
This is an operation that the Mastodon v2 filter API doesn't actually have, because filter statuses, unlike keywords, don't have options: the only info they contain is the status ID to be filtered.
* Add a test for filter statuses specifically
* De-generify filter statuses
* Inline FilterEntry
* Use vertical style for Bun operations consistently
* Add comment on Filter DB interface
* Remove GoLand linter control comments
Our existing linters should catch these, or they don't matter very much
* Reduce memory ratio for filters
Diffstat (limited to 'internal')
44 files changed, 4259 insertions, 52 deletions
diff --git a/internal/api/client.go b/internal/api/client.go index 1112efa31..d41add017 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -29,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/api/client/customemojis" "github.com/superseriousbusiness/gotosocial/internal/api/client/favourites" "github.com/superseriousbusiness/gotosocial/internal/api/client/featuredtags" - filter "github.com/superseriousbusiness/gotosocial/internal/api/client/filters" + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance" "github.com/superseriousbusiness/gotosocial/internal/api/client/lists" @@ -62,7 +62,7 @@ type Client struct { customEmojis *customemojis.Module // api/v1/custom_emojis favourites *favourites.Module // api/v1/favourites featuredTags *featuredtags.Module // api/v1/featured_tags - filters *filter.Module // api/v1/filters + filtersV1 *filtersV1.Module // api/v1/filters followRequests *followrequests.Module // api/v1/follow_requests instance *instance.Module // api/v1/instance lists *lists.Module // api/v1/lists @@ -104,7 +104,7 @@ func (c *Client) Route(r *router.Router, m ...gin.HandlerFunc) { c.customEmojis.Route(h) c.favourites.Route(h) c.featuredTags.Route(h) - c.filters.Route(h) + c.filtersV1.Route(h) c.followRequests.Route(h) c.instance.Route(h) c.lists.Route(h) @@ -134,7 +134,7 @@ func NewClient(db db.DB, p *processing.Processor) *Client { customEmojis: customemojis.New(p), favourites: favourites.New(p), featuredTags: featuredtags.New(p), - filters: filter.New(p), + filtersV1: filtersV1.New(p), followRequests: followrequests.New(p), instance: instance.New(p), lists: lists.New(p), diff --git a/internal/api/client/filters/filter.go b/internal/api/client/filters/v1/filter.go index 68c99e825..9daeb75d3 100644 --- a/internal/api/client/filters/filter.go +++ b/internal/api/client/filters/v1/filter.go @@ -15,20 +15,23 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see <http://www.gnu.org/licenses/>. -package filter +package v1 import ( - "net/http" - "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/processing" + "net/http" ) const ( // BasePath is the base path for serving the filters API, minus the 'api' prefix BasePath = "/v1/filters" + // BasePathWithID is the base path with the ID key in it, for operations on an existing filter. + BasePathWithID = BasePath + "/:" + apiutil.IDKey ) +// Module implements APIs for client-side aka "v1" filtering. type Module struct { processor *processing.Processor } @@ -41,4 +44,8 @@ func New(processor *processing.Processor) *Module { func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { attachHandler(http.MethodGet, BasePath, m.FiltersGETHandler) + attachHandler(http.MethodPost, BasePath, m.FilterPOSTHandler) + attachHandler(http.MethodGet, BasePathWithID, m.FilterGETHandler) + attachHandler(http.MethodPut, BasePathWithID, m.FilterPUTHandler) + attachHandler(http.MethodDelete, BasePathWithID, m.FilterDELETEHandler) } diff --git a/internal/api/client/filters/v1/filter_test.go b/internal/api/client/filters/v1/filter_test.go new file mode 100644 index 000000000..c92e22a05 --- /dev/null +++ b/internal/api/client/filters/v1/filter_test.go @@ -0,0 +1,117 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1_test + +import ( + "github.com/stretchr/testify/suite" + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/testrig" + "testing" +) + +type FiltersTestSuite struct { + suite.Suite + db db.DB + storage *storage.Driver + mediaManager *media.Manager + federator *federation.Federator + processor *processing.Processor + emailSender email.Sender + sentEmails map[string]string + state state.State + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testStatuses map[string]*gtsmodel.Status + testFilters map[string]*gtsmodel.Filter + testFilterKeywords map[string]*gtsmodel.FilterKeyword + testFilterStatuses map[string]*gtsmodel.FilterStatus + + // module being tested + filtersModule *filtersV1.Module +} + +func (suite *FiltersTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testStatuses = testrig.NewTestStatuses() + suite.testFilters = testrig.NewTestFilters() + suite.testFilterKeywords = testrig.NewTestFilterKeywords() + suite.testFilterStatuses = testrig.NewTestFilterStatuses() +} + +func (suite *FiltersTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartNoopWorkers(&suite.state) + + testrig.InitTestConfig() + config.Config(func(cfg *config.Configuration) { + cfg.WebAssetBaseDir = "../../../../../web/assets/" + cfg.WebTemplateBaseDir = "../../../../../web/templates/" + }) + testrig.InitTestLog() + + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db + suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + typeutils.NewConverter(&suite.state), + ) + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../../testrig/media")), suite.mediaManager) + suite.sentEmails = make(map[string]string) + suite.emailSender = testrig.NewEmailSender("../../../../../web/template/", suite.sentEmails) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) + suite.filtersModule = filtersV1.New(suite.processor) + + testrig.StandardDBSetup(suite.db, nil) + testrig.StandardStorageSetup(suite.storage, "../../../../../testrig/media") +} + +func (suite *FiltersTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) + testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) +} + +func TestFiltersTestSuite(t *testing.T) { + suite.Run(t, new(FiltersTestSuite)) +} diff --git a/internal/api/client/filters/v1/filterdelete.go b/internal/api/client/filters/v1/filterdelete.go new file mode 100644 index 000000000..d86b277a6 --- /dev/null +++ b/internal/api/client/filters/v1/filterdelete.go @@ -0,0 +1,90 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterDELETEHandler swagger:operation DELETE /api/v1/filters/{id} filterV1Delete +// +// Delete a single filter with the given ID. +// +// --- +// tags: +// - filters +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - write:filters +// +// responses: +// '200': +// description: filter deleted +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) FilterDELETEHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + errWithCode = m.processor.FiltersV1().Delete(c.Request.Context(), authed.Account, id) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiutil.EmptyJSONObject) +} diff --git a/internal/api/client/filters/v1/filterdelete_test.go b/internal/api/client/filters/v1/filterdelete_test.go new file mode 100644 index 000000000..83155f08a --- /dev/null +++ b/internal/api/client/filters/v1/filterdelete_test.go @@ -0,0 +1,112 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) deleteFilter( + filterKeywordID string, + expectedHTTPStatus int, + expectedBody string, +) error { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodDelete, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil) + ctx.Request.Header.Set("accept", "application/json") + + ctx.AddParam("id", filterKeywordID) + + // trigger the handler + suite.filtersModule.FilterDELETEHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return errs.Combine() + } + + resp := &struct{}{} + if err := json.Unmarshal(b, resp); err != nil { + return err + } + + return nil +} + +func (suite *FiltersTestSuite) TestDeleteFilter() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + + err := suite.deleteFilter(id, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestDeleteAnotherAccountsFilter() { + id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID + + err := suite.deleteFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestDeleteNonexistentFilter() { + id := "not_even_a_real_ULID" + + err := suite.deleteFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/v1/filterget.go b/internal/api/client/filters/v1/filterget.go new file mode 100644 index 000000000..35c44b60c --- /dev/null +++ b/internal/api/client/filters/v1/filterget.go @@ -0,0 +1,93 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterGETHandler swagger:operation GET /api/v1/filters/{id} filterV1Get +// +// Get a single filter with the given ID. +// +// --- +// tags: +// - filters +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the filter +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - read:filters +// +// responses: +// '200': +// name: filter +// description: Requested filter. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) FilterGETHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiFilter, errWithCode := m.processor.FiltersV1().Get(c.Request.Context(), authed.Account, id) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiFilter) +} diff --git a/internal/api/client/filters/v1/filterget_test.go b/internal/api/client/filters/v1/filterget_test.go new file mode 100644 index 000000000..a9dbf6dbb --- /dev/null +++ b/internal/api/client/filters/v1/filterget_test.go @@ -0,0 +1,121 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) getFilter( + filterKeywordID string, + expectedHTTPStatus int, + expectedBody string, +) (*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodGet, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil) + ctx.Request.Header.Set("accept", "application/json") + + ctx.AddParam("id", filterKeywordID) + + // trigger the handler + suite.filtersModule.FilterGETHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := &apimodel.FilterV1{} + if err := json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestGetFilter() { + // v1 filters map to individual filter keywords, but also use the settings of the associated filter. + expectedFilterGtsModel := suite.testFilters["local_account_1_filter_1"] + expectedFilterKeywordGtsModel := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"] + + filter, err := suite.getFilter(expectedFilterKeywordGtsModel.ID, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.NotEmpty(filter) + suite.Equal(expectedFilterGtsModel.Action == gtsmodel.FilterActionHide, filter.Irreversible) + suite.Equal(expectedFilterKeywordGtsModel.ID, filter.ID) + suite.Equal(expectedFilterKeywordGtsModel.Keyword, filter.Phrase) +} + +func (suite *FiltersTestSuite) TestGetAnotherAccountsFilter() { + id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID + + _, err := suite.getFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestGetNonexistentFilter() { + id := "not_even_a_real_ULID" + + _, err := suite.getFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/v1/filterpost.go b/internal/api/client/filters/v1/filterpost.go new file mode 100644 index 000000000..b0a626199 --- /dev/null +++ b/internal/api/client/filters/v1/filterpost.go @@ -0,0 +1,147 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterPOSTHandler swagger:operation POST /api/v1/filters filterV1Post +// +// Create a single filter. +// +// --- +// tags: +// - filters +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// parameters: +// - +// name: phrase +// in: formData +// required: true +// description: The text to be filtered. +// maxLength: 40 +// type: string +// example: "fnord" +// - +// name: context +// in: formData +// required: true +// description: The contexts in which the filter should be applied. +// enum: +// - home +// - notifications +// - public +// - thread +// - account +// example: +// - home +// - public +// items: +// $ref: '#/definitions/filterContext' +// minLength: 1 +// type: array +// uniqueItems: true +// - +// name: expires_in +// in: formData +// description: Number of seconds from now that the filter should expire. If omitted, filter never expires. +// type: number +// example: 86400 +// - +// name: irreversible +// in: formData +// description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet. +// type: boolean +// default: false +// example: false +// - +// name: whole_word +// in: formData +// description: Should the filter consider word boundaries? +// type: boolean +// default: false +// example: true +// +// security: +// - OAuth2 Bearer: +// - write:filters +// +// responses: +// '200': +// name: filter +// description: New filter. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '422': +// description: unprocessable content +// '500': +// description: internal server error +func (m *Module) FilterPOSTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form := &apimodel.FilterCreateUpdateRequestV1{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if err := validateNormalizeCreateUpdateFilter(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1) + return + } + + apiFilter, errWithCode := m.processor.FiltersV1().Create(c.Request.Context(), authed.Account, form) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, apiFilter) +} diff --git a/internal/api/client/filters/v1/filterpost_test.go b/internal/api/client/filters/v1/filterpost_test.go new file mode 100644 index 000000000..729b2bd72 --- /dev/null +++ b/internal/api/client/filters/v1/filterpost_test.go @@ -0,0 +1,239 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) postFilter( + phrase *string, + context *[]string, + irreversible *bool, + wholeWord *bool, + expiresIn *int, + requestJson *string, + expectedHTTPStatus int, + expectedBody string, +) (*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodPost, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath, nil) + ctx.Request.Header.Set("accept", "application/json") + if requestJson != nil { + ctx.Request.Header.Set("content-type", "application/json") + ctx.Request.Body = io.NopCloser(strings.NewReader(*requestJson)) + } else { + ctx.Request.Form = make(url.Values) + if phrase != nil { + ctx.Request.Form["phrase"] = []string{*phrase} + } + if context != nil { + ctx.Request.Form["context[]"] = *context + } + if irreversible != nil { + ctx.Request.Form["irreversible"] = []string{strconv.FormatBool(*irreversible)} + } + if wholeWord != nil { + ctx.Request.Form["whole_word"] = []string{strconv.FormatBool(*wholeWord)} + } + if expiresIn != nil { + ctx.Request.Form["expires_in"] = []string{strconv.Itoa(*expiresIn)} + } + } + + // trigger the handler + suite.filtersModule.FilterPOSTHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := &apimodel.FilterV1{} + if err := json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestPostFilterFull() { + phrase := "GNU/Linux" + context := []string{"home", "public"} + irreversible := false + wholeWord := true + expiresIn := 86400 + filter, err := suite.postFilter(&phrase, &context, &irreversible, &wholeWord, &expiresIn, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.Equal(irreversible, filter.Irreversible) + suite.Equal(wholeWord, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPostFilterFullJSON() { + // Use a numeric literal with a fractional part to test the JSON-specific handling for non-integer "expires_in". + requestJson := `{ + "phrase":"GNU/Linux", + "context": ["home", "public"], + "irreversible": false, + "whole_word": true, + "expires_in": 86400.1 + }` + filter, err := suite.postFilter(nil, nil, nil, nil, nil, &requestJson, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal("GNU/Linux", filter.Phrase) + suite.ElementsMatch( + []apimodel.FilterContext{ + apimodel.FilterContextHome, + apimodel.FilterContextPublic, + }, + filter.Context, + ) + suite.Equal(false, filter.Irreversible) + suite.Equal(true, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPostFilterMinimal() { + phrase := "GNU/Linux" + context := []string{"home"} + filter, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.False(filter.Irreversible) + suite.False(filter.WholeWord) + suite.Nil(filter.ExpiresAt) +} + +func (suite *FiltersTestSuite) TestPostFilterEmptyPhrase() { + phrase := "" + context := []string{"home"} + _, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPostFilterMissingPhrase() { + context := []string{"home"} + _, err := suite.postFilter(nil, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPostFilterEmptyContext() { + phrase := "GNU/Linux" + context := []string{} + _, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPostFilterMissingContext() { + phrase := "GNU/Linux" + _, err := suite.postFilter(&phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// There should be a filter with this phrase as its title in our test fixtures. Creating another should fail. +func (suite *FiltersTestSuite) TestPostFilterTitleConflict() { + phrase := "fnord" + _, err := suite.postFilter(&phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// FUTURE: this should be removed once we support server-side filters. +func (suite *FiltersTestSuite) TestPostFilterIrreversibleNotSupported() { + phrase := "GNU/Linux" + context := []string{"home"} + irreversible := true + _, err := suite.postFilter(&phrase, &context, &irreversible, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/v1/filterput.go b/internal/api/client/filters/v1/filterput.go new file mode 100644 index 000000000..c686e4515 --- /dev/null +++ b/internal/api/client/filters/v1/filterput.go @@ -0,0 +1,159 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterPUTHandler swagger:operation PUT /api/v1/filters/{id} filterV1Put +// +// Update a single filter with the given ID. +// +// --- +// tags: +// - filters +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// in: path +// type: string +// required: true +// description: ID of the filter. +// - +// name: phrase +// in: formData +// required: true +// description: The text to be filtered. +// maxLength: 40 +// type: string +// example: "fnord" +// - +// name: context +// in: formData +// required: true +// description: The contexts in which the filter should be applied. +// enum: +// - home +// - notifications +// - public +// - thread +// - account +// example: +// - home +// - public +// items: +// $ref: '#/definitions/filterContext' +// minLength: 1 +// type: array +// uniqueItems: true +// - +// name: expires_in +// in: formData +// description: Number of seconds from now that the filter should expire. If omitted, filter never expires. +// type: number +// example: 86400 +// - +// name: irreversible +// in: formData +// description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet. +// type: boolean +// default: false +// example: false +// - +// name: whole_word +// in: formData +// description: Should the filter consider word boundaries? +// type: boolean +// default: false +// example: true +// +// security: +// - OAuth2 Bearer: +// - write:filters +// +// responses: +// '200': +// name: filter +// description: Updated filter. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '422': +// description: unprocessable content +// '500': +// description: internal server error +func (m *Module) FilterPUTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + form := &apimodel.FilterCreateUpdateRequestV1{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if err := validateNormalizeCreateUpdateFilter(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1) + return + } + + apiFilter, errWithCode := m.processor.FiltersV1().Update(c.Request.Context(), authed.Account, id, form) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, apiFilter) +} diff --git a/internal/api/client/filters/v1/filterput_test.go b/internal/api/client/filters/v1/filterput_test.go new file mode 100644 index 000000000..0308e53d9 --- /dev/null +++ b/internal/api/client/filters/v1/filterput_test.go @@ -0,0 +1,269 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) putFilter( + filterKeywordID string, + phrase *string, + context *[]string, + irreversible *bool, + wholeWord *bool, + expiresIn *int, + requestJson *string, + expectedHTTPStatus int, + expectedBody string, +) (*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodPut, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil) + ctx.Request.Header.Set("accept", "application/json") + if requestJson != nil { + ctx.Request.Header.Set("content-type", "application/json") + ctx.Request.Body = io.NopCloser(strings.NewReader(*requestJson)) + } else { + ctx.Request.Form = make(url.Values) + if phrase != nil { + ctx.Request.Form["phrase"] = []string{*phrase} + } + if context != nil { + ctx.Request.Form["context[]"] = *context + } + if irreversible != nil { + ctx.Request.Form["irreversible"] = []string{strconv.FormatBool(*irreversible)} + } + if wholeWord != nil { + ctx.Request.Form["whole_word"] = []string{strconv.FormatBool(*wholeWord)} + } + if expiresIn != nil { + ctx.Request.Form["expires_in"] = []string{strconv.Itoa(*expiresIn)} + } + } + + ctx.AddParam("id", filterKeywordID) + + // trigger the handler + suite.filtersModule.FilterPUTHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := &apimodel.FilterV1{} + if err := json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestPutFilterFull() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{"home", "public"} + irreversible := false + wholeWord := true + expiresIn := 86400 + filter, err := suite.putFilter(id, &phrase, &context, &irreversible, &wholeWord, &expiresIn, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.Equal(irreversible, filter.Irreversible) + suite.Equal(wholeWord, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPutFilterFullJSON() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + // Use a numeric literal with a fractional part to test the JSON-specific handling for non-integer "expires_in". + requestJson := `{ + "phrase":"GNU/Linux", + "context": ["home", "public"], + "irreversible": false, + "whole_word": true, + "expires_in": 86400.1 + }` + filter, err := suite.putFilter(id, nil, nil, nil, nil, nil, &requestJson, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal("GNU/Linux", filter.Phrase) + suite.ElementsMatch( + []apimodel.FilterContext{ + apimodel.FilterContextHome, + apimodel.FilterContextPublic, + }, + filter.Context, + ) + suite.Equal(false, filter.Irreversible) + suite.Equal(true, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPutFilterMinimal() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{"home"} + filter, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.False(filter.Irreversible) + suite.False(filter.WholeWord) + suite.Nil(filter.ExpiresAt) +} + +func (suite *FiltersTestSuite) TestPutFilterEmptyPhrase() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "" + context := []string{"home"} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutFilterMissingPhrase() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + context := []string{"home"} + _, err := suite.putFilter(id, nil, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutFilterEmptyContext() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutFilterMissingContext() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + _, err := suite.putFilter(id, &phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// There should be a filter with this phrase as its title in our test fixtures. Changing ours to that title should fail. +func (suite *FiltersTestSuite) TestPutFilterTitleConflict() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "metasyntactic variables" + _, err := suite.putFilter(id, &phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// FUTURE: this should be removed once we support server-side filters. +func (suite *FiltersTestSuite) TestPutFilterIrreversibleNotSupported() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + irreversible := true + _, err := suite.putFilter(id, nil, nil, &irreversible, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutAnotherAccountsFilter() { + id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{"home"} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutNonexistentFilter() { + id := "not_even_a_real_ULID" + phrase := "GNU/Linux" + context := []string{"home"} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/filtersget.go b/internal/api/client/filters/v1/filtersget.go index 38dd330a7..84d638676 100644 --- a/internal/api/client/filters/filtersget.go +++ b/internal/api/client/filters/v1/filtersget.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see <http://www.gnu.org/licenses/>. -package filter +package v1 import ( "net/http" @@ -26,9 +26,40 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -// FiltersGETHandler returns a list of filters set by/for the authed account +// FiltersGETHandler swagger:operation GET /api/v1/filters filtersV1Get +// +// Get all filters for the authenticated account. +// +// --- +// tags: +// - filters +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - read:filters +// +// responses: +// '200': +// name: filter +// description: Requested filters. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error func (m *Module) FiltersGETHandler(c *gin.Context) { - if _, err := oauth.Authed(c, true, true, true, true); err != nil { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) return } @@ -38,5 +69,11 @@ func (m *Module) FiltersGETHandler(c *gin.Context) { return } - apiutil.Data(c, http.StatusOK, apiutil.AppJSON, apiutil.EmptyJSONArray) + apiFilters, errWithCode := m.processor.FiltersV1().GetAll(c.Request.Context(), authed.Account) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiFilters) } diff --git a/internal/api/client/filters/v1/filtersget_test.go b/internal/api/client/filters/v1/filtersget_test.go new file mode 100644 index 000000000..a568239ef --- /dev/null +++ b/internal/api/client/filters/v1/filtersget_test.go @@ -0,0 +1,114 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) getFilters( + expectedHTTPStatus int, + expectedBody string, +) ([]*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodGet, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath, nil) + ctx.Request.Header.Set("accept", "application/json") + + // trigger the handler + suite.filtersModule.FiltersGETHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := make([]*apimodel.FilterV1, 0) + if err := json.Unmarshal(b, &resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestGetFilters() { + // v1 filters map to individual filter keywords. + expectedFilterIDs := make([]string, 0, len(suite.testFilterKeywords)) + expectedFilterKeywords := make([]string, 0, len(suite.testFilterKeywords)) + for _, filterKeyword := range suite.testFilterKeywords { + if filterKeyword.AccountID == suite.testAccounts["local_account_1"].ID { + expectedFilterIDs = append(expectedFilterIDs, filterKeyword.ID) + expectedFilterKeywords = append(expectedFilterKeywords, filterKeyword.Keyword) + } + } + suite.NotEmpty(expectedFilterIDs) + suite.NotEmpty(expectedFilterKeywords) + + // Fetch all filters for the logged-in account. + filters, err := suite.getFilters(http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotEmpty(filters) + + // Check that we got the right ones. + actualFilterIDs := make([]string, 0, len(filters)) + actualFilterKeywords := make([]string, 0, len(filters)) + for _, filter := range filters { + actualFilterIDs = append(actualFilterIDs, filter.ID) + actualFilterKeywords = append(actualFilterKeywords, filter.Phrase) + } + suite.ElementsMatch(expectedFilterIDs, actualFilterIDs) + suite.ElementsMatch(expectedFilterKeywords, actualFilterKeywords) +} diff --git a/internal/api/client/filters/v1/validate.go b/internal/api/client/filters/v1/validate.go new file mode 100644 index 000000000..b539c9563 --- /dev/null +++ b/internal/api/client/filters/v1/validate.go @@ -0,0 +1,68 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "errors" + "fmt" + "strconv" + + "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/internal/validate" +) + +func validateNormalizeCreateUpdateFilter(form *model.FilterCreateUpdateRequestV1) error { + if err := validate.FilterKeyword(form.Phrase); err != nil { + return err + } + if err := validate.FilterContexts(form.Context); err != nil { + return err + } + + // Apply defaults for missing fields. + form.WholeWord = util.Ptr(util.PtrValueOr(form.WholeWord, false)) + form.Irreversible = util.Ptr(util.PtrValueOr(form.Irreversible, false)) + + if *form.Irreversible { + return errors.New("irreversible aka server-side drop filters are not supported yet") + } + + // Normalize filter expiry if necessary. + // If we parsed this as JSON, expires_in + // may be either a float64 or a string. + if ei := form.ExpiresInI; ei != nil { + switch e := ei.(type) { + case float64: + form.ExpiresIn = util.Ptr(int(e)) + + case string: + expiresIn, err := strconv.Atoi(e) + if err != nil { + return fmt.Errorf("could not parse expires_in value %s as integer: %w", e, err) + } + + form.ExpiresIn = &expiresIn + + default: + return fmt.Errorf("could not parse expires_in type %T as integer", ei) + } + } + + return nil +} diff --git a/internal/api/model/filter.go b/internal/api/model/filter.go index 4a5d29690..027dea48c 100644 --- a/internal/api/model/filter.go +++ b/internal/api/model/filter.go @@ -17,29 +17,23 @@ package model -// Filter represents a user-defined filter for determining which statuses should not be shown to the user. -// If whole_word is true , client app should do: -// Define ‘word constituent character’ for your app. In the official implementation, it’s [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby. -// Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation). -// If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match. -// If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match. -// Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details. -type Filter struct { - // The ID of the filter in the database. - ID string `json:"id"` - // The text to be filtered. - Phrase string `json:"text"` - // The contexts in which the filter should be applied. - // Array of String (Enumerable anyOf) - // home = home timeline and lists - // notifications = notifications timeline - // public = public timelines - // thread = expanded thread of a detailed status - Context []string `json:"context"` - // Should the filter consider word boundaries? - WholeWord bool `json:"whole_word"` - // When the filter should no longer be applied (ISO 8601 Datetime), or null if the filter does not expire - ExpiresAt string `json:"expires_at,omitempty"` - // Should matching entities in home and notifications be dropped by the server? - Irreversible bool `json:"irreversible"` -} +// FilterContext represents the context in which to apply a filter. +// v1 and v2 filter APIs use the same set of contexts. +// +// swagger:model filterContext +type FilterContext string + +const ( + // FilterContextHome means this filter should be applied to the home timeline and lists. + FilterContextHome FilterContext = "home" + // FilterContextNotifications means this filter should be applied to the notifications timeline. + FilterContextNotifications FilterContext = "notifications" + // FilterContextPublic means this filter should be applied to public timelines. + FilterContextPublic FilterContext = "public" + // FilterContextThread means this filter should be applied to the expanded thread of a detailed status. + FilterContextThread FilterContext = "thread" + // FilterContextAccount means this filter should be applied when viewing a profile. + FilterContextAccount FilterContext = "account" + + FilterContextNumValues = 5 +) diff --git a/internal/api/model/filterv1.go b/internal/api/model/filterv1.go new file mode 100644 index 000000000..52250f537 --- /dev/null +++ b/internal/api/model/filterv1.go @@ -0,0 +1,99 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package model + +// FilterV1 represents a user-defined filter for determining which statuses should not be shown to the user. +// Note that v1 filters are mapped to v2 filters and v2 filter keywords internally. +// If whole_word is true, client app should do: +// Define ‘word constituent character’ for your app. In the official implementation, it’s [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby. +// Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation). +// If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match. +// If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match. +// Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details. +// +// swagger:model filterV1 +// +// --- +// tags: +// - filters +type FilterV1 struct { + // The ID of the filter in the database. + ID string `json:"id"` + // The text to be filtered. + // + // Example: fnord + Phrase string `json:"phrase"` + // The contexts in which the filter should be applied. + // + // Minimum length: 1 + // Unique: true + // Enum: + // - home + // - notifications + // - public + // - thread + // - account + // Example: ["home", "public"] + Context []FilterContext `json:"context"` + // Should the filter consider word boundaries? + // + // Example: true + WholeWord bool `json:"whole_word"` + // Should matching entities be removed from the user's timelines/views, instead of hidden? + // + // Example: false + Irreversible bool `json:"irreversible"` + // When the filter should no longer be applied. Null if the filter does not expire. + // + // Example: 2024-02-01T02:57:49Z + ExpiresAt *string `json:"expires_at"` +} + +// FilterCreateUpdateRequestV1 captures params for creating or updating a v1 filter. +// +// swagger:ignore +type FilterCreateUpdateRequestV1 struct { + // The text to be filtered. + // + // Required: true + // Maximum length: 40 + // Example: fnord + Phrase string `form:"phrase" json:"phrase" xml:"phrase"` + // The contexts in which the filter should be applied. + // + // Required: true + // Minimum length: 1 + // Unique: true + // Enum: home,notifications,public,thread,account + // Example: ["home", "public"] + Context []FilterContext `form:"context[]" json:"context" xml:"context"` + // Should matching entities be removed from the user's timelines/views, instead of hidden? + // + // Example: false + Irreversible *bool `form:"irreversible" json:"irreversible" xml:"irreversible"` + // Should the filter consider word boundaries? + // + // Example: true + WholeWord *bool `form:"whole_word" json:"whole_word" xml:"whole_word"` + // Number of seconds from now that the filter should expire. If omitted, filter never expires. + ExpiresIn *int `json:"-" form:"expires_in" xml:"expires_in"` + // Number of seconds from now that the filter should expire. If omitted, filter never expires. + // + // Example: 86400 + ExpiresInI interface{} `json:"expires_in"` +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 17fa03323..9b70a565c 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -61,6 +61,9 @@ func (c *Caches) Init() { c.initDomainBlock() c.initEmoji() c.initEmojiCategory() + c.initFilter() + c.initFilterKeyword() + c.initFilterStatus() c.initFollow() c.initFollowIDs() c.initFollowRequest() @@ -119,6 +122,9 @@ func (c *Caches) Sweep(threshold float64) { c.GTS.BlockIDs.Trim(threshold) c.GTS.Emoji.Trim(threshold) c.GTS.EmojiCategory.Trim(threshold) + c.GTS.Filter.Trim(threshold) + c.GTS.FilterKeyword.Trim(threshold) + c.GTS.FilterStatus.Trim(threshold) c.GTS.Follow.Trim(threshold) c.GTS.FollowIDs.Trim(threshold) c.GTS.FollowRequest.Trim(threshold) diff --git a/internal/cache/db.go b/internal/cache/db.go index 275a25451..dc9e385cd 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -67,6 +67,15 @@ type GTSCaches struct { // EmojiCategory provides access to the gtsmodel EmojiCategory database cache. EmojiCategory structr.Cache[*gtsmodel.EmojiCategory] + // Filter provides access to the gtsmodel Filter database cache. + Filter structr.Cache[*gtsmodel.Filter] + + // FilterKeyword provides access to the gtsmodel FilterKeyword database cache. + FilterKeyword structr.Cache[*gtsmodel.FilterKeyword] + + // FilterStatus provides access to the gtsmodel FilterStatus database cache. + FilterStatus structr.Cache[*gtsmodel.FilterStatus] + // Follow provides access to the gtsmodel Follow database cache. Follow structr.Cache[*gtsmodel.Follow] @@ -409,6 +418,105 @@ func (c *Caches) initEmojiCategory() { }) } +func (c *Caches) initFilter() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFilter(), // model in-mem size. + config.GetCacheFilterMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(filter1 *gtsmodel.Filter) *gtsmodel.Filter { + filter2 := new(gtsmodel.Filter) + *filter2 = *filter1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/filter.go. + filter2.Keywords = nil + filter2.Statuses = nil + + return filter2 + } + + c.GTS.Filter.Init(structr.Config[*gtsmodel.Filter]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initFilterKeyword() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFilterKeyword(), // model in-mem size. + config.GetCacheFilterKeywordMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(filterKeyword1 *gtsmodel.FilterKeyword) *gtsmodel.FilterKeyword { + filterKeyword2 := new(gtsmodel.FilterKeyword) + *filterKeyword2 = *filterKeyword1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/filter.go. + filterKeyword2.Filter = nil + + return filterKeyword2 + } + + c.GTS.FilterKeyword.Init(structr.Config[*gtsmodel.FilterKeyword]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID", Multiple: true}, + {Fields: "FilterID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initFilterStatus() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFilterStatus(), // model in-mem size. + config.GetCacheFilterStatusMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(filterStatus1 *gtsmodel.FilterStatus) *gtsmodel.FilterStatus { + filterStatus2 := new(gtsmodel.FilterStatus) + *filterStatus2 = *filterStatus1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/filter.go. + filterStatus2.Filter = nil + + return filterStatus2 + } + + c.GTS.FilterStatus.Init(structr.Config[*gtsmodel.FilterStatus]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID", Multiple: true}, + {Fields: "FilterID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + func (c *Caches) initFollow() { // Calculate maximum cache size. cap := calculateResultCacheMax( diff --git a/internal/cache/size.go b/internal/cache/size.go index 62fb31469..f9d88491d 100644 --- a/internal/cache/size.go +++ b/internal/cache/size.go @@ -309,6 +309,38 @@ func sizeofEmojiCategory() uintptr { })) } +func sizeofFilter() uintptr { + return uintptr(size.Of(>smodel.Filter{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + ExpiresAt: exampleTime, + AccountID: exampleID, + Title: exampleTextSmall, + Action: gtsmodel.FilterActionHide, + })) +} + +func sizeofFilterKeyword() uintptr { + return uintptr(size.Of(>smodel.FilterKeyword{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + FilterID: exampleID, + Keyword: exampleTextSmall, + })) +} + +func sizeofFilterStatus() uintptr { + return uintptr(size.Of(>smodel.FilterStatus{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + FilterID: exampleID, + StatusID: exampleID, + })) +} + func sizeofFollow() uintptr { return uintptr(size.Of(>smodel.Follow{ ID: exampleID, diff --git a/internal/config/config.go b/internal/config/config.go index c810222a1..ea84a4af7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -201,6 +201,9 @@ type CacheConfiguration struct { BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"` EmojiMemRatio float64 `name:"emoji-mem-ratio"` EmojiCategoryMemRatio float64 `name:"emoji-category-mem-ratio"` + FilterMemRatio float64 `name:"filter-mem-ratio"` + FilterKeywordMemRatio float64 `name:"filter-keyword-mem-ratio"` + FilterStatusMemRatio float64 `name:"filter-status-mem-ratio"` FollowMemRatio float64 `name:"follow-mem-ratio"` FollowIDsMemRatio float64 `name:"follow-ids-mem-ratio"` FollowRequestMemRatio float64 `name:"follow-request-mem-ratio"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 78474539f..c98b54b0b 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -165,6 +165,9 @@ var Defaults = Configuration{ BoostOfIDsMemRatio: 3, EmojiMemRatio: 3, EmojiCategoryMemRatio: 0.1, + FilterMemRatio: 0.5, + FilterKeywordMemRatio: 0.5, + FilterStatusMemRatio: 0.5, FollowMemRatio: 2, FollowIDsMemRatio: 4, FollowRequestMemRatio: 2, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index f458074b1..c5d4c992b 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2975,6 +2975,81 @@ func GetCacheEmojiCategoryMemRatio() float64 { return global.GetCacheEmojiCatego // SetCacheEmojiCategoryMemRatio safely sets the value for global configuration 'Cache.EmojiCategoryMemRatio' field func SetCacheEmojiCategoryMemRatio(v float64) { global.SetCacheEmojiCategoryMemRatio(v) } +// GetCacheFilterMemRatio safely fetches the Configuration value for state's 'Cache.FilterMemRatio' field +func (st *ConfigState) GetCacheFilterMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.FilterMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheFilterMemRatio safely sets the Configuration value for state's 'Cache.FilterMemRatio' field +func (st *ConfigState) SetCacheFilterMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.FilterMemRatio = v + st.reloadToViper() +} + +// CacheFilterMemRatioFlag returns the flag name for the 'Cache.FilterMemRatio' field +func CacheFilterMemRatioFlag() string { return "cache-filter-mem-ratio" } + +// GetCacheFilterMemRatio safely fetches the value for global configuration 'Cache.FilterMemRatio' field +func GetCacheFilterMemRatio() float64 { return global.GetCacheFilterMemRatio() } + +// SetCacheFilterMemRatio safely sets the value for global configuration 'Cache.FilterMemRatio' field +func SetCacheFilterMemRatio(v float64) { global.SetCacheFilterMemRatio(v) } + +// GetCacheFilterKeywordMemRatio safely fetches the Configuration value for state's 'Cache.FilterKeywordMemRatio' field +func (st *ConfigState) GetCacheFilterKeywordMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.FilterKeywordMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheFilterKeywordMemRatio safely sets the Configuration value for state's 'Cache.FilterKeywordMemRatio' field +func (st *ConfigState) SetCacheFilterKeywordMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.FilterKeywordMemRatio = v + st.reloadToViper() +} + +// CacheFilterKeywordMemRatioFlag returns the flag name for the 'Cache.FilterKeywordMemRatio' field +func CacheFilterKeywordMemRatioFlag() string { return "cache-filter-keyword-mem-ratio" } + +// GetCacheFilterKeywordMemRatio safely fetches the value for global configuration 'Cache.FilterKeywordMemRatio' field +func GetCacheFilterKeywordMemRatio() float64 { return global.GetCacheFilterKeywordMemRatio() } + +// SetCacheFilterKeywordMemRatio safely sets the value for global configuration 'Cache.FilterKeywordMemRatio' field +func SetCacheFilterKeywordMemRatio(v float64) { global.SetCacheFilterKeywordMemRatio(v) } + +// GetCacheFilterStatusMemRatio safely fetches the Configuration value for state's 'Cache.FilterStatusMemRatio' field +func (st *ConfigState) GetCacheFilterStatusMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.FilterStatusMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheFilterStatusMemRatio safely sets the Configuration value for state's 'Cache.FilterStatusMemRatio' field +func (st *ConfigState) SetCacheFilterStatusMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.FilterStatusMemRatio = v + st.reloadToViper() +} + +// CacheFilterStatusMemRatioFlag returns the flag name for the 'Cache.FilterStatusMemRatio' field +func CacheFilterStatusMemRatioFlag() string { return "cache-filter-status-mem-ratio" } + +// GetCacheFilterStatusMemRatio safely fetches the value for global configuration 'Cache.FilterStatusMemRatio' field +func GetCacheFilterStatusMemRatio() float64 { return global.GetCacheFilterStatusMemRatio() } + +// SetCacheFilterStatusMemRatio safely sets the value for global configuration 'Cache.FilterStatusMemRatio' field +func SetCacheFilterStatusMemRatio(v float64) { global.SetCacheFilterStatusMemRatio(v) } + // GetCacheFollowMemRatio safely fetches the Configuration value for state's 'Cache.FollowMemRatio' field func (st *ConfigState) GetCacheFollowMemRatio() (v float64) { st.mutex.RLock() diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 4ecbec7b9..c49da272b 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -62,6 +62,7 @@ type DBService struct { db.Emoji db.HeaderFilter db.Instance + db.Filter db.List db.Marker db.Media @@ -200,6 +201,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + Filter: &filterDB{ + db: db, + state: state, + }, List: &listDB{ db: db, state: state, diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go new file mode 100644 index 000000000..bcd572f34 --- /dev/null +++ b/internal/db/bundb/filter.go @@ -0,0 +1,339 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +type filterDB struct { + db *bun.DB + state *state.State +} + +func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) { + filter, err := f.state.Caches.GTS.Filter.LoadOne( + "ID", + func() (*gtsmodel.Filter, error) { + var filter gtsmodel.Filter + err := f.db. + NewSelect(). + Model(&filter). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filter, err + }, + id, + ) + if err != nil { + // already processed + return nil, err + } + + if !gtscontext.Barebones(ctx) { + if err := f.populateFilter(ctx, filter); err != nil { + return nil, err + } + } + + return filter, nil +} + +func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) { + // Fetch IDs of all filters owned by this account. + var filterIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.Filter)(nil)). + Column("id"). + Where("? = ?", bun.Ident("account_id"), accountID). + Scan(ctx, &filterIDs); err != nil { + return nil, err + } + if len(filterIDs) == 0 { + return nil, nil + } + + // Get each filter by ID from the cache or DB. + uncachedFilterIDs := make([]string, 0, len(filterIDs)) + filters, err := f.state.Caches.GTS.Filter.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterIDs { + if !load(id) { + uncachedFilterIDs = append(uncachedFilterIDs, id) + } + } + }, + func() ([]*gtsmodel.Filter, error) { + uncachedFilters := make([]*gtsmodel.Filter, 0, len(uncachedFilterIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilters). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilters, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter structs in the same order as the filter IDs. + util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID }) + + if gtscontext.Barebones(ctx) { + return filters, nil + } + + // Populate the filters. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filters)) + filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool { + if err := f.populateFilter(ctx, filter); err != nil { + errs.Appendf("error populating filter %s: %w", filter.ID, err) + return true + } + return false + }) + + return filters, errs.Combine() +} + +func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error { + var err error + errs := gtserror.NewMultiError(2) + + if filter.Keywords == nil { + // Filter keywords are not set, fetch from the database. + filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID( + gtscontext.SetBarebones(ctx), + filter.ID, + ) + if err != nil { + errs.Appendf("error populating filter keywords: %w", err) + } + for i := range filter.Keywords { + filter.Keywords[i].Filter = filter + } + } + + if filter.Statuses == nil { + // Filter statuses are not set, fetch from the database. + filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID( + gtscontext.SetBarebones(ctx), + filter.ID, + ) + if err != nil { + errs.Appendf("error populating filter statuses: %w", err) + } + for i := range filter.Statuses { + filter.Statuses[i].Filter = filter + } + } + + return errs.Combine() +} + +func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error { + // Update database. + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx. + NewInsert(). + Model(filter). + Exec(ctx); err != nil { + return err + } + + if len(filter.Keywords) > 0 { + if _, err := tx. + NewInsert(). + Model(&filter.Keywords). + Exec(ctx); err != nil { + return err + } + } + + if len(filter.Statuses) > 0 { + if _, err := tx. + NewInsert(). + Model(&filter.Statuses). + Exec(ctx); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update cache. + f.state.Caches.GTS.Filter.Put(filter) + f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) + f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) + + return nil +} + +func (f *filterDB) UpdateFilter( + ctx context.Context, + filter *gtsmodel.Filter, + filterColumns []string, + filterKeywordColumns []string, + deleteFilterKeywordIDs []string, + deleteFilterStatusIDs []string, +) error { + updatedAt := time.Now() + filter.UpdatedAt = updatedAt + for _, filterKeyword := range filter.Keywords { + filterKeyword.UpdatedAt = updatedAt + } + for _, filterStatus := range filter.Statuses { + filterStatus.UpdatedAt = updatedAt + } + + // If we're updating by column, ensure "updated_at" is included. + if len(filterColumns) > 0 { + filterColumns = append(filterColumns, "updated_at") + } + if len(filterKeywordColumns) > 0 { + filterKeywordColumns = append(filterKeywordColumns, "updated_at") + } + + // Update database. + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx. + NewUpdate(). + Model(filter). + Column(filterColumns...). + Where("? = ?", bun.Ident("id"), filter.ID). + Exec(ctx); err != nil { + return err + } + + if len(filter.Keywords) > 0 { + if _, err := NewUpsert(tx). + Model(&filter.Keywords). + Constraint("id"). + Column(filterKeywordColumns...). + Exec(ctx); err != nil { + return err + } + } + + if len(filter.Statuses) > 0 { + if _, err := tx. + NewInsert(). + Ignore(). + Model(&filter.Statuses). + Exec(ctx); err != nil { + return err + } + } + + if len(deleteFilterKeywordIDs) > 0 { + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)). + Exec(ctx); err != nil { + return err + } + } + + if len(deleteFilterStatusIDs) > 0 { + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)). + Exec(ctx); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update cache. + f.state.Caches.GTS.Filter.Put(filter) + f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) + f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) + // TODO: (Vyr) replace with cache multi-invalidate call + for _, id := range deleteFilterKeywordIDs { + f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) + } + for _, id := range deleteFilterStatusIDs { + f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) + } + + return nil +} + +func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error { + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Delete all keywords attached to filter. + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = ?", bun.Ident("filter_id"), id). + Exec(ctx); err != nil { + return err + } + + // Delete all statuses attached to filter. + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = ?", bun.Ident("filter_id"), id). + Exec(ctx); err != nil { + return err + } + + // Delete the filter itself. + _, err := tx. + NewDelete(). + Model((*gtsmodel.Filter)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return err + }); err != nil { + return err + } + + // Invalidate this filter. + f.state.Caches.GTS.Filter.Invalidate("ID", id) + + // Invalidate all keywords and statuses for this filter. + f.state.Caches.GTS.FilterKeyword.Invalidate("FilterID", id) + f.state.Caches.GTS.FilterStatus.Invalidate("FilterID", id) + + return nil +} diff --git a/internal/db/bundb/filter_test.go b/internal/db/bundb/filter_test.go new file mode 100644 index 000000000..7940b6651 --- /dev/null +++ b/internal/db/bundb/filter_test.go @@ -0,0 +1,252 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +type FilterTestSuite struct { + BunDBStandardTestSuite +} + +// TestFilterCRUD tests CRUD and read-all operations on filters. +func (suite *FilterTestSuite) TestFilterCRUD() { + t := suite.T() + + // Create new example filter with attached keyword. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + filterKeyword := >smodel.FilterKeyword{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + Keyword: "GNU/Linux", + } + filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the example filter into db. + if err := suite.db.PutFilter(ctx, filter); err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // Now fetch newly created filter. + check, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + + // Check all expected fields match. + suite.Equal(filter.ID, check.ID) + suite.Equal(filter.AccountID, check.AccountID) + suite.Equal(filter.Title, check.Title) + suite.Equal(filter.Action, check.Action) + suite.Equal(filter.ContextHome, check.ContextHome) + suite.Equal(filter.ContextNotifications, check.ContextNotifications) + suite.Equal(filter.ContextPublic, check.ContextPublic) + suite.Equal(filter.ContextThread, check.ContextThread) + suite.Equal(filter.ContextAccount, check.ContextAccount) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + + suite.Equal(len(filter.Keywords), len(check.Keywords)) + suite.Equal(filter.Keywords[0].ID, check.Keywords[0].ID) + suite.Equal(filter.Keywords[0].AccountID, check.Keywords[0].AccountID) + suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) + suite.Equal(filter.Keywords[0].Keyword, check.Keywords[0].Keyword) + suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) + suite.NotZero(check.Keywords[0].CreatedAt) + suite.NotZero(check.Keywords[0].UpdatedAt) + + suite.Equal(len(filter.Statuses), len(check.Statuses)) + + // Fetch all filters. + all, err := suite.db.GetFiltersForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filters: %v", err) + } + + // Ensure the result contains our example filter. + suite.Len(all, 1) + suite.Equal(filter.ID, all[0].ID) + + suite.Len(all[0].Keywords, 1) + suite.Equal(filter.Keywords[0].ID, all[0].Keywords[0].ID) + + suite.Empty(all[0].Statuses) + + // Update the filter context and add another keyword and a status. + check.ContextNotifications = util.Ptr(true) + + newKeyword := >smodel.FilterKeyword{ + ID: "01HNEMY810E5XKWDDMN5ZRE749", + FilterID: filter.ID, + AccountID: filter.AccountID, + Keyword: "tux", + } + check.Keywords = append(check.Keywords, newKeyword) + + newStatus := >smodel.FilterStatus{ + ID: "01HNEMYD5XE7C8HH8TNCZ76FN2", + FilterID: filter.ID, + AccountID: filter.AccountID, + StatusID: "01HNEKZW34SQZ8PSDQ0Z10NZES", + } + check.Statuses = append(check.Statuses, newStatus) + + if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + // Now fetch newly updated filter. + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure expected fields were modified on check filter. + suite.True(check.UpdatedAt.After(filter.UpdatedAt)) + if suite.NotNil(check.ContextHome) { + suite.True(*check.ContextHome) + } + if suite.NotNil(check.ContextNotifications) { + suite.True(*check.ContextNotifications) + } + if suite.NotNil(check.ContextPublic) { + suite.True(*check.ContextPublic) + } + if suite.NotNil(check.ContextThread) { + suite.False(*check.ContextThread) + } + if suite.NotNil(check.ContextAccount) { + suite.False(*check.ContextAccount) + } + + // Ensure keyword entries were added. + suite.Len(check.Keywords, 2) + checkFilterKeywordIDs := make([]string, 0, 2) + for _, checkFilterKeyword := range check.Keywords { + checkFilterKeywordIDs = append(checkFilterKeywordIDs, checkFilterKeyword.ID) + } + suite.ElementsMatch([]string{filterKeyword.ID, newKeyword.ID}, checkFilterKeywordIDs) + + // Ensure status entry was added. + suite.Len(check.Statuses, 1) + checkFilterStatusIDs := make([]string, 0, 1) + for _, checkFilterStatus := range check.Statuses { + checkFilterStatusIDs = append(checkFilterStatusIDs, checkFilterStatus.ID) + } + suite.ElementsMatch([]string{newStatus.ID}, checkFilterStatusIDs) + + // Update one filter keyword and delete another. Don't change the filter or the filter status. + filterKeyword.WholeWord = util.Ptr(true) + check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + check.Statuses = nil + + if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure expected fields were not modified. + suite.Equal(filter.Title, check.Title) + suite.Equal(gtsmodel.FilterActionWarn, check.Action) + if suite.NotNil(check.ContextHome) { + suite.True(*check.ContextHome) + } + if suite.NotNil(check.ContextNotifications) { + suite.True(*check.ContextNotifications) + } + if suite.NotNil(check.ContextPublic) { + suite.True(*check.ContextPublic) + } + if suite.NotNil(check.ContextThread) { + suite.False(*check.ContextThread) + } + if suite.NotNil(check.ContextAccount) { + suite.False(*check.ContextAccount) + } + + // Ensure only changed field of keyword was modified, and other keyword was deleted. + suite.Len(check.Keywords, 1) + suite.Equal(filterKeyword.ID, check.Keywords[0].ID) + suite.Equal("GNU/Linux", check.Keywords[0].Keyword) + if suite.NotNil(check.Keywords[0].WholeWord) { + suite.True(*check.Keywords[0].WholeWord) + } + + // Ensure status entry was not deleted. + suite.Len(check.Statuses, 1) + suite.Equal(newStatus.ID, check.Statuses[0].ID) + + // Add another status entry for the same status ID. It should be ignored without problems. + redundantStatus := >smodel.FilterStatus{ + ID: "01HQXJ5Y405XZSQ67C2BSQ6HJ0", + FilterID: filter.ID, + AccountID: filter.AccountID, + StatusID: newStatus.StatusID, + } + check.Statuses = []*gtsmodel.FilterStatus{redundantStatus} + if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure status entry was not deleted, updated, or duplicated. + suite.Len(check.Statuses, 1) + suite.Equal(newStatus.ID, check.Statuses[0].ID) + suite.Equal(newStatus.StatusID, check.Statuses[0].StatusID) + + // Now delete the filter from the DB. + if err := suite.db.DeleteFilterByID(ctx, filter.ID); err != nil { + t.Fatalf("error deleting filter: %v", err) + } + + // Ensure we can't refetch it. + _, err = suite.db.GetFilterByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter returned unexpected error: %v", err) + } +} + +func TestFilterTestSuite(t *testing.T) { + suite.Run(t, new(FilterTestSuite)) +} diff --git a/internal/db/bundb/filterkeyword.go b/internal/db/bundb/filterkeyword.go new file mode 100644 index 000000000..703d58d43 --- /dev/null +++ b/internal/db/bundb/filterkeyword.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) { + filterKeyword, err := f.state.Caches.GTS.FilterKeyword.LoadOne( + "ID", + func() (*gtsmodel.FilterKeyword, error) { + var filterKeyword gtsmodel.FilterKeyword + err := f.db. + NewSelect(). + Model(&filterKeyword). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filterKeyword, err + }, + id, + ) + if err != nil { + return nil, err + } + + if !gtscontext.Barebones(ctx) { + err = f.populateFilterKeyword(ctx, filterKeyword) + if err != nil { + return nil, err + } + } + + return filterKeyword, nil +} + +func (f *filterDB) populateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { + if filterKeyword.Filter == nil { + // Filter is not set, fetch from the cache or database. + filter, err := f.state.DB.GetFilterByID( + // Don't populate the filter with all of its keywords and statuses or we'll just end up back here. + gtscontext.SetBarebones(ctx), + filterKeyword.FilterID, + ) + if err != nil { + return err + } + filterKeyword.Filter = filter + } + + return nil +} + +func (f *filterDB) GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) { + return f.getFilterKeywords(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) { + return f.getFilterKeywords(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterKeywords(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterKeyword, error) { + var filterKeywordIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.FilterKeyword)(nil)). + Column("id"). + Where("? = ?", bun.Ident(idColumn), id). + Scan(ctx, &filterKeywordIDs); err != nil { + return nil, err + } + if len(filterKeywordIDs) == 0 { + return nil, nil + } + + // Get each filter keyword by ID from the cache or DB. + uncachedFilterKeywordIDs := make([]string, 0, len(filterKeywordIDs)) + filterKeywords, err := f.state.Caches.GTS.FilterKeyword.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterKeywordIDs { + if !load(id) { + uncachedFilterKeywordIDs = append(uncachedFilterKeywordIDs, id) + } + } + }, + func() ([]*gtsmodel.FilterKeyword, error) { + uncachedFilterKeywords := make([]*gtsmodel.FilterKeyword, 0, len(uncachedFilterKeywordIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilterKeywords). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterKeywordIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilterKeywords, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter keyword structs in the same order as the filter keyword IDs. + util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string { + return filterKeyword.ID + }) + + if gtscontext.Barebones(ctx) { + return filterKeywords, nil + } + + // Populate the filter keywords. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filterKeywords)) + filterKeywords = slices.DeleteFunc(filterKeywords, func(filterKeyword *gtsmodel.FilterKeyword) bool { + if err := f.populateFilterKeyword(ctx, filterKeyword); err != nil { + errs.Appendf( + "error populating filter keyword %s: %w", + filterKeyword.ID, + err, + ) + return true + } + return false + }) + + return filterKeywords, errs.Combine() +} + +func (f *filterDB) PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { + return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { + _, err := f.db. + NewInsert(). + Model(filterKeyword). + Exec(ctx) + return err + }) +} + +func (f *filterDB) UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error { + filterKeyword.UpdatedAt = time.Now() + if len(columns) > 0 { + columns = append(columns, "updated_at") + } + + return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { + _, err := f.db. + NewUpdate(). + Model(filterKeyword). + Where("? = ?", bun.Ident("id"), filterKeyword.ID). + Column(columns...). + Exec(ctx) + return err + }) +} + +func (f *filterDB) DeleteFilterKeywordByID(ctx context.Context, id string) error { + if _, err := f.db. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) + + return nil +} diff --git a/internal/db/bundb/filterkeyword_test.go b/internal/db/bundb/filterkeyword_test.go new file mode 100644 index 000000000..91c8d192c --- /dev/null +++ b/internal/db/bundb/filterkeyword_test.go @@ -0,0 +1,143 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterKeywordCRUD tests CRUD and read-all operations on filter keywords. +func (suite *FilterTestSuite) TestFilterKeywordCRUD() { + t := suite.T() + + // Create new filter. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the new filter into the DB. + err := suite.db.PutFilter(ctx, filter) + if err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // There should be no filter keywords yet. + all, err := suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Empty(all) + + // Add a filter keyword to it. + filterKeyword := >smodel.FilterKeyword{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + Keyword: "GNU/Linux", + } + + // Insert the new filter keyword into the DB. + err = suite.db.PutFilterKeyword(ctx, filterKeyword) + if err != nil { + t.Fatalf("error inserting filter keyword: %v", err) + } + + // Try to find it again and ensure it has the fields we expect. + check, err := suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) + if err != nil { + t.Fatalf("error fetching filter keyword: %v", err) + } + suite.Equal(filterKeyword.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + suite.Equal(filterKeyword.AccountID, check.AccountID) + suite.Equal(filterKeyword.FilterID, check.FilterID) + suite.Equal(filterKeyword.Keyword, check.Keyword) + suite.Equal(filterKeyword.WholeWord, check.WholeWord) + + // Loading filter keywords by account ID should find the one we inserted. + all, err = suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterKeyword.ID, all[0].ID) + + // Loading filter keywords by filter ID should also find the one we inserted. + all, err = suite.db.GetFilterKeywordsForFilterID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterKeyword.ID, all[0].ID) + + // Modify the filter keyword. + filterKeyword.WholeWord = util.Ptr(true) + err = suite.db.UpdateFilterKeyword(ctx, filterKeyword) + if err != nil { + t.Fatalf("error updating filter keyword: %v", err) + } + + // Try to find it again and ensure it has the updated fields we expect. + check, err = suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) + if err != nil { + t.Fatalf("error fetching filter keyword: %v", err) + } + suite.Equal(filterKeyword.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.True(check.UpdatedAt.After(check.CreatedAt)) + suite.Equal(filterKeyword.AccountID, check.AccountID) + suite.Equal(filterKeyword.FilterID, check.FilterID) + suite.Equal(filterKeyword.Keyword, check.Keyword) + suite.Equal(filterKeyword.WholeWord, check.WholeWord) + + // Delete the filter keyword from the DB. + err = suite.db.DeleteFilterKeywordByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error deleting filter keyword: %v", err) + } + + // Ensure we can't refetch it. + check, err = suite.db.GetFilterKeywordByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter keyword returned unexpected error: %v", err) + } + suite.Nil(check) + + // Ensure the filter itself is still there. + checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/filterstatus.go b/internal/db/bundb/filterstatus.go new file mode 100644 index 000000000..1e98f5958 --- /dev/null +++ b/internal/db/bundb/filterstatus.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) { + filterStatus, err := f.state.Caches.GTS.FilterStatus.LoadOne( + "ID", + func() (*gtsmodel.FilterStatus, error) { + var filterStatus gtsmodel.FilterStatus + err := f.db. + NewSelect(). + Model(&filterStatus). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filterStatus, err + }, + id, + ) + if err != nil { + return nil, err + } + + if !gtscontext.Barebones(ctx) { + err = f.populateFilterStatus(ctx, filterStatus) + if err != nil { + return nil, err + } + } + + return filterStatus, nil +} + +func (f *filterDB) populateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { + if filterStatus.Filter == nil { + // Filter is not set, fetch from the cache or database. + filter, err := f.state.DB.GetFilterByID( + // Don't populate the filter with all of its keywords and statuses or we'll just end up back here. + gtscontext.SetBarebones(ctx), + filterStatus.FilterID, + ) + if err != nil { + return err + } + filterStatus.Filter = filter + } + + return nil +} + +func (f *filterDB) GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) { + return f.getFilterStatuses(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) { + return f.getFilterStatuses(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterStatuses(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterStatus, error) { + var filterStatusIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.FilterStatus)(nil)). + Column("id"). + Where("? = ?", bun.Ident(idColumn), id). + Scan(ctx, &filterStatusIDs); err != nil { + return nil, err + } + if len(filterStatusIDs) == 0 { + return nil, nil + } + + // Get each filter status by ID from the cache or DB. + uncachedFilterStatusIDs := make([]string, 0, len(filterStatusIDs)) + filterStatuses, err := f.state.Caches.GTS.FilterStatus.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterStatusIDs { + if !load(id) { + uncachedFilterStatusIDs = append(uncachedFilterStatusIDs, id) + } + } + }, + func() ([]*gtsmodel.FilterStatus, error) { + uncachedFilterStatuses := make([]*gtsmodel.FilterStatus, 0, len(uncachedFilterStatusIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilterStatuses). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterStatusIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilterStatuses, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter status structs in the same order as the filter status IDs. + util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string { + return filterStatus.ID + }) + + if gtscontext.Barebones(ctx) { + return filterStatuses, nil + } + + // Populate the filter statuses. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filterStatuses)) + filterStatuses = slices.DeleteFunc(filterStatuses, func(filterStatus *gtsmodel.FilterStatus) bool { + if err := f.populateFilterStatus(ctx, filterStatus); err != nil { + errs.Appendf( + "error populating filter status %s: %w", + filterStatus.ID, + err, + ) + return true + } + return false + }) + + return filterStatuses, errs.Combine() +} + +func (f *filterDB) PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { + return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { + _, err := f.db. + NewInsert(). + Model(filterStatus). + Exec(ctx) + return err + }) +} + +func (f *filterDB) UpdateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus, columns ...string) error { + filterStatus.UpdatedAt = time.Now() + if len(columns) > 0 { + columns = append(columns, "updated_at") + } + + return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { + _, err := f.db. + NewUpdate(). + Model(filterStatus). + Where("? = ?", bun.Ident("id"), filterStatus.ID). + Column(columns...). + Exec(ctx) + return err + }) +} + +func (f *filterDB) DeleteFilterStatusByID(ctx context.Context, id string) error { + if _, err := f.db. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) + + return nil +} diff --git a/internal/db/bundb/filterstatus_test.go b/internal/db/bundb/filterstatus_test.go new file mode 100644 index 000000000..48ddb1bed --- /dev/null +++ b/internal/db/bundb/filterstatus_test.go @@ -0,0 +1,122 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterStatusCRD tests CRD (no U) and read-all operations on filter statuses. +func (suite *FilterTestSuite) TestFilterStatusCRD() { + t := suite.T() + + // Create new filter. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the new filter into the DB. + err := suite.db.PutFilter(ctx, filter) + if err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // There should be no filter statuses yet. + all, err := suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Empty(all) + + // Add a filter status to it. + filterStatus := >smodel.FilterStatus{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + StatusID: "01HQXGMQ3QFXRT4GX9WNQ8KC0X", + } + + // Insert the new filter status into the DB. + err = suite.db.PutFilterStatus(ctx, filterStatus) + if err != nil { + t.Fatalf("error inserting filter status: %v", err) + } + + // Try to find it again and ensure it has the fields we expect. + check, err := suite.db.GetFilterStatusByID(ctx, filterStatus.ID) + if err != nil { + t.Fatalf("error fetching filter status: %v", err) + } + suite.Equal(filterStatus.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + suite.Equal(filterStatus.AccountID, check.AccountID) + suite.Equal(filterStatus.FilterID, check.FilterID) + suite.Equal(filterStatus.StatusID, check.StatusID) + + // Loading filter statuses by account ID should find the one we inserted. + all, err = suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterStatus.ID, all[0].ID) + + // Loading filter statuses by filter ID should also find the one we inserted. + all, err = suite.db.GetFilterStatusesForFilterID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterStatus.ID, all[0].ID) + + // Delete the filter status from the DB. + err = suite.db.DeleteFilterStatusByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error deleting filter status: %v", err) + } + + // Ensure we can't refetch it. + check, err = suite.db.GetFilterStatusByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter status returned unexpected error: %v", err) + } + suite.Nil(check) + + // Ensure the filter itself is still there. + checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/migrations/20240126064004_add_filters.go b/internal/db/bundb/migrations/20240126064004_add_filters.go new file mode 100644 index 000000000..3ad22f9d8 --- /dev/null +++ b/internal/db/bundb/migrations/20240126064004_add_filters.go @@ -0,0 +1,97 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Filter table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.Filter{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Filter keyword table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.FilterKeyword{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Filter status table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.FilterStatus{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Add indexes to the filter tables. + for table, indexes := range map[string]map[string][]string{ + "filters": { + "filters_account_id_idx": {"account_id"}, + }, + "filter_keywords": { + "filter_keywords_account_id_idx": {"account_id"}, + "filter_keywords_filter_id_idx": {"filter_id"}, + }, + "filter_statuses": { + "filter_statuses_account_id_idx": {"account_id"}, + "filter_statuses_filter_id_idx": {"filter_id"}, + }, + } { + for index, columns := range indexes { + if _, err := tx. + NewCreateIndex(). + Table(table). + Index(index). + Column(columns...). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go new file mode 100644 index 000000000..34724446c --- /dev/null +++ b/internal/db/bundb/upsert.go @@ -0,0 +1,230 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "database/sql" + "reflect" + "strings" + + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" +) + +// UpsertQuery is a wrapper around an insert query that can update if an insert fails. +// Doesn't implement the full set of Bun query methods, but we can add more if we need them. +// See https://bun.uptrace.dev/guide/query-insert.html#upsert +type UpsertQuery struct { + db bun.IDB + model interface{} + constraints []string + columns []string +} + +func NewUpsert(idb bun.IDB) *UpsertQuery { + // note: passing in rawtx as conn iface so no double query-hook + // firing when passed through the bun.Tx.Query___() functions. + return &UpsertQuery{db: idb} +} + +// Model sets the model or models to upsert. +func (u *UpsertQuery) Model(model interface{}) *UpsertQuery { + u.model = model + return u +} + +// Constraint sets the columns or indexes that are used to check for conflicts. +// This is required. +func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery { + u.constraints = constraints + return u +} + +// Column sets the columns to update if an insert does't happen. +// If empty, all columns not being used for constraints will be updated. +// Cannot overlap with Constraint. +func (u *UpsertQuery) Column(columns ...string) *UpsertQuery { + u.columns = columns + return u +} + +// insertDialect errors if we're using a dialect in which we don't know how to upsert. +func (u *UpsertQuery) insertDialect() error { + dialectName := u.db.Dialect().Name() + switch dialectName { + case dialect.PG, dialect.SQLite: + return nil + default: + // FUTURE: MySQL has its own variation on upserts, but the syntax is different. + return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName) + } +} + +// insertConstraints checks that we have constraints and returns them. +func (u *UpsertQuery) insertConstraints() ([]string, error) { + if len(u.constraints) == 0 { + return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided") + } + return u.constraints, nil +} + +// insertColumns returns the non-constraint columns we'll be updating. +func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) { + // Constraints as a set. + constraintSet := make(map[string]struct{}, len(constraints)) + for _, constraint := range constraints { + constraintSet[constraint] = struct{}{} + } + + var columns []string + var err error + if len(u.columns) == 0 { + columns, err = u.insertColumnsDefault(constraintSet) + } else { + columns, err = u.insertColumnsSpecified(constraintSet) + } + if err != nil { + return nil, err + } + if len(columns) == 0 { + return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting") + } + + return columns, nil +} + +// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking. +func hasElem(modelType reflect.Type) bool { + switch modelType.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice: + return true + default: + return false + } +} + +// insertColumnsDefault returns all non-constraint columns from the model schema. +func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) { + // Get underlying struct type. + modelType := reflect.TypeOf(u.model) + for hasElem(modelType) { + modelType = modelType.Elem() + } + + table := u.db.Dialect().Tables().Get(modelType) + if table == nil { + return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model) + } + + columns := make([]string, 0, len(u.columns)) + for _, field := range table.Fields { + column := field.Name + if _, overlaps := constraintSet[column]; !overlaps { + columns = append(columns, column) + } + } + + return columns, nil +} + +// insertColumnsSpecified ensures constraints and specified columns to update don't overlap. +func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) { + overlapping := make([]string, 0, min(len(u.constraints), len(u.columns))) + for _, column := range u.columns { + if _, overlaps := constraintSet[column]; overlaps { + overlapping = append(overlapping, column) + } + } + + if len(overlapping) > 0 { + return nil, gtserror.Newf( + "UpsertQuery: the following columns can't be used for both constraints and columns to update: %s", + strings.Join(overlapping, ", "), + ) + } + + return u.columns, nil +} + +// insert tries to create a Bun insert query from an upsert query. +func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) { + var err error + + err = u.insertDialect() + if err != nil { + return nil, err + } + + constraints, err := u.insertConstraints() + if err != nil { + return nil, err + } + + columns, err := u.insertColumns(constraints) + if err != nil { + return nil, err + } + + // Build the parts of the query that need us to generate SQL. + constraintIDPlaceholders := make([]string, 0, len(constraints)) + constraintIDs := make([]interface{}, 0, len(constraints)) + for _, constraint := range constraints { + constraintIDPlaceholders = append(constraintIDPlaceholders, "?") + constraintIDs = append(constraintIDs, bun.Ident(constraint)) + } + onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update" + + setClauses := make([]string, 0, len(columns)) + setIDs := make([]interface{}, 0, 2*len(columns)) + for _, column := range columns { + // "excluded" is a special table that contains only the row involved in a conflict. + setClauses = append(setClauses, "? = excluded.?") + setIDs = append(setIDs, bun.Ident(column), bun.Ident(column)) + } + setSQL := strings.Join(setClauses, ", ") + + insertQuery := u.db. + NewInsert(). + Model(u.model). + On(onSQL, constraintIDs...). + Set(setSQL, setIDs...) + + return insertQuery, nil +} + +// Exec builds a Bun insert query from the upsert query, and executes it. +func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + insertQuery, err := u.insertQuery() + if err != nil { + return nil, err + } + + return insertQuery.Exec(ctx, dest...) +} + +// Scan builds a Bun insert query from the upsert query, and scans it. +func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error { + insertQuery, err := u.insertQuery() + if err != nil { + return err + } + + return insertQuery.Scan(ctx, dest...) +} diff --git a/internal/db/db.go b/internal/db/db.go index 361687e94..f23324777 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -32,6 +32,7 @@ type DB interface { Emoji HeaderFilter Instance + Filter List Marker Media diff --git a/internal/db/filter.go b/internal/db/filter.go new file mode 100644 index 000000000..18943b4f9 --- /dev/null +++ b/internal/db/filter.go @@ -0,0 +1,101 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Filter contains methods for creating, reading, updating, and deleting filters and their keyword and status entries. +type Filter interface { + //<editor-fold desc="Filter methods"> + + // GetFilterByID gets one filter with the given id. + GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) + + // GetFiltersForAccountID gets all filters owned by the given accountID. + GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) + + // PutFilter puts a new filter in the database, adding any attached keywords or statuses. + // It uses a transaction to ensure no partial updates. + PutFilter(ctx context.Context, filter *gtsmodel.Filter) error + + // UpdateFilter updates the given filter, + // upserts any attached keywords and inserts any new statuses (existing statuses cannot be updated), + // and deletes indicated filter keywords and statuses by ID. + // It uses a transaction to ensure no partial updates. + // The column lists are optional; if not specified, all columns will be updated. + UpdateFilter( + ctx context.Context, + filter *gtsmodel.Filter, + filterColumns []string, + filterKeywordColumns []string, + deleteFilterKeywordIDs []string, + deleteFilterStatusIDs []string, + ) error + + // DeleteFilterByID deletes one filter with the given ID. + // It uses a transaction to ensure no partial updates. + DeleteFilterByID(ctx context.Context, id string) error + + //</editor-fold> + + //<editor-fold desc="Filter keyword methods"> + + // GetFilterKeywordByID gets one filter keyword with the given ID. + GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) + + // GetFilterKeywordsForFilterID gets filter keywords from the given filterID. + GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) + + // GetFilterKeywordsForAccountID gets filter keywords from the given accountID. + GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) + + // PutFilterKeyword inserts a single filter keyword into the database. + PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error + + // UpdateFilterKeyword updates the given filter keyword. + // Columns is optional, if not specified all will be updated. + UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error + + // DeleteFilterKeywordByID deletes one filter keyword with the given id. + DeleteFilterKeywordByID(ctx context.Context, id string) error + + //</editor-fold> + + //<editor-fold desc="Filter status methods"> + + // GetFilterStatusByID gets one filter status with the given ID. + GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) + + // GetFilterStatusesForFilterID gets filter statuses from the given filterID. + GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) + + // GetFilterStatusesForAccountID gets filter keywords from the given accountID. + GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) + + // PutFilterStatus inserts a single filter status into the database. + PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error + + // DeleteFilterStatusByID deletes one filter status with the given id. + DeleteFilterStatusByID(ctx context.Context, id string) error + + //</editor-fold> +} diff --git a/internal/gtsmodel/filter.go b/internal/gtsmodel/filter.go new file mode 100644 index 000000000..db0a15dfd --- /dev/null +++ b/internal/gtsmodel/filter.go @@ -0,0 +1,71 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package gtsmodel + +import "time" + +// Filter stores a filter created by a local account. +type Filter struct { + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + ExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Time filter should expire. If null, should not expire. + AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter. + Title string `bun:",nullzero,notnull,unique"` // The name of the filter. + Action FilterAction `bun:",nullzero,notnull"` // The action to take. + Keywords []*FilterKeyword `bun:"-"` // Keywords for this filter. + Statuses []*FilterStatus `bun:"-"` // Statuses for this filter. + ContextHome *bool `bun:",nullzero,notnull,default:false"` // Apply filter to home timeline and lists. + ContextNotifications *bool `bun:",nullzero,notnull,default:false"` // Apply filter to notifications. + ContextPublic *bool `bun:",nullzero,notnull,default:false"` // Apply filter to home timeline and lists. + ContextThread *bool `bun:",nullzero,notnull,default:false"` // Apply filter when viewing a status's associated thread. + ContextAccount *bool `bun:",nullzero,notnull,default:false"` // Apply filter when viewing an account profile. +} + +// FilterKeyword stores a single keyword to filter statuses against. +type FilterKeyword struct { + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter keyword. + FilterID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_keywords_filter_id_keyword_uniq"` // ID of the filter that this keyword belongs to. + Filter *Filter `bun:"-"` // Filter corresponding to FilterID + Keyword string `bun:",nullzero,notnull,unique:filter_keywords_filter_id_keyword_uniq"` // The keyword or phrase to filter against. + WholeWord *bool `bun:",nullzero,notnull,default:false"` // Should the filter consider word boundaries? +} + +// FilterStatus stores a single status to filter. +type FilterStatus struct { + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter keyword. + FilterID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_statuses_filter_id_status_id_uniq"` // ID of the filter that this keyword belongs to. + Filter *Filter `bun:"-"` // Filter corresponding to FilterID + StatusID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_statuses_filter_id_status_id_uniq"` // ID of the status to filter. +} + +// FilterAction represents the action to take on a filtered status. +type FilterAction string + +const ( + // FilterActionWarn means that the status should be shown behind a warning. + FilterActionWarn FilterAction = "warn" + // FilterActionHide means that the status should be removed from timeline results entirely. + FilterActionHide FilterAction = "hide" +) diff --git a/internal/processing/filters/v1/convert.go b/internal/processing/filters/v1/convert.go new file mode 100644 index 000000000..1e0db5ff1 --- /dev/null +++ b/internal/processing/filters/v1/convert.go @@ -0,0 +1,38 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "context" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// apiFilter is a shortcut to return the API v1 filter version of the given +// filter keyword, or return an appropriate error if conversion fails. +func (p *Processor) apiFilter(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) (*apimodel.FilterV1, gtserror.WithCode) { + apiFilter, err := p.converter.FilterKeywordToAPIFilterV1(ctx, filterKeyword) + if err != nil { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting filter keyword to API v1 filter: %w", err)) + } + + return apiFilter, nil +} diff --git a/internal/processing/filters/v1/create.go b/internal/processing/filters/v1/create.go new file mode 100644 index 000000000..e36d6800a --- /dev/null +++ b/internal/processing/filters/v1/create.go @@ -0,0 +1,87 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "context" + "errors" + "fmt" + "time" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// Create a new filter and filter keyword for the given account, using the provided parameters. +// These params should have already been validated by the time they reach this function. +func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.FilterCreateUpdateRequestV1) (*apimodel.FilterV1, gtserror.WithCode) { + filter := >smodel.Filter{ + ID: id.NewULID(), + AccountID: account.ID, + Title: form.Phrase, + Action: gtsmodel.FilterActionWarn, + } + if *form.Irreversible { + filter.Action = gtsmodel.FilterActionHide + } + if form.ExpiresIn != nil { + filter.ExpiresAt = time.Now().Add(time.Second * time.Duration(*form.ExpiresIn)) + } + for _, context := range form.Context { + switch context { + case apimodel.FilterContextHome: + filter.ContextHome = util.Ptr(true) + case apimodel.FilterContextNotifications: + filter.ContextNotifications = util.Ptr(true) + case apimodel.FilterContextPublic: + filter.ContextPublic = util.Ptr(true) + case apimodel.FilterContextThread: + filter.ContextThread = util.Ptr(true) + case apimodel.FilterContextAccount: + filter.ContextAccount = util.Ptr(true) + default: + return nil, gtserror.NewErrorUnprocessableEntity( + fmt.Errorf("unsupported filter context '%s'", context), + ) + } + } + + filterKeyword := >smodel.FilterKeyword{ + ID: id.NewULID(), + AccountID: account.ID, + FilterID: filter.ID, + Filter: filter, + Keyword: form.Phrase, + WholeWord: util.Ptr(util.PtrValueOr(form.WholeWord, false)), + } + filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + + if err := p.state.DB.PutFilter(ctx, filter); err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + err = errors.New("you already have a filter with this title") + return nil, gtserror.NewErrorConflict(err, err.Error()) + } + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiFilter(ctx, filterKeyword) +} diff --git a/internal/processing/filters/v1/delete.go b/internal/processing/filters/v1/delete.go new file mode 100644 index 000000000..f2312f039 --- /dev/null +++ b/internal/processing/filters/v1/delete.go @@ -0,0 +1,67 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Delete an existing filter keyword and (if empty afterwards) filter for the given account. +func (p *Processor) Delete( + ctx context.Context, + account *gtsmodel.Account, + filterKeywordID string, +) gtserror.WithCode { + // Get enough of the filter keyword that we can look up its filter ID. + filterKeyword, err := p.state.DB.GetFilterKeywordByID(gtscontext.SetBarebones(ctx), filterKeywordID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return gtserror.NewErrorNotFound(err) + } + return gtserror.NewErrorInternalError(err) + } + if filterKeyword.AccountID != account.ID { + return gtserror.NewErrorNotFound(nil) + } + + // Get the filter for this keyword. + filter, err := p.state.DB.GetFilterByID(ctx, filterKeyword.FilterID) + if err != nil { + return gtserror.NewErrorNotFound(err) + } + + if len(filter.Keywords) > 1 || len(filter.Statuses) > 0 { + // The filter has other keywords or statuses. Delete only the requested filter keyword. + if err := p.state.DB.DeleteFilterKeywordByID(ctx, filterKeyword.ID); err != nil { + return gtserror.NewErrorInternalError(err) + } + } else { + // Delete the entire filter. + if err := p.state.DB.DeleteFilterByID(ctx, filter.ID); err != nil { + return gtserror.NewErrorInternalError(err) + } + } + + return nil +} diff --git a/internal/processing/filters/v1/filters.go b/internal/processing/filters/v1/filters.go new file mode 100644 index 000000000..d46c9e72c --- /dev/null +++ b/internal/processing/filters/v1/filters.go @@ -0,0 +1,35 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" +) + +type Processor struct { + state *state.State + converter *typeutils.Converter +} + +func New(state *state.State, converter *typeutils.Converter) Processor { + return Processor{ + state: state, + converter: converter, + } +} diff --git a/internal/processing/filters/v1/get.go b/internal/processing/filters/v1/get.go new file mode 100644 index 000000000..39575dd94 --- /dev/null +++ b/internal/processing/filters/v1/get.go @@ -0,0 +1,78 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "context" + "errors" + "slices" + "strings" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Get looks up a filter keyword by ID and returns it as a v1 filter. +func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, filterKeywordID string) (*apimodel.FilterV1, gtserror.WithCode) { + filterKeyword, err := p.state.DB.GetFilterKeywordByID(ctx, filterKeywordID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(err) + } + return nil, gtserror.NewErrorInternalError(err) + } + if filterKeyword.AccountID != account.ID { + return nil, gtserror.NewErrorNotFound(nil) + } + + return p.apiFilter(ctx, filterKeyword) +} + +// GetAll looks up all filter keywords for the current account and returns them as v1 filters. +func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.FilterV1, gtserror.WithCode) { + filters, err := p.state.DB.GetFilterKeywordsForAccountID( + ctx, + account.ID, + ) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, nil + } + return nil, gtserror.NewErrorInternalError(err) + } + + apiFilters := make([]*apimodel.FilterV1, 0, len(filters)) + for _, list := range filters { + apiFilter, errWithCode := p.apiFilter(ctx, list) + if errWithCode != nil { + return nil, errWithCode + } + + apiFilters = append(apiFilters, apiFilter) + } + + // Sort them by ID so that they're in a stable order. + // Clients may opt to sort them lexically in a locale-aware manner. + slices.SortFunc(apiFilters, func(lhs *apimodel.FilterV1, rhs *apimodel.FilterV1) int { + return strings.Compare(lhs.ID, rhs.ID) + }) + + return apiFilters, nil +} diff --git a/internal/processing/filters/v1/update.go b/internal/processing/filters/v1/update.go new file mode 100644 index 000000000..1fe49721b --- /dev/null +++ b/internal/processing/filters/v1/update.go @@ -0,0 +1,165 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package v1 + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// Update an existing filter and filter keyword for the given account, using the provided parameters. +// These params should have already been validated by the time they reach this function. +func (p *Processor) Update( + ctx context.Context, + account *gtsmodel.Account, + filterKeywordID string, + form *apimodel.FilterCreateUpdateRequestV1, +) (*apimodel.FilterV1, gtserror.WithCode) { + // Get enough of the filter keyword that we can look up its filter ID. + filterKeyword, err := p.state.DB.GetFilterKeywordByID(gtscontext.SetBarebones(ctx), filterKeywordID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(err) + } + return nil, gtserror.NewErrorInternalError(err) + } + if filterKeyword.AccountID != account.ID { + return nil, gtserror.NewErrorNotFound(nil) + } + + // Get the filter for this keyword. + filter, err := p.state.DB.GetFilterByID(ctx, filterKeyword.FilterID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(err) + } + return nil, gtserror.NewErrorInternalError(err) + } + + title := form.Phrase + action := gtsmodel.FilterActionWarn + if *form.Irreversible { + action = gtsmodel.FilterActionHide + } + expiresAt := time.Time{} + if form.ExpiresIn != nil { + expiresAt = time.Now().Add(time.Second * time.Duration(*form.ExpiresIn)) + } + contextHome := false + contextNotifications := false + contextPublic := false + contextThread := false + contextAccount := false + for _, context := range form.Context { + switch context { + case apimodel.FilterContextHome: + contextHome = true + case apimodel.FilterContextNotifications: + contextNotifications = true + case apimodel.FilterContextPublic: + contextPublic = true + case apimodel.FilterContextThread: + contextThread = true + case apimodel.FilterContextAccount: + contextAccount = true + default: + return nil, gtserror.NewErrorUnprocessableEntity( + fmt.Errorf("unsupported filter context '%s'", context), + ) + } + } + + // v1 filter APIs can't change certain fields for a filter with multiple keywords or any statuses, + // since it would be an unexpected side effect on filters that, to the v1 API, appear separate. + // See https://docs.joinmastodon.org/methods/filters/#update-v1 + if len(filter.Keywords) > 1 || len(filter.Statuses) > 0 { + forbiddenFields := make([]string, 0, 4) + if title != filter.Title { + forbiddenFields = append(forbiddenFields, "phrase") + } + if action != filter.Action { + forbiddenFields = append(forbiddenFields, "irreversible") + } + if expiresAt != filter.ExpiresAt { + forbiddenFields = append(forbiddenFields, "expires_in") + } + if contextHome != util.PtrValueOr(filter.ContextHome, false) || + contextNotifications != util.PtrValueOr(filter.ContextNotifications, false) || + contextPublic != util.PtrValueOr(filter.ContextPublic, false) || + contextThread != util.PtrValueOr(filter.ContextThread, false) || + contextAccount != util.PtrValueOr(filter.ContextAccount, false) { + forbiddenFields = append(forbiddenFields, "context") + } + if len(forbiddenFields) > 0 { + return nil, gtserror.NewErrorUnprocessableEntity( + fmt.Errorf("v1 filter backwards compatibility: can't change these fields for a filter with multiple keywords or any statuses: %s", strings.Join(forbiddenFields, ", ")), + ) + } + } + + // Now that we've checked that the changes are legal, apply them to the filter and keyword. + filter.Title = title + filter.Action = action + filter.ExpiresAt = expiresAt + filter.ContextHome = &contextHome + filter.ContextNotifications = &contextNotifications + filter.ContextPublic = &contextPublic + filter.ContextThread = &contextThread + filter.ContextAccount = &contextAccount + filterKeyword.Keyword = form.Phrase + filterKeyword.WholeWord = util.Ptr(util.PtrValueOr(form.WholeWord, false)) + + // We only want to update the relevant filter keyword. + filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + filter.Statuses = nil + filterKeyword.Filter = filter + + filterColumns := []string{ + "title", + "action", + "expires_at", + "context_home", + "context_notifications", + "context_public", + "context_thread", + "context_account", + } + filterKeywordColumns := []string{ + "keyword", + "whole_word", + } + if err := p.state.DB.UpdateFilter(ctx, filter, filterColumns, filterKeywordColumns, nil, nil); err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + err = errors.New("you already have a filter with this title") + return nil, gtserror.NewErrorConflict(err, err.Error()) + } + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiFilter(ctx, filterKeyword) +} diff --git a/internal/processing/processor.go b/internal/processing/processor.go index bb46d31a9..4aaa94fb7 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -29,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing/admin" "github.com/superseriousbusiness/gotosocial/internal/processing/common" "github.com/superseriousbusiness/gotosocial/internal/processing/fedi" + filtersv1 "github.com/superseriousbusiness/gotosocial/internal/processing/filters/v1" "github.com/superseriousbusiness/gotosocial/internal/processing/list" "github.com/superseriousbusiness/gotosocial/internal/processing/markers" "github.com/superseriousbusiness/gotosocial/internal/processing/media" @@ -68,20 +69,21 @@ type Processor struct { SUB-PROCESSORS */ - account account.Processor - admin admin.Processor - fedi fedi.Processor - list list.Processor - markers markers.Processor - media media.Processor - polls polls.Processor - report report.Processor - search search.Processor - status status.Processor - stream stream.Processor - timeline timeline.Processor - user user.Processor - workers workers.Processor + account account.Processor + admin admin.Processor + fedi fedi.Processor + filtersv1 filtersv1.Processor + list list.Processor + markers markers.Processor + media media.Processor + polls polls.Processor + report report.Processor + search search.Processor + status status.Processor + stream stream.Processor + timeline timeline.Processor + user user.Processor + workers workers.Processor } func (p *Processor) Account() *account.Processor { @@ -96,6 +98,10 @@ func (p *Processor) Fedi() *fedi.Processor { return &p.fedi } +func (p *Processor) FiltersV1() *filtersv1.Processor { + return &p.filtersv1 +} + func (p *Processor) List() *list.Processor { return &p.list } @@ -177,6 +183,7 @@ func NewProcessor( processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender) processor.fedi = fedi.New(state, &common, converter, federator, filter) + processor.filtersv1 = filtersv1.New(state, converter) processor.list = list.New(state, converter) processor.markers = markers.New(state, converter) processor.polls = polls.New(&common, state, converter) diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index 475ab0128..7256d2f82 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -111,7 +111,6 @@ func (p *Processor) contextGet( TopoSort(descendants, targetStatus.AccountID) - //goland:noinspection GoImportUsedAsName context := &apimodel.Context{ Ancestors: make([]apimodel.Status, 0, len(ancestors)), Descendants: make([]apimodel.Status, 0, len(descendants)), diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index d74f4d86e..df4598deb 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -1617,6 +1617,59 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta return apiAttachments, errs.Combine() } +// FilterToAPIFiltersV1 converts one GTS model filter into an API v1 filter list +func (c *Converter) FilterToAPIFiltersV1(ctx context.Context, filter *gtsmodel.Filter) ([]*apimodel.FilterV1, error) { + apiFilters := make([]*apimodel.FilterV1, 0, len(filter.Keywords)) + for _, filterKeyword := range filter.Keywords { + apiFilter, err := c.FilterKeywordToAPIFilterV1(ctx, filterKeyword) + if err != nil { + return nil, err + } + apiFilters = append(apiFilters, apiFilter) + } + return apiFilters, nil +} + +// FilterKeywordToAPIFilterV1 converts one GTS model filter and filter keyword into an API v1 filter +func (c *Converter) FilterKeywordToAPIFilterV1(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) (*apimodel.FilterV1, error) { + if filterKeyword.Filter == nil { + return nil, gtserror.New("FilterKeyword model's Filter field isn't populated, but needs to be") + } + filter := filterKeyword.Filter + + apiContexts := make([]apimodel.FilterContext, 0, apimodel.FilterContextNumValues) + if util.PtrValueOr(filter.ContextHome, false) { + apiContexts = append(apiContexts, apimodel.FilterContextHome) + } + if util.PtrValueOr(filter.ContextNotifications, false) { + apiContexts = append(apiContexts, apimodel.FilterContextNotifications) + } + if util.PtrValueOr(filter.ContextPublic, false) { + apiContexts = append(apiContexts, apimodel.FilterContextPublic) + } + if util.PtrValueOr(filter.ContextThread, false) { + apiContexts = append(apiContexts, apimodel.FilterContextThread) + } + if util.PtrValueOr(filter.ContextAccount, false) { + apiContexts = append(apiContexts, apimodel.FilterContextAccount) + } + + var expiresAt *string + if !filter.ExpiresAt.IsZero() { + expiresAt = util.Ptr(util.FormatISO8601(filter.ExpiresAt)) + } + + return &apimodel.FilterV1{ + // v1 filters have a single keyword each, so we use the filter keyword ID as the v1 filter ID. + ID: filterKeyword.ID, + Phrase: filterKeyword.Keyword, + Context: apiContexts, + WholeWord: util.PtrValueOr(filterKeyword.WholeWord, false), + ExpiresAt: expiresAt, + Irreversible: filter.Action == gtsmodel.FilterActionHide, + }, nil +} + // convertEmojisToAPIEmojis will convert a slice of GTS model emojis to frontend API model emojis, falling back to IDs if no GTS models supplied. func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) { var errs gtserror.MultiError diff --git a/internal/validate/formvalidation.go b/internal/validate/formvalidation.go index 3d1d05072..b0332a572 100644 --- a/internal/validate/formvalidation.go +++ b/internal/validate/formvalidation.go @@ -44,6 +44,7 @@ const ( maximumProfileFieldLength = 255 maximumProfileFields = 6 maximumListTitleLength = 200 + maximumFilterKeywordLength = 40 ) // Password returns a helpful error if the given password @@ -306,3 +307,44 @@ func MarkerName(name string) error { } return fmt.Errorf("marker timeline name '%s' was not recognized, valid options are '%s', '%s'", name, apimodel.MarkerNameHome, apimodel.MarkerNameNotifications) } + +// FilterKeyword validates the title of a new or updated List. +func FilterKeyword(keyword string) error { + if keyword == "" { + return fmt.Errorf("filter keyword must be provided, and must be no more than %d chars", maximumFilterKeywordLength) + } + + if length := len([]rune(keyword)); length > maximumFilterKeywordLength { + return fmt.Errorf("filter keyword length must be no more than %d chars, provided keyword was %d chars", maximumFilterKeywordLength, length) + } + + return nil +} + +// FilterContexts validates the context of a new or updated filter. +func FilterContexts(contexts []apimodel.FilterContext) error { + if len(contexts) == 0 { + return fmt.Errorf("at least one filter context is required") + } + for _, context := range contexts { + switch context { + case apimodel.FilterContextHome, + apimodel.FilterContextNotifications, + apimodel.FilterContextPublic, + apimodel.FilterContextThread, + apimodel.FilterContextAccount: + continue + default: + return fmt.Errorf( + "filter context '%s' was not recognized, valid options are '%s', '%s', '%s', '%s', '%s'", + context, + apimodel.FilterContextHome, + apimodel.FilterContextNotifications, + apimodel.FilterContextPublic, + apimodel.FilterContextThread, + apimodel.FilterContextAccount, + ) + } + } + return nil +} |