summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/bundb.go5
-rw-r--r--internal/db/bundb/filter.go339
-rw-r--r--internal/db/bundb/filter_test.go252
-rw-r--r--internal/db/bundb/filterkeyword.go191
-rw-r--r--internal/db/bundb/filterkeyword_test.go143
-rw-r--r--internal/db/bundb/filterstatus.go191
-rw-r--r--internal/db/bundb/filterstatus_test.go122
-rw-r--r--internal/db/bundb/migrations/20240126064004_add_filters.go97
-rw-r--r--internal/db/bundb/upsert.go230
-rw-r--r--internal/db/db.go1
-rw-r--r--internal/db/filter.go101
11 files changed, 1672 insertions, 0 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 4ecbec7b9..c49da272b 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -62,6 +62,7 @@ type DBService struct {
db.Emoji
db.HeaderFilter
db.Instance
+ db.Filter
db.List
db.Marker
db.Media
@@ -200,6 +201,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ Filter: &filterDB{
+ db: db,
+ state: state,
+ },
List: &listDB{
db: db,
state: state,
diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go
new file mode 100644
index 000000000..bcd572f34
--- /dev/null
+++ b/internal/db/bundb/filter.go
@@ -0,0 +1,339 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "slices"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+)
+
+type filterDB struct {
+ db *bun.DB
+ state *state.State
+}
+
+func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) {
+ filter, err := f.state.Caches.GTS.Filter.LoadOne(
+ "ID",
+ func() (*gtsmodel.Filter, error) {
+ var filter gtsmodel.Filter
+ err := f.db.
+ NewSelect().
+ Model(&filter).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ return &filter, err
+ },
+ id,
+ )
+ if err != nil {
+ // already processed
+ return nil, err
+ }
+
+ if !gtscontext.Barebones(ctx) {
+ if err := f.populateFilter(ctx, filter); err != nil {
+ return nil, err
+ }
+ }
+
+ return filter, nil
+}
+
+func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) {
+ // Fetch IDs of all filters owned by this account.
+ var filterIDs []string
+ if err := f.db.
+ NewSelect().
+ Model((*gtsmodel.Filter)(nil)).
+ Column("id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Scan(ctx, &filterIDs); err != nil {
+ return nil, err
+ }
+ if len(filterIDs) == 0 {
+ return nil, nil
+ }
+
+ // Get each filter by ID from the cache or DB.
+ uncachedFilterIDs := make([]string, 0, len(filterIDs))
+ filters, err := f.state.Caches.GTS.Filter.Load(
+ "ID",
+ func(load func(keyParts ...any) bool) {
+ for _, id := range filterIDs {
+ if !load(id) {
+ uncachedFilterIDs = append(uncachedFilterIDs, id)
+ }
+ }
+ },
+ func() ([]*gtsmodel.Filter, error) {
+ uncachedFilters := make([]*gtsmodel.Filter, 0, len(uncachedFilterIDs))
+ if err := f.db.
+ NewSelect().
+ Model(&uncachedFilters).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterIDs)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return uncachedFilters, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Put the filter structs in the same order as the filter IDs.
+ util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID })
+
+ if gtscontext.Barebones(ctx) {
+ return filters, nil
+ }
+
+ // Populate the filters. Remove any that we can't populate from the return slice.
+ errs := gtserror.NewMultiError(len(filters))
+ filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool {
+ if err := f.populateFilter(ctx, filter); err != nil {
+ errs.Appendf("error populating filter %s: %w", filter.ID, err)
+ return true
+ }
+ return false
+ })
+
+ return filters, errs.Combine()
+}
+
+func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error {
+ var err error
+ errs := gtserror.NewMultiError(2)
+
+ if filter.Keywords == nil {
+ // Filter keywords are not set, fetch from the database.
+ filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID(
+ gtscontext.SetBarebones(ctx),
+ filter.ID,
+ )
+ if err != nil {
+ errs.Appendf("error populating filter keywords: %w", err)
+ }
+ for i := range filter.Keywords {
+ filter.Keywords[i].Filter = filter
+ }
+ }
+
+ if filter.Statuses == nil {
+ // Filter statuses are not set, fetch from the database.
+ filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID(
+ gtscontext.SetBarebones(ctx),
+ filter.ID,
+ )
+ if err != nil {
+ errs.Appendf("error populating filter statuses: %w", err)
+ }
+ for i := range filter.Statuses {
+ filter.Statuses[i].Filter = filter
+ }
+ }
+
+ return errs.Combine()
+}
+
+func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error {
+ // Update database.
+ if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ if _, err := tx.
+ NewInsert().
+ Model(filter).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ if len(filter.Keywords) > 0 {
+ if _, err := tx.
+ NewInsert().
+ Model(&filter.Keywords).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ if len(filter.Statuses) > 0 {
+ if _, err := tx.
+ NewInsert().
+ Model(&filter.Statuses).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ // Update cache.
+ f.state.Caches.GTS.Filter.Put(filter)
+ f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...)
+ f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...)
+
+ return nil
+}
+
+func (f *filterDB) UpdateFilter(
+ ctx context.Context,
+ filter *gtsmodel.Filter,
+ filterColumns []string,
+ filterKeywordColumns []string,
+ deleteFilterKeywordIDs []string,
+ deleteFilterStatusIDs []string,
+) error {
+ updatedAt := time.Now()
+ filter.UpdatedAt = updatedAt
+ for _, filterKeyword := range filter.Keywords {
+ filterKeyword.UpdatedAt = updatedAt
+ }
+ for _, filterStatus := range filter.Statuses {
+ filterStatus.UpdatedAt = updatedAt
+ }
+
+ // If we're updating by column, ensure "updated_at" is included.
+ if len(filterColumns) > 0 {
+ filterColumns = append(filterColumns, "updated_at")
+ }
+ if len(filterKeywordColumns) > 0 {
+ filterKeywordColumns = append(filterKeywordColumns, "updated_at")
+ }
+
+ // Update database.
+ if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ if _, err := tx.
+ NewUpdate().
+ Model(filter).
+ Column(filterColumns...).
+ Where("? = ?", bun.Ident("id"), filter.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ if len(filter.Keywords) > 0 {
+ if _, err := NewUpsert(tx).
+ Model(&filter.Keywords).
+ Constraint("id").
+ Column(filterKeywordColumns...).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ if len(filter.Statuses) > 0 {
+ if _, err := tx.
+ NewInsert().
+ Ignore().
+ Model(&filter.Statuses).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ if len(deleteFilterKeywordIDs) > 0 {
+ if _, err := tx.
+ NewDelete().
+ Model((*gtsmodel.FilterKeyword)(nil)).
+ Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ if len(deleteFilterStatusIDs) > 0 {
+ if _, err := tx.
+ NewDelete().
+ Model((*gtsmodel.FilterStatus)(nil)).
+ Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ // Update cache.
+ f.state.Caches.GTS.Filter.Put(filter)
+ f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...)
+ f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...)
+ // TODO: (Vyr) replace with cache multi-invalidate call
+ for _, id := range deleteFilterKeywordIDs {
+ f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id)
+ }
+ for _, id := range deleteFilterStatusIDs {
+ f.state.Caches.GTS.FilterStatus.Invalidate("ID", id)
+ }
+
+ return nil
+}
+
+func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error {
+ if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Delete all keywords attached to filter.
+ if _, err := tx.
+ NewDelete().
+ Model((*gtsmodel.FilterKeyword)(nil)).
+ Where("? = ?", bun.Ident("filter_id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Delete all statuses attached to filter.
+ if _, err := tx.
+ NewDelete().
+ Model((*gtsmodel.FilterStatus)(nil)).
+ Where("? = ?", bun.Ident("filter_id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Delete the filter itself.
+ _, err := tx.
+ NewDelete().
+ Model((*gtsmodel.Filter)(nil)).
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ return err
+ }); err != nil {
+ return err
+ }
+
+ // Invalidate this filter.
+ f.state.Caches.GTS.Filter.Invalidate("ID", id)
+
+ // Invalidate all keywords and statuses for this filter.
+ f.state.Caches.GTS.FilterKeyword.Invalidate("FilterID", id)
+ f.state.Caches.GTS.FilterStatus.Invalidate("FilterID", id)
+
+ return nil
+}
diff --git a/internal/db/bundb/filter_test.go b/internal/db/bundb/filter_test.go
new file mode 100644
index 000000000..7940b6651
--- /dev/null
+++ b/internal/db/bundb/filter_test.go
@@ -0,0 +1,252 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+)
+
+type FilterTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+// TestFilterCRUD tests CRUD and read-all operations on filters.
+func (suite *FilterTestSuite) TestFilterCRUD() {
+ t := suite.T()
+
+ // Create new example filter with attached keyword.
+ filter := &gtsmodel.Filter{
+ ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
+ AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
+ Title: "foss jail",
+ Action: gtsmodel.FilterActionWarn,
+ ContextHome: util.Ptr(true),
+ ContextPublic: util.Ptr(true),
+ }
+ filterKeyword := &gtsmodel.FilterKeyword{
+ ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
+ AccountID: filter.AccountID,
+ FilterID: filter.ID,
+ Keyword: "GNU/Linux",
+ }
+ filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
+
+ // Create new cancellable test context.
+ ctx := context.Background()
+ ctx, cncl := context.WithCancel(ctx)
+ defer cncl()
+
+ // Insert the example filter into db.
+ if err := suite.db.PutFilter(ctx, filter); err != nil {
+ t.Fatalf("error inserting filter: %v", err)
+ }
+
+ // Now fetch newly created filter.
+ check, err := suite.db.GetFilterByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter: %v", err)
+ }
+
+ // Check all expected fields match.
+ suite.Equal(filter.ID, check.ID)
+ suite.Equal(filter.AccountID, check.AccountID)
+ suite.Equal(filter.Title, check.Title)
+ suite.Equal(filter.Action, check.Action)
+ suite.Equal(filter.ContextHome, check.ContextHome)
+ suite.Equal(filter.ContextNotifications, check.ContextNotifications)
+ suite.Equal(filter.ContextPublic, check.ContextPublic)
+ suite.Equal(filter.ContextThread, check.ContextThread)
+ suite.Equal(filter.ContextAccount, check.ContextAccount)
+ suite.NotZero(check.CreatedAt)
+ suite.NotZero(check.UpdatedAt)
+
+ suite.Equal(len(filter.Keywords), len(check.Keywords))
+ suite.Equal(filter.Keywords[0].ID, check.Keywords[0].ID)
+ suite.Equal(filter.Keywords[0].AccountID, check.Keywords[0].AccountID)
+ suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID)
+ suite.Equal(filter.Keywords[0].Keyword, check.Keywords[0].Keyword)
+ suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID)
+ suite.NotZero(check.Keywords[0].CreatedAt)
+ suite.NotZero(check.Keywords[0].UpdatedAt)
+
+ suite.Equal(len(filter.Statuses), len(check.Statuses))
+
+ // Fetch all filters.
+ all, err := suite.db.GetFiltersForAccountID(ctx, filter.AccountID)
+ if err != nil {
+ t.Fatalf("error fetching filters: %v", err)
+ }
+
+ // Ensure the result contains our example filter.
+ suite.Len(all, 1)
+ suite.Equal(filter.ID, all[0].ID)
+
+ suite.Len(all[0].Keywords, 1)
+ suite.Equal(filter.Keywords[0].ID, all[0].Keywords[0].ID)
+
+ suite.Empty(all[0].Statuses)
+
+ // Update the filter context and add another keyword and a status.
+ check.ContextNotifications = util.Ptr(true)
+
+ newKeyword := &gtsmodel.FilterKeyword{
+ ID: "01HNEMY810E5XKWDDMN5ZRE749",
+ FilterID: filter.ID,
+ AccountID: filter.AccountID,
+ Keyword: "tux",
+ }
+ check.Keywords = append(check.Keywords, newKeyword)
+
+ newStatus := &gtsmodel.FilterStatus{
+ ID: "01HNEMYD5XE7C8HH8TNCZ76FN2",
+ FilterID: filter.ID,
+ AccountID: filter.AccountID,
+ StatusID: "01HNEKZW34SQZ8PSDQ0Z10NZES",
+ }
+ check.Statuses = append(check.Statuses, newStatus)
+
+ if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
+ t.Fatalf("error updating filter: %v", err)
+ }
+ // Now fetch newly updated filter.
+ check, err = suite.db.GetFilterByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching updated filter: %v", err)
+ }
+
+ // Ensure expected fields were modified on check filter.
+ suite.True(check.UpdatedAt.After(filter.UpdatedAt))
+ if suite.NotNil(check.ContextHome) {
+ suite.True(*check.ContextHome)
+ }
+ if suite.NotNil(check.ContextNotifications) {
+ suite.True(*check.ContextNotifications)
+ }
+ if suite.NotNil(check.ContextPublic) {
+ suite.True(*check.ContextPublic)
+ }
+ if suite.NotNil(check.ContextThread) {
+ suite.False(*check.ContextThread)
+ }
+ if suite.NotNil(check.ContextAccount) {
+ suite.False(*check.ContextAccount)
+ }
+
+ // Ensure keyword entries were added.
+ suite.Len(check.Keywords, 2)
+ checkFilterKeywordIDs := make([]string, 0, 2)
+ for _, checkFilterKeyword := range check.Keywords {
+ checkFilterKeywordIDs = append(checkFilterKeywordIDs, checkFilterKeyword.ID)
+ }
+ suite.ElementsMatch([]string{filterKeyword.ID, newKeyword.ID}, checkFilterKeywordIDs)
+
+ // Ensure status entry was added.
+ suite.Len(check.Statuses, 1)
+ checkFilterStatusIDs := make([]string, 0, 1)
+ for _, checkFilterStatus := range check.Statuses {
+ checkFilterStatusIDs = append(checkFilterStatusIDs, checkFilterStatus.ID)
+ }
+ suite.ElementsMatch([]string{newStatus.ID}, checkFilterStatusIDs)
+
+ // Update one filter keyword and delete another. Don't change the filter or the filter status.
+ filterKeyword.WholeWord = util.Ptr(true)
+ check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
+ check.Statuses = nil
+
+ if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil {
+ t.Fatalf("error updating filter: %v", err)
+ }
+ check, err = suite.db.GetFilterByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching updated filter: %v", err)
+ }
+
+ // Ensure expected fields were not modified.
+ suite.Equal(filter.Title, check.Title)
+ suite.Equal(gtsmodel.FilterActionWarn, check.Action)
+ if suite.NotNil(check.ContextHome) {
+ suite.True(*check.ContextHome)
+ }
+ if suite.NotNil(check.ContextNotifications) {
+ suite.True(*check.ContextNotifications)
+ }
+ if suite.NotNil(check.ContextPublic) {
+ suite.True(*check.ContextPublic)
+ }
+ if suite.NotNil(check.ContextThread) {
+ suite.False(*check.ContextThread)
+ }
+ if suite.NotNil(check.ContextAccount) {
+ suite.False(*check.ContextAccount)
+ }
+
+ // Ensure only changed field of keyword was modified, and other keyword was deleted.
+ suite.Len(check.Keywords, 1)
+ suite.Equal(filterKeyword.ID, check.Keywords[0].ID)
+ suite.Equal("GNU/Linux", check.Keywords[0].Keyword)
+ if suite.NotNil(check.Keywords[0].WholeWord) {
+ suite.True(*check.Keywords[0].WholeWord)
+ }
+
+ // Ensure status entry was not deleted.
+ suite.Len(check.Statuses, 1)
+ suite.Equal(newStatus.ID, check.Statuses[0].ID)
+
+ // Add another status entry for the same status ID. It should be ignored without problems.
+ redundantStatus := &gtsmodel.FilterStatus{
+ ID: "01HQXJ5Y405XZSQ67C2BSQ6HJ0",
+ FilterID: filter.ID,
+ AccountID: filter.AccountID,
+ StatusID: newStatus.StatusID,
+ }
+ check.Statuses = []*gtsmodel.FilterStatus{redundantStatus}
+ if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
+ t.Fatalf("error updating filter: %v", err)
+ }
+ check, err = suite.db.GetFilterByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching updated filter: %v", err)
+ }
+
+ // Ensure status entry was not deleted, updated, or duplicated.
+ suite.Len(check.Statuses, 1)
+ suite.Equal(newStatus.ID, check.Statuses[0].ID)
+ suite.Equal(newStatus.StatusID, check.Statuses[0].StatusID)
+
+ // Now delete the filter from the DB.
+ if err := suite.db.DeleteFilterByID(ctx, filter.ID); err != nil {
+ t.Fatalf("error deleting filter: %v", err)
+ }
+
+ // Ensure we can't refetch it.
+ _, err = suite.db.GetFilterByID(ctx, filter.ID)
+ if !errors.Is(err, db.ErrNoEntries) {
+ t.Fatalf("fetching deleted filter returned unexpected error: %v", err)
+ }
+}
+
+func TestFilterTestSuite(t *testing.T) {
+ suite.Run(t, new(FilterTestSuite))
+}
diff --git a/internal/db/bundb/filterkeyword.go b/internal/db/bundb/filterkeyword.go
new file mode 100644
index 000000000..703d58d43
--- /dev/null
+++ b/internal/db/bundb/filterkeyword.go
@@ -0,0 +1,191 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "slices"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+)
+
+func (f *filterDB) GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) {
+ filterKeyword, err := f.state.Caches.GTS.FilterKeyword.LoadOne(
+ "ID",
+ func() (*gtsmodel.FilterKeyword, error) {
+ var filterKeyword gtsmodel.FilterKeyword
+ err := f.db.
+ NewSelect().
+ Model(&filterKeyword).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ return &filterKeyword, err
+ },
+ id,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ if !gtscontext.Barebones(ctx) {
+ err = f.populateFilterKeyword(ctx, filterKeyword)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return filterKeyword, nil
+}
+
+func (f *filterDB) populateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error {
+ if filterKeyword.Filter == nil {
+ // Filter is not set, fetch from the cache or database.
+ filter, err := f.state.DB.GetFilterByID(
+ // Don't populate the filter with all of its keywords and statuses or we'll just end up back here.
+ gtscontext.SetBarebones(ctx),
+ filterKeyword.FilterID,
+ )
+ if err != nil {
+ return err
+ }
+ filterKeyword.Filter = filter
+ }
+
+ return nil
+}
+
+func (f *filterDB) GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) {
+ return f.getFilterKeywords(ctx, "filter_id", filterID)
+}
+
+func (f *filterDB) GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) {
+ return f.getFilterKeywords(ctx, "account_id", accountID)
+}
+
+func (f *filterDB) getFilterKeywords(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterKeyword, error) {
+ var filterKeywordIDs []string
+ if err := f.db.
+ NewSelect().
+ Model((*gtsmodel.FilterKeyword)(nil)).
+ Column("id").
+ Where("? = ?", bun.Ident(idColumn), id).
+ Scan(ctx, &filterKeywordIDs); err != nil {
+ return nil, err
+ }
+ if len(filterKeywordIDs) == 0 {
+ return nil, nil
+ }
+
+ // Get each filter keyword by ID from the cache or DB.
+ uncachedFilterKeywordIDs := make([]string, 0, len(filterKeywordIDs))
+ filterKeywords, err := f.state.Caches.GTS.FilterKeyword.Load(
+ "ID",
+ func(load func(keyParts ...any) bool) {
+ for _, id := range filterKeywordIDs {
+ if !load(id) {
+ uncachedFilterKeywordIDs = append(uncachedFilterKeywordIDs, id)
+ }
+ }
+ },
+ func() ([]*gtsmodel.FilterKeyword, error) {
+ uncachedFilterKeywords := make([]*gtsmodel.FilterKeyword, 0, len(uncachedFilterKeywordIDs))
+ if err := f.db.
+ NewSelect().
+ Model(&uncachedFilterKeywords).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterKeywordIDs)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return uncachedFilterKeywords, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Put the filter keyword structs in the same order as the filter keyword IDs.
+ util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string {
+ return filterKeyword.ID
+ })
+
+ if gtscontext.Barebones(ctx) {
+ return filterKeywords, nil
+ }
+
+ // Populate the filter keywords. Remove any that we can't populate from the return slice.
+ errs := gtserror.NewMultiError(len(filterKeywords))
+ filterKeywords = slices.DeleteFunc(filterKeywords, func(filterKeyword *gtsmodel.FilterKeyword) bool {
+ if err := f.populateFilterKeyword(ctx, filterKeyword); err != nil {
+ errs.Appendf(
+ "error populating filter keyword %s: %w",
+ filterKeyword.ID,
+ err,
+ )
+ return true
+ }
+ return false
+ })
+
+ return filterKeywords, errs.Combine()
+}
+
+func (f *filterDB) PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error {
+ return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error {
+ _, err := f.db.
+ NewInsert().
+ Model(filterKeyword).
+ Exec(ctx)
+ return err
+ })
+}
+
+func (f *filterDB) UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error {
+ filterKeyword.UpdatedAt = time.Now()
+ if len(columns) > 0 {
+ columns = append(columns, "updated_at")
+ }
+
+ return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error {
+ _, err := f.db.
+ NewUpdate().
+ Model(filterKeyword).
+ Where("? = ?", bun.Ident("id"), filterKeyword.ID).
+ Column(columns...).
+ Exec(ctx)
+ return err
+ })
+}
+
+func (f *filterDB) DeleteFilterKeywordByID(ctx context.Context, id string) error {
+ if _, err := f.db.
+ NewDelete().
+ Model((*gtsmodel.FilterKeyword)(nil)).
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id)
+
+ return nil
+}
diff --git a/internal/db/bundb/filterkeyword_test.go b/internal/db/bundb/filterkeyword_test.go
new file mode 100644
index 000000000..91c8d192c
--- /dev/null
+++ b/internal/db/bundb/filterkeyword_test.go
@@ -0,0 +1,143 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "errors"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+)
+
+// TestFilterKeywordCRUD tests CRUD and read-all operations on filter keywords.
+func (suite *FilterTestSuite) TestFilterKeywordCRUD() {
+ t := suite.T()
+
+ // Create new filter.
+ filter := &gtsmodel.Filter{
+ ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
+ AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
+ Title: "foss jail",
+ Action: gtsmodel.FilterActionWarn,
+ ContextHome: util.Ptr(true),
+ ContextPublic: util.Ptr(true),
+ }
+
+ // Create new cancellable test context.
+ ctx := context.Background()
+ ctx, cncl := context.WithCancel(ctx)
+ defer cncl()
+
+ // Insert the new filter into the DB.
+ err := suite.db.PutFilter(ctx, filter)
+ if err != nil {
+ t.Fatalf("error inserting filter: %v", err)
+ }
+
+ // There should be no filter keywords yet.
+ all, err := suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID)
+ if err != nil {
+ t.Fatalf("error fetching filter keywords: %v", err)
+ }
+ suite.Empty(all)
+
+ // Add a filter keyword to it.
+ filterKeyword := &gtsmodel.FilterKeyword{
+ ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
+ AccountID: filter.AccountID,
+ FilterID: filter.ID,
+ Keyword: "GNU/Linux",
+ }
+
+ // Insert the new filter keyword into the DB.
+ err = suite.db.PutFilterKeyword(ctx, filterKeyword)
+ if err != nil {
+ t.Fatalf("error inserting filter keyword: %v", err)
+ }
+
+ // Try to find it again and ensure it has the fields we expect.
+ check, err := suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter keyword: %v", err)
+ }
+ suite.Equal(filterKeyword.ID, check.ID)
+ suite.NotZero(check.CreatedAt)
+ suite.NotZero(check.UpdatedAt)
+ suite.Equal(filterKeyword.AccountID, check.AccountID)
+ suite.Equal(filterKeyword.FilterID, check.FilterID)
+ suite.Equal(filterKeyword.Keyword, check.Keyword)
+ suite.Equal(filterKeyword.WholeWord, check.WholeWord)
+
+ // Loading filter keywords by account ID should find the one we inserted.
+ all, err = suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID)
+ if err != nil {
+ t.Fatalf("error fetching filter keywords: %v", err)
+ }
+ suite.Len(all, 1)
+ suite.Equal(filterKeyword.ID, all[0].ID)
+
+ // Loading filter keywords by filter ID should also find the one we inserted.
+ all, err = suite.db.GetFilterKeywordsForFilterID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter keywords: %v", err)
+ }
+ suite.Len(all, 1)
+ suite.Equal(filterKeyword.ID, all[0].ID)
+
+ // Modify the filter keyword.
+ filterKeyword.WholeWord = util.Ptr(true)
+ err = suite.db.UpdateFilterKeyword(ctx, filterKeyword)
+ if err != nil {
+ t.Fatalf("error updating filter keyword: %v", err)
+ }
+
+ // Try to find it again and ensure it has the updated fields we expect.
+ check, err = suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter keyword: %v", err)
+ }
+ suite.Equal(filterKeyword.ID, check.ID)
+ suite.NotZero(check.CreatedAt)
+ suite.True(check.UpdatedAt.After(check.CreatedAt))
+ suite.Equal(filterKeyword.AccountID, check.AccountID)
+ suite.Equal(filterKeyword.FilterID, check.FilterID)
+ suite.Equal(filterKeyword.Keyword, check.Keyword)
+ suite.Equal(filterKeyword.WholeWord, check.WholeWord)
+
+ // Delete the filter keyword from the DB.
+ err = suite.db.DeleteFilterKeywordByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error deleting filter keyword: %v", err)
+ }
+
+ // Ensure we can't refetch it.
+ check, err = suite.db.GetFilterKeywordByID(ctx, filter.ID)
+ if !errors.Is(err, db.ErrNoEntries) {
+ t.Fatalf("fetching deleted filter keyword returned unexpected error: %v", err)
+ }
+ suite.Nil(check)
+
+ // Ensure the filter itself is still there.
+ checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter: %v", err)
+ }
+ suite.Equal(filter.ID, checkFilter.ID)
+}
diff --git a/internal/db/bundb/filterstatus.go b/internal/db/bundb/filterstatus.go
new file mode 100644
index 000000000..1e98f5958
--- /dev/null
+++ b/internal/db/bundb/filterstatus.go
@@ -0,0 +1,191 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "slices"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+)
+
+func (f *filterDB) GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) {
+ filterStatus, err := f.state.Caches.GTS.FilterStatus.LoadOne(
+ "ID",
+ func() (*gtsmodel.FilterStatus, error) {
+ var filterStatus gtsmodel.FilterStatus
+ err := f.db.
+ NewSelect().
+ Model(&filterStatus).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ return &filterStatus, err
+ },
+ id,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ if !gtscontext.Barebones(ctx) {
+ err = f.populateFilterStatus(ctx, filterStatus)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return filterStatus, nil
+}
+
+func (f *filterDB) populateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error {
+ if filterStatus.Filter == nil {
+ // Filter is not set, fetch from the cache or database.
+ filter, err := f.state.DB.GetFilterByID(
+ // Don't populate the filter with all of its keywords and statuses or we'll just end up back here.
+ gtscontext.SetBarebones(ctx),
+ filterStatus.FilterID,
+ )
+ if err != nil {
+ return err
+ }
+ filterStatus.Filter = filter
+ }
+
+ return nil
+}
+
+func (f *filterDB) GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) {
+ return f.getFilterStatuses(ctx, "filter_id", filterID)
+}
+
+func (f *filterDB) GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) {
+ return f.getFilterStatuses(ctx, "account_id", accountID)
+}
+
+func (f *filterDB) getFilterStatuses(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterStatus, error) {
+ var filterStatusIDs []string
+ if err := f.db.
+ NewSelect().
+ Model((*gtsmodel.FilterStatus)(nil)).
+ Column("id").
+ Where("? = ?", bun.Ident(idColumn), id).
+ Scan(ctx, &filterStatusIDs); err != nil {
+ return nil, err
+ }
+ if len(filterStatusIDs) == 0 {
+ return nil, nil
+ }
+
+ // Get each filter status by ID from the cache or DB.
+ uncachedFilterStatusIDs := make([]string, 0, len(filterStatusIDs))
+ filterStatuses, err := f.state.Caches.GTS.FilterStatus.Load(
+ "ID",
+ func(load func(keyParts ...any) bool) {
+ for _, id := range filterStatusIDs {
+ if !load(id) {
+ uncachedFilterStatusIDs = append(uncachedFilterStatusIDs, id)
+ }
+ }
+ },
+ func() ([]*gtsmodel.FilterStatus, error) {
+ uncachedFilterStatuses := make([]*gtsmodel.FilterStatus, 0, len(uncachedFilterStatusIDs))
+ if err := f.db.
+ NewSelect().
+ Model(&uncachedFilterStatuses).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterStatusIDs)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return uncachedFilterStatuses, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Put the filter status structs in the same order as the filter status IDs.
+ util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string {
+ return filterStatus.ID
+ })
+
+ if gtscontext.Barebones(ctx) {
+ return filterStatuses, nil
+ }
+
+ // Populate the filter statuses. Remove any that we can't populate from the return slice.
+ errs := gtserror.NewMultiError(len(filterStatuses))
+ filterStatuses = slices.DeleteFunc(filterStatuses, func(filterStatus *gtsmodel.FilterStatus) bool {
+ if err := f.populateFilterStatus(ctx, filterStatus); err != nil {
+ errs.Appendf(
+ "error populating filter status %s: %w",
+ filterStatus.ID,
+ err,
+ )
+ return true
+ }
+ return false
+ })
+
+ return filterStatuses, errs.Combine()
+}
+
+func (f *filterDB) PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error {
+ return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error {
+ _, err := f.db.
+ NewInsert().
+ Model(filterStatus).
+ Exec(ctx)
+ return err
+ })
+}
+
+func (f *filterDB) UpdateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus, columns ...string) error {
+ filterStatus.UpdatedAt = time.Now()
+ if len(columns) > 0 {
+ columns = append(columns, "updated_at")
+ }
+
+ return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error {
+ _, err := f.db.
+ NewUpdate().
+ Model(filterStatus).
+ Where("? = ?", bun.Ident("id"), filterStatus.ID).
+ Column(columns...).
+ Exec(ctx)
+ return err
+ })
+}
+
+func (f *filterDB) DeleteFilterStatusByID(ctx context.Context, id string) error {
+ if _, err := f.db.
+ NewDelete().
+ Model((*gtsmodel.FilterStatus)(nil)).
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ f.state.Caches.GTS.FilterStatus.Invalidate("ID", id)
+
+ return nil
+}
diff --git a/internal/db/bundb/filterstatus_test.go b/internal/db/bundb/filterstatus_test.go
new file mode 100644
index 000000000..48ddb1bed
--- /dev/null
+++ b/internal/db/bundb/filterstatus_test.go
@@ -0,0 +1,122 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "errors"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+)
+
+// TestFilterStatusCRD tests CRD (no U) and read-all operations on filter statuses.
+func (suite *FilterTestSuite) TestFilterStatusCRD() {
+ t := suite.T()
+
+ // Create new filter.
+ filter := &gtsmodel.Filter{
+ ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
+ AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
+ Title: "foss jail",
+ Action: gtsmodel.FilterActionWarn,
+ ContextHome: util.Ptr(true),
+ ContextPublic: util.Ptr(true),
+ }
+
+ // Create new cancellable test context.
+ ctx := context.Background()
+ ctx, cncl := context.WithCancel(ctx)
+ defer cncl()
+
+ // Insert the new filter into the DB.
+ err := suite.db.PutFilter(ctx, filter)
+ if err != nil {
+ t.Fatalf("error inserting filter: %v", err)
+ }
+
+ // There should be no filter statuses yet.
+ all, err := suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID)
+ if err != nil {
+ t.Fatalf("error fetching filter statuses: %v", err)
+ }
+ suite.Empty(all)
+
+ // Add a filter status to it.
+ filterStatus := &gtsmodel.FilterStatus{
+ ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
+ AccountID: filter.AccountID,
+ FilterID: filter.ID,
+ StatusID: "01HQXGMQ3QFXRT4GX9WNQ8KC0X",
+ }
+
+ // Insert the new filter status into the DB.
+ err = suite.db.PutFilterStatus(ctx, filterStatus)
+ if err != nil {
+ t.Fatalf("error inserting filter status: %v", err)
+ }
+
+ // Try to find it again and ensure it has the fields we expect.
+ check, err := suite.db.GetFilterStatusByID(ctx, filterStatus.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter status: %v", err)
+ }
+ suite.Equal(filterStatus.ID, check.ID)
+ suite.NotZero(check.CreatedAt)
+ suite.NotZero(check.UpdatedAt)
+ suite.Equal(filterStatus.AccountID, check.AccountID)
+ suite.Equal(filterStatus.FilterID, check.FilterID)
+ suite.Equal(filterStatus.StatusID, check.StatusID)
+
+ // Loading filter statuses by account ID should find the one we inserted.
+ all, err = suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID)
+ if err != nil {
+ t.Fatalf("error fetching filter statuses: %v", err)
+ }
+ suite.Len(all, 1)
+ suite.Equal(filterStatus.ID, all[0].ID)
+
+ // Loading filter statuses by filter ID should also find the one we inserted.
+ all, err = suite.db.GetFilterStatusesForFilterID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter statuses: %v", err)
+ }
+ suite.Len(all, 1)
+ suite.Equal(filterStatus.ID, all[0].ID)
+
+ // Delete the filter status from the DB.
+ err = suite.db.DeleteFilterStatusByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error deleting filter status: %v", err)
+ }
+
+ // Ensure we can't refetch it.
+ check, err = suite.db.GetFilterStatusByID(ctx, filter.ID)
+ if !errors.Is(err, db.ErrNoEntries) {
+ t.Fatalf("fetching deleted filter status returned unexpected error: %v", err)
+ }
+ suite.Nil(check)
+
+ // Ensure the filter itself is still there.
+ checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching filter: %v", err)
+ }
+ suite.Equal(filter.ID, checkFilter.ID)
+}
diff --git a/internal/db/bundb/migrations/20240126064004_add_filters.go b/internal/db/bundb/migrations/20240126064004_add_filters.go
new file mode 100644
index 000000000..3ad22f9d8
--- /dev/null
+++ b/internal/db/bundb/migrations/20240126064004_add_filters.go
@@ -0,0 +1,97 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Filter table.
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.Filter{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Filter keyword table.
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.FilterKeyword{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Filter status table.
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.FilterStatus{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Add indexes to the filter tables.
+ for table, indexes := range map[string]map[string][]string{
+ "filters": {
+ "filters_account_id_idx": {"account_id"},
+ },
+ "filter_keywords": {
+ "filter_keywords_account_id_idx": {"account_id"},
+ "filter_keywords_filter_id_idx": {"filter_id"},
+ },
+ "filter_statuses": {
+ "filter_statuses_account_id_idx": {"account_id"},
+ "filter_statuses_filter_id_idx": {"filter_id"},
+ },
+ } {
+ for index, columns := range indexes {
+ if _, err := tx.
+ NewCreateIndex().
+ Table(table).
+ Index(index).
+ Column(columns...).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go
new file mode 100644
index 000000000..34724446c
--- /dev/null
+++ b/internal/db/bundb/upsert.go
@@ -0,0 +1,230 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+ "strings"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect"
+)
+
+// UpsertQuery is a wrapper around an insert query that can update if an insert fails.
+// Doesn't implement the full set of Bun query methods, but we can add more if we need them.
+// See https://bun.uptrace.dev/guide/query-insert.html#upsert
+type UpsertQuery struct {
+ db bun.IDB
+ model interface{}
+ constraints []string
+ columns []string
+}
+
+func NewUpsert(idb bun.IDB) *UpsertQuery {
+ // note: passing in rawtx as conn iface so no double query-hook
+ // firing when passed through the bun.Tx.Query___() functions.
+ return &UpsertQuery{db: idb}
+}
+
+// Model sets the model or models to upsert.
+func (u *UpsertQuery) Model(model interface{}) *UpsertQuery {
+ u.model = model
+ return u
+}
+
+// Constraint sets the columns or indexes that are used to check for conflicts.
+// This is required.
+func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery {
+ u.constraints = constraints
+ return u
+}
+
+// Column sets the columns to update if an insert does't happen.
+// If empty, all columns not being used for constraints will be updated.
+// Cannot overlap with Constraint.
+func (u *UpsertQuery) Column(columns ...string) *UpsertQuery {
+ u.columns = columns
+ return u
+}
+
+// insertDialect errors if we're using a dialect in which we don't know how to upsert.
+func (u *UpsertQuery) insertDialect() error {
+ dialectName := u.db.Dialect().Name()
+ switch dialectName {
+ case dialect.PG, dialect.SQLite:
+ return nil
+ default:
+ // FUTURE: MySQL has its own variation on upserts, but the syntax is different.
+ return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName)
+ }
+}
+
+// insertConstraints checks that we have constraints and returns them.
+func (u *UpsertQuery) insertConstraints() ([]string, error) {
+ if len(u.constraints) == 0 {
+ return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided")
+ }
+ return u.constraints, nil
+}
+
+// insertColumns returns the non-constraint columns we'll be updating.
+func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) {
+ // Constraints as a set.
+ constraintSet := make(map[string]struct{}, len(constraints))
+ for _, constraint := range constraints {
+ constraintSet[constraint] = struct{}{}
+ }
+
+ var columns []string
+ var err error
+ if len(u.columns) == 0 {
+ columns, err = u.insertColumnsDefault(constraintSet)
+ } else {
+ columns, err = u.insertColumnsSpecified(constraintSet)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if len(columns) == 0 {
+ return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting")
+ }
+
+ return columns, nil
+}
+
+// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking.
+func hasElem(modelType reflect.Type) bool {
+ switch modelType.Kind() {
+ case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice:
+ return true
+ default:
+ return false
+ }
+}
+
+// insertColumnsDefault returns all non-constraint columns from the model schema.
+func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) {
+ // Get underlying struct type.
+ modelType := reflect.TypeOf(u.model)
+ for hasElem(modelType) {
+ modelType = modelType.Elem()
+ }
+
+ table := u.db.Dialect().Tables().Get(modelType)
+ if table == nil {
+ return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model)
+ }
+
+ columns := make([]string, 0, len(u.columns))
+ for _, field := range table.Fields {
+ column := field.Name
+ if _, overlaps := constraintSet[column]; !overlaps {
+ columns = append(columns, column)
+ }
+ }
+
+ return columns, nil
+}
+
+// insertColumnsSpecified ensures constraints and specified columns to update don't overlap.
+func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) {
+ overlapping := make([]string, 0, min(len(u.constraints), len(u.columns)))
+ for _, column := range u.columns {
+ if _, overlaps := constraintSet[column]; overlaps {
+ overlapping = append(overlapping, column)
+ }
+ }
+
+ if len(overlapping) > 0 {
+ return nil, gtserror.Newf(
+ "UpsertQuery: the following columns can't be used for both constraints and columns to update: %s",
+ strings.Join(overlapping, ", "),
+ )
+ }
+
+ return u.columns, nil
+}
+
+// insert tries to create a Bun insert query from an upsert query.
+func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) {
+ var err error
+
+ err = u.insertDialect()
+ if err != nil {
+ return nil, err
+ }
+
+ constraints, err := u.insertConstraints()
+ if err != nil {
+ return nil, err
+ }
+
+ columns, err := u.insertColumns(constraints)
+ if err != nil {
+ return nil, err
+ }
+
+ // Build the parts of the query that need us to generate SQL.
+ constraintIDPlaceholders := make([]string, 0, len(constraints))
+ constraintIDs := make([]interface{}, 0, len(constraints))
+ for _, constraint := range constraints {
+ constraintIDPlaceholders = append(constraintIDPlaceholders, "?")
+ constraintIDs = append(constraintIDs, bun.Ident(constraint))
+ }
+ onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update"
+
+ setClauses := make([]string, 0, len(columns))
+ setIDs := make([]interface{}, 0, 2*len(columns))
+ for _, column := range columns {
+ // "excluded" is a special table that contains only the row involved in a conflict.
+ setClauses = append(setClauses, "? = excluded.?")
+ setIDs = append(setIDs, bun.Ident(column), bun.Ident(column))
+ }
+ setSQL := strings.Join(setClauses, ", ")
+
+ insertQuery := u.db.
+ NewInsert().
+ Model(u.model).
+ On(onSQL, constraintIDs...).
+ Set(setSQL, setIDs...)
+
+ return insertQuery, nil
+}
+
+// Exec builds a Bun insert query from the upsert query, and executes it.
+func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ insertQuery, err := u.insertQuery()
+ if err != nil {
+ return nil, err
+ }
+
+ return insertQuery.Exec(ctx, dest...)
+}
+
+// Scan builds a Bun insert query from the upsert query, and scans it.
+func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error {
+ insertQuery, err := u.insertQuery()
+ if err != nil {
+ return err
+ }
+
+ return insertQuery.Scan(ctx, dest...)
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index 361687e94..f23324777 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -32,6 +32,7 @@ type DB interface {
Emoji
HeaderFilter
Instance
+ Filter
List
Marker
Media
diff --git a/internal/db/filter.go b/internal/db/filter.go
new file mode 100644
index 000000000..18943b4f9
--- /dev/null
+++ b/internal/db/filter.go
@@ -0,0 +1,101 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// Filter contains methods for creating, reading, updating, and deleting filters and their keyword and status entries.
+type Filter interface {
+ //<editor-fold desc="Filter methods">
+
+ // GetFilterByID gets one filter with the given id.
+ GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error)
+
+ // GetFiltersForAccountID gets all filters owned by the given accountID.
+ GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error)
+
+ // PutFilter puts a new filter in the database, adding any attached keywords or statuses.
+ // It uses a transaction to ensure no partial updates.
+ PutFilter(ctx context.Context, filter *gtsmodel.Filter) error
+
+ // UpdateFilter updates the given filter,
+ // upserts any attached keywords and inserts any new statuses (existing statuses cannot be updated),
+ // and deletes indicated filter keywords and statuses by ID.
+ // It uses a transaction to ensure no partial updates.
+ // The column lists are optional; if not specified, all columns will be updated.
+ UpdateFilter(
+ ctx context.Context,
+ filter *gtsmodel.Filter,
+ filterColumns []string,
+ filterKeywordColumns []string,
+ deleteFilterKeywordIDs []string,
+ deleteFilterStatusIDs []string,
+ ) error
+
+ // DeleteFilterByID deletes one filter with the given ID.
+ // It uses a transaction to ensure no partial updates.
+ DeleteFilterByID(ctx context.Context, id string) error
+
+ //</editor-fold>
+
+ //<editor-fold desc="Filter keyword methods">
+
+ // GetFilterKeywordByID gets one filter keyword with the given ID.
+ GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error)
+
+ // GetFilterKeywordsForFilterID gets filter keywords from the given filterID.
+ GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error)
+
+ // GetFilterKeywordsForAccountID gets filter keywords from the given accountID.
+ GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error)
+
+ // PutFilterKeyword inserts a single filter keyword into the database.
+ PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error
+
+ // UpdateFilterKeyword updates the given filter keyword.
+ // Columns is optional, if not specified all will be updated.
+ UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error
+
+ // DeleteFilterKeywordByID deletes one filter keyword with the given id.
+ DeleteFilterKeywordByID(ctx context.Context, id string) error
+
+ //</editor-fold>
+
+ //<editor-fold desc="Filter status methods">
+
+ // GetFilterStatusByID gets one filter status with the given ID.
+ GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error)
+
+ // GetFilterStatusesForFilterID gets filter statuses from the given filterID.
+ GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error)
+
+ // GetFilterStatusesForAccountID gets filter keywords from the given accountID.
+ GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error)
+
+ // PutFilterStatus inserts a single filter status into the database.
+ PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error
+
+ // DeleteFilterStatusByID deletes one filter status with the given id.
+ DeleteFilterStatusByID(ctx context.Context, id string) error
+
+ //</editor-fold>
+}