diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
-rw-r--r-- | internal/db/bundb/filter.go | 339 | ||||
-rw-r--r-- | internal/db/bundb/filter_test.go | 252 | ||||
-rw-r--r-- | internal/db/bundb/filterkeyword.go | 191 | ||||
-rw-r--r-- | internal/db/bundb/filterkeyword_test.go | 143 | ||||
-rw-r--r-- | internal/db/bundb/filterstatus.go | 191 | ||||
-rw-r--r-- | internal/db/bundb/filterstatus_test.go | 122 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20240126064004_add_filters.go | 97 | ||||
-rw-r--r-- | internal/db/bundb/upsert.go | 230 | ||||
-rw-r--r-- | internal/db/db.go | 1 | ||||
-rw-r--r-- | internal/db/filter.go | 101 |
11 files changed, 1672 insertions, 0 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 4ecbec7b9..c49da272b 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -62,6 +62,7 @@ type DBService struct { db.Emoji db.HeaderFilter db.Instance + db.Filter db.List db.Marker db.Media @@ -200,6 +201,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + Filter: &filterDB{ + db: db, + state: state, + }, List: &listDB{ db: db, state: state, diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go new file mode 100644 index 000000000..bcd572f34 --- /dev/null +++ b/internal/db/bundb/filter.go @@ -0,0 +1,339 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +type filterDB struct { + db *bun.DB + state *state.State +} + +func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) { + filter, err := f.state.Caches.GTS.Filter.LoadOne( + "ID", + func() (*gtsmodel.Filter, error) { + var filter gtsmodel.Filter + err := f.db. + NewSelect(). + Model(&filter). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filter, err + }, + id, + ) + if err != nil { + // already processed + return nil, err + } + + if !gtscontext.Barebones(ctx) { + if err := f.populateFilter(ctx, filter); err != nil { + return nil, err + } + } + + return filter, nil +} + +func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) { + // Fetch IDs of all filters owned by this account. + var filterIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.Filter)(nil)). + Column("id"). + Where("? = ?", bun.Ident("account_id"), accountID). + Scan(ctx, &filterIDs); err != nil { + return nil, err + } + if len(filterIDs) == 0 { + return nil, nil + } + + // Get each filter by ID from the cache or DB. + uncachedFilterIDs := make([]string, 0, len(filterIDs)) + filters, err := f.state.Caches.GTS.Filter.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterIDs { + if !load(id) { + uncachedFilterIDs = append(uncachedFilterIDs, id) + } + } + }, + func() ([]*gtsmodel.Filter, error) { + uncachedFilters := make([]*gtsmodel.Filter, 0, len(uncachedFilterIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilters). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilters, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter structs in the same order as the filter IDs. + util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID }) + + if gtscontext.Barebones(ctx) { + return filters, nil + } + + // Populate the filters. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filters)) + filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool { + if err := f.populateFilter(ctx, filter); err != nil { + errs.Appendf("error populating filter %s: %w", filter.ID, err) + return true + } + return false + }) + + return filters, errs.Combine() +} + +func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error { + var err error + errs := gtserror.NewMultiError(2) + + if filter.Keywords == nil { + // Filter keywords are not set, fetch from the database. + filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID( + gtscontext.SetBarebones(ctx), + filter.ID, + ) + if err != nil { + errs.Appendf("error populating filter keywords: %w", err) + } + for i := range filter.Keywords { + filter.Keywords[i].Filter = filter + } + } + + if filter.Statuses == nil { + // Filter statuses are not set, fetch from the database. + filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID( + gtscontext.SetBarebones(ctx), + filter.ID, + ) + if err != nil { + errs.Appendf("error populating filter statuses: %w", err) + } + for i := range filter.Statuses { + filter.Statuses[i].Filter = filter + } + } + + return errs.Combine() +} + +func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error { + // Update database. + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx. + NewInsert(). + Model(filter). + Exec(ctx); err != nil { + return err + } + + if len(filter.Keywords) > 0 { + if _, err := tx. + NewInsert(). + Model(&filter.Keywords). + Exec(ctx); err != nil { + return err + } + } + + if len(filter.Statuses) > 0 { + if _, err := tx. + NewInsert(). + Model(&filter.Statuses). + Exec(ctx); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update cache. + f.state.Caches.GTS.Filter.Put(filter) + f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) + f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) + + return nil +} + +func (f *filterDB) UpdateFilter( + ctx context.Context, + filter *gtsmodel.Filter, + filterColumns []string, + filterKeywordColumns []string, + deleteFilterKeywordIDs []string, + deleteFilterStatusIDs []string, +) error { + updatedAt := time.Now() + filter.UpdatedAt = updatedAt + for _, filterKeyword := range filter.Keywords { + filterKeyword.UpdatedAt = updatedAt + } + for _, filterStatus := range filter.Statuses { + filterStatus.UpdatedAt = updatedAt + } + + // If we're updating by column, ensure "updated_at" is included. + if len(filterColumns) > 0 { + filterColumns = append(filterColumns, "updated_at") + } + if len(filterKeywordColumns) > 0 { + filterKeywordColumns = append(filterKeywordColumns, "updated_at") + } + + // Update database. + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx. + NewUpdate(). + Model(filter). + Column(filterColumns...). + Where("? = ?", bun.Ident("id"), filter.ID). + Exec(ctx); err != nil { + return err + } + + if len(filter.Keywords) > 0 { + if _, err := NewUpsert(tx). + Model(&filter.Keywords). + Constraint("id"). + Column(filterKeywordColumns...). + Exec(ctx); err != nil { + return err + } + } + + if len(filter.Statuses) > 0 { + if _, err := tx. + NewInsert(). + Ignore(). + Model(&filter.Statuses). + Exec(ctx); err != nil { + return err + } + } + + if len(deleteFilterKeywordIDs) > 0 { + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)). + Exec(ctx); err != nil { + return err + } + } + + if len(deleteFilterStatusIDs) > 0 { + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)). + Exec(ctx); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update cache. + f.state.Caches.GTS.Filter.Put(filter) + f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) + f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) + // TODO: (Vyr) replace with cache multi-invalidate call + for _, id := range deleteFilterKeywordIDs { + f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) + } + for _, id := range deleteFilterStatusIDs { + f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) + } + + return nil +} + +func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error { + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Delete all keywords attached to filter. + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = ?", bun.Ident("filter_id"), id). + Exec(ctx); err != nil { + return err + } + + // Delete all statuses attached to filter. + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = ?", bun.Ident("filter_id"), id). + Exec(ctx); err != nil { + return err + } + + // Delete the filter itself. + _, err := tx. + NewDelete(). + Model((*gtsmodel.Filter)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return err + }); err != nil { + return err + } + + // Invalidate this filter. + f.state.Caches.GTS.Filter.Invalidate("ID", id) + + // Invalidate all keywords and statuses for this filter. + f.state.Caches.GTS.FilterKeyword.Invalidate("FilterID", id) + f.state.Caches.GTS.FilterStatus.Invalidate("FilterID", id) + + return nil +} diff --git a/internal/db/bundb/filter_test.go b/internal/db/bundb/filter_test.go new file mode 100644 index 000000000..7940b6651 --- /dev/null +++ b/internal/db/bundb/filter_test.go @@ -0,0 +1,252 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +type FilterTestSuite struct { + BunDBStandardTestSuite +} + +// TestFilterCRUD tests CRUD and read-all operations on filters. +func (suite *FilterTestSuite) TestFilterCRUD() { + t := suite.T() + + // Create new example filter with attached keyword. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + filterKeyword := >smodel.FilterKeyword{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + Keyword: "GNU/Linux", + } + filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the example filter into db. + if err := suite.db.PutFilter(ctx, filter); err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // Now fetch newly created filter. + check, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + + // Check all expected fields match. + suite.Equal(filter.ID, check.ID) + suite.Equal(filter.AccountID, check.AccountID) + suite.Equal(filter.Title, check.Title) + suite.Equal(filter.Action, check.Action) + suite.Equal(filter.ContextHome, check.ContextHome) + suite.Equal(filter.ContextNotifications, check.ContextNotifications) + suite.Equal(filter.ContextPublic, check.ContextPublic) + suite.Equal(filter.ContextThread, check.ContextThread) + suite.Equal(filter.ContextAccount, check.ContextAccount) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + + suite.Equal(len(filter.Keywords), len(check.Keywords)) + suite.Equal(filter.Keywords[0].ID, check.Keywords[0].ID) + suite.Equal(filter.Keywords[0].AccountID, check.Keywords[0].AccountID) + suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) + suite.Equal(filter.Keywords[0].Keyword, check.Keywords[0].Keyword) + suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) + suite.NotZero(check.Keywords[0].CreatedAt) + suite.NotZero(check.Keywords[0].UpdatedAt) + + suite.Equal(len(filter.Statuses), len(check.Statuses)) + + // Fetch all filters. + all, err := suite.db.GetFiltersForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filters: %v", err) + } + + // Ensure the result contains our example filter. + suite.Len(all, 1) + suite.Equal(filter.ID, all[0].ID) + + suite.Len(all[0].Keywords, 1) + suite.Equal(filter.Keywords[0].ID, all[0].Keywords[0].ID) + + suite.Empty(all[0].Statuses) + + // Update the filter context and add another keyword and a status. + check.ContextNotifications = util.Ptr(true) + + newKeyword := >smodel.FilterKeyword{ + ID: "01HNEMY810E5XKWDDMN5ZRE749", + FilterID: filter.ID, + AccountID: filter.AccountID, + Keyword: "tux", + } + check.Keywords = append(check.Keywords, newKeyword) + + newStatus := >smodel.FilterStatus{ + ID: "01HNEMYD5XE7C8HH8TNCZ76FN2", + FilterID: filter.ID, + AccountID: filter.AccountID, + StatusID: "01HNEKZW34SQZ8PSDQ0Z10NZES", + } + check.Statuses = append(check.Statuses, newStatus) + + if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + // Now fetch newly updated filter. + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure expected fields were modified on check filter. + suite.True(check.UpdatedAt.After(filter.UpdatedAt)) + if suite.NotNil(check.ContextHome) { + suite.True(*check.ContextHome) + } + if suite.NotNil(check.ContextNotifications) { + suite.True(*check.ContextNotifications) + } + if suite.NotNil(check.ContextPublic) { + suite.True(*check.ContextPublic) + } + if suite.NotNil(check.ContextThread) { + suite.False(*check.ContextThread) + } + if suite.NotNil(check.ContextAccount) { + suite.False(*check.ContextAccount) + } + + // Ensure keyword entries were added. + suite.Len(check.Keywords, 2) + checkFilterKeywordIDs := make([]string, 0, 2) + for _, checkFilterKeyword := range check.Keywords { + checkFilterKeywordIDs = append(checkFilterKeywordIDs, checkFilterKeyword.ID) + } + suite.ElementsMatch([]string{filterKeyword.ID, newKeyword.ID}, checkFilterKeywordIDs) + + // Ensure status entry was added. + suite.Len(check.Statuses, 1) + checkFilterStatusIDs := make([]string, 0, 1) + for _, checkFilterStatus := range check.Statuses { + checkFilterStatusIDs = append(checkFilterStatusIDs, checkFilterStatus.ID) + } + suite.ElementsMatch([]string{newStatus.ID}, checkFilterStatusIDs) + + // Update one filter keyword and delete another. Don't change the filter or the filter status. + filterKeyword.WholeWord = util.Ptr(true) + check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + check.Statuses = nil + + if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure expected fields were not modified. + suite.Equal(filter.Title, check.Title) + suite.Equal(gtsmodel.FilterActionWarn, check.Action) + if suite.NotNil(check.ContextHome) { + suite.True(*check.ContextHome) + } + if suite.NotNil(check.ContextNotifications) { + suite.True(*check.ContextNotifications) + } + if suite.NotNil(check.ContextPublic) { + suite.True(*check.ContextPublic) + } + if suite.NotNil(check.ContextThread) { + suite.False(*check.ContextThread) + } + if suite.NotNil(check.ContextAccount) { + suite.False(*check.ContextAccount) + } + + // Ensure only changed field of keyword was modified, and other keyword was deleted. + suite.Len(check.Keywords, 1) + suite.Equal(filterKeyword.ID, check.Keywords[0].ID) + suite.Equal("GNU/Linux", check.Keywords[0].Keyword) + if suite.NotNil(check.Keywords[0].WholeWord) { + suite.True(*check.Keywords[0].WholeWord) + } + + // Ensure status entry was not deleted. + suite.Len(check.Statuses, 1) + suite.Equal(newStatus.ID, check.Statuses[0].ID) + + // Add another status entry for the same status ID. It should be ignored without problems. + redundantStatus := >smodel.FilterStatus{ + ID: "01HQXJ5Y405XZSQ67C2BSQ6HJ0", + FilterID: filter.ID, + AccountID: filter.AccountID, + StatusID: newStatus.StatusID, + } + check.Statuses = []*gtsmodel.FilterStatus{redundantStatus} + if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure status entry was not deleted, updated, or duplicated. + suite.Len(check.Statuses, 1) + suite.Equal(newStatus.ID, check.Statuses[0].ID) + suite.Equal(newStatus.StatusID, check.Statuses[0].StatusID) + + // Now delete the filter from the DB. + if err := suite.db.DeleteFilterByID(ctx, filter.ID); err != nil { + t.Fatalf("error deleting filter: %v", err) + } + + // Ensure we can't refetch it. + _, err = suite.db.GetFilterByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter returned unexpected error: %v", err) + } +} + +func TestFilterTestSuite(t *testing.T) { + suite.Run(t, new(FilterTestSuite)) +} diff --git a/internal/db/bundb/filterkeyword.go b/internal/db/bundb/filterkeyword.go new file mode 100644 index 000000000..703d58d43 --- /dev/null +++ b/internal/db/bundb/filterkeyword.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) { + filterKeyword, err := f.state.Caches.GTS.FilterKeyword.LoadOne( + "ID", + func() (*gtsmodel.FilterKeyword, error) { + var filterKeyword gtsmodel.FilterKeyword + err := f.db. + NewSelect(). + Model(&filterKeyword). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filterKeyword, err + }, + id, + ) + if err != nil { + return nil, err + } + + if !gtscontext.Barebones(ctx) { + err = f.populateFilterKeyword(ctx, filterKeyword) + if err != nil { + return nil, err + } + } + + return filterKeyword, nil +} + +func (f *filterDB) populateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { + if filterKeyword.Filter == nil { + // Filter is not set, fetch from the cache or database. + filter, err := f.state.DB.GetFilterByID( + // Don't populate the filter with all of its keywords and statuses or we'll just end up back here. + gtscontext.SetBarebones(ctx), + filterKeyword.FilterID, + ) + if err != nil { + return err + } + filterKeyword.Filter = filter + } + + return nil +} + +func (f *filterDB) GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) { + return f.getFilterKeywords(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) { + return f.getFilterKeywords(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterKeywords(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterKeyword, error) { + var filterKeywordIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.FilterKeyword)(nil)). + Column("id"). + Where("? = ?", bun.Ident(idColumn), id). + Scan(ctx, &filterKeywordIDs); err != nil { + return nil, err + } + if len(filterKeywordIDs) == 0 { + return nil, nil + } + + // Get each filter keyword by ID from the cache or DB. + uncachedFilterKeywordIDs := make([]string, 0, len(filterKeywordIDs)) + filterKeywords, err := f.state.Caches.GTS.FilterKeyword.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterKeywordIDs { + if !load(id) { + uncachedFilterKeywordIDs = append(uncachedFilterKeywordIDs, id) + } + } + }, + func() ([]*gtsmodel.FilterKeyword, error) { + uncachedFilterKeywords := make([]*gtsmodel.FilterKeyword, 0, len(uncachedFilterKeywordIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilterKeywords). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterKeywordIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilterKeywords, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter keyword structs in the same order as the filter keyword IDs. + util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string { + return filterKeyword.ID + }) + + if gtscontext.Barebones(ctx) { + return filterKeywords, nil + } + + // Populate the filter keywords. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filterKeywords)) + filterKeywords = slices.DeleteFunc(filterKeywords, func(filterKeyword *gtsmodel.FilterKeyword) bool { + if err := f.populateFilterKeyword(ctx, filterKeyword); err != nil { + errs.Appendf( + "error populating filter keyword %s: %w", + filterKeyword.ID, + err, + ) + return true + } + return false + }) + + return filterKeywords, errs.Combine() +} + +func (f *filterDB) PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { + return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { + _, err := f.db. + NewInsert(). + Model(filterKeyword). + Exec(ctx) + return err + }) +} + +func (f *filterDB) UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error { + filterKeyword.UpdatedAt = time.Now() + if len(columns) > 0 { + columns = append(columns, "updated_at") + } + + return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { + _, err := f.db. + NewUpdate(). + Model(filterKeyword). + Where("? = ?", bun.Ident("id"), filterKeyword.ID). + Column(columns...). + Exec(ctx) + return err + }) +} + +func (f *filterDB) DeleteFilterKeywordByID(ctx context.Context, id string) error { + if _, err := f.db. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) + + return nil +} diff --git a/internal/db/bundb/filterkeyword_test.go b/internal/db/bundb/filterkeyword_test.go new file mode 100644 index 000000000..91c8d192c --- /dev/null +++ b/internal/db/bundb/filterkeyword_test.go @@ -0,0 +1,143 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterKeywordCRUD tests CRUD and read-all operations on filter keywords. +func (suite *FilterTestSuite) TestFilterKeywordCRUD() { + t := suite.T() + + // Create new filter. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the new filter into the DB. + err := suite.db.PutFilter(ctx, filter) + if err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // There should be no filter keywords yet. + all, err := suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Empty(all) + + // Add a filter keyword to it. + filterKeyword := >smodel.FilterKeyword{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + Keyword: "GNU/Linux", + } + + // Insert the new filter keyword into the DB. + err = suite.db.PutFilterKeyword(ctx, filterKeyword) + if err != nil { + t.Fatalf("error inserting filter keyword: %v", err) + } + + // Try to find it again and ensure it has the fields we expect. + check, err := suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) + if err != nil { + t.Fatalf("error fetching filter keyword: %v", err) + } + suite.Equal(filterKeyword.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + suite.Equal(filterKeyword.AccountID, check.AccountID) + suite.Equal(filterKeyword.FilterID, check.FilterID) + suite.Equal(filterKeyword.Keyword, check.Keyword) + suite.Equal(filterKeyword.WholeWord, check.WholeWord) + + // Loading filter keywords by account ID should find the one we inserted. + all, err = suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterKeyword.ID, all[0].ID) + + // Loading filter keywords by filter ID should also find the one we inserted. + all, err = suite.db.GetFilterKeywordsForFilterID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterKeyword.ID, all[0].ID) + + // Modify the filter keyword. + filterKeyword.WholeWord = util.Ptr(true) + err = suite.db.UpdateFilterKeyword(ctx, filterKeyword) + if err != nil { + t.Fatalf("error updating filter keyword: %v", err) + } + + // Try to find it again and ensure it has the updated fields we expect. + check, err = suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) + if err != nil { + t.Fatalf("error fetching filter keyword: %v", err) + } + suite.Equal(filterKeyword.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.True(check.UpdatedAt.After(check.CreatedAt)) + suite.Equal(filterKeyword.AccountID, check.AccountID) + suite.Equal(filterKeyword.FilterID, check.FilterID) + suite.Equal(filterKeyword.Keyword, check.Keyword) + suite.Equal(filterKeyword.WholeWord, check.WholeWord) + + // Delete the filter keyword from the DB. + err = suite.db.DeleteFilterKeywordByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error deleting filter keyword: %v", err) + } + + // Ensure we can't refetch it. + check, err = suite.db.GetFilterKeywordByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter keyword returned unexpected error: %v", err) + } + suite.Nil(check) + + // Ensure the filter itself is still there. + checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/filterstatus.go b/internal/db/bundb/filterstatus.go new file mode 100644 index 000000000..1e98f5958 --- /dev/null +++ b/internal/db/bundb/filterstatus.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) { + filterStatus, err := f.state.Caches.GTS.FilterStatus.LoadOne( + "ID", + func() (*gtsmodel.FilterStatus, error) { + var filterStatus gtsmodel.FilterStatus + err := f.db. + NewSelect(). + Model(&filterStatus). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filterStatus, err + }, + id, + ) + if err != nil { + return nil, err + } + + if !gtscontext.Barebones(ctx) { + err = f.populateFilterStatus(ctx, filterStatus) + if err != nil { + return nil, err + } + } + + return filterStatus, nil +} + +func (f *filterDB) populateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { + if filterStatus.Filter == nil { + // Filter is not set, fetch from the cache or database. + filter, err := f.state.DB.GetFilterByID( + // Don't populate the filter with all of its keywords and statuses or we'll just end up back here. + gtscontext.SetBarebones(ctx), + filterStatus.FilterID, + ) + if err != nil { + return err + } + filterStatus.Filter = filter + } + + return nil +} + +func (f *filterDB) GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) { + return f.getFilterStatuses(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) { + return f.getFilterStatuses(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterStatuses(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterStatus, error) { + var filterStatusIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.FilterStatus)(nil)). + Column("id"). + Where("? = ?", bun.Ident(idColumn), id). + Scan(ctx, &filterStatusIDs); err != nil { + return nil, err + } + if len(filterStatusIDs) == 0 { + return nil, nil + } + + // Get each filter status by ID from the cache or DB. + uncachedFilterStatusIDs := make([]string, 0, len(filterStatusIDs)) + filterStatuses, err := f.state.Caches.GTS.FilterStatus.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterStatusIDs { + if !load(id) { + uncachedFilterStatusIDs = append(uncachedFilterStatusIDs, id) + } + } + }, + func() ([]*gtsmodel.FilterStatus, error) { + uncachedFilterStatuses := make([]*gtsmodel.FilterStatus, 0, len(uncachedFilterStatusIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilterStatuses). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterStatusIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilterStatuses, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter status structs in the same order as the filter status IDs. + util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string { + return filterStatus.ID + }) + + if gtscontext.Barebones(ctx) { + return filterStatuses, nil + } + + // Populate the filter statuses. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filterStatuses)) + filterStatuses = slices.DeleteFunc(filterStatuses, func(filterStatus *gtsmodel.FilterStatus) bool { + if err := f.populateFilterStatus(ctx, filterStatus); err != nil { + errs.Appendf( + "error populating filter status %s: %w", + filterStatus.ID, + err, + ) + return true + } + return false + }) + + return filterStatuses, errs.Combine() +} + +func (f *filterDB) PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { + return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { + _, err := f.db. + NewInsert(). + Model(filterStatus). + Exec(ctx) + return err + }) +} + +func (f *filterDB) UpdateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus, columns ...string) error { + filterStatus.UpdatedAt = time.Now() + if len(columns) > 0 { + columns = append(columns, "updated_at") + } + + return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { + _, err := f.db. + NewUpdate(). + Model(filterStatus). + Where("? = ?", bun.Ident("id"), filterStatus.ID). + Column(columns...). + Exec(ctx) + return err + }) +} + +func (f *filterDB) DeleteFilterStatusByID(ctx context.Context, id string) error { + if _, err := f.db. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) + + return nil +} diff --git a/internal/db/bundb/filterstatus_test.go b/internal/db/bundb/filterstatus_test.go new file mode 100644 index 000000000..48ddb1bed --- /dev/null +++ b/internal/db/bundb/filterstatus_test.go @@ -0,0 +1,122 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterStatusCRD tests CRD (no U) and read-all operations on filter statuses. +func (suite *FilterTestSuite) TestFilterStatusCRD() { + t := suite.T() + + // Create new filter. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the new filter into the DB. + err := suite.db.PutFilter(ctx, filter) + if err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // There should be no filter statuses yet. + all, err := suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Empty(all) + + // Add a filter status to it. + filterStatus := >smodel.FilterStatus{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + StatusID: "01HQXGMQ3QFXRT4GX9WNQ8KC0X", + } + + // Insert the new filter status into the DB. + err = suite.db.PutFilterStatus(ctx, filterStatus) + if err != nil { + t.Fatalf("error inserting filter status: %v", err) + } + + // Try to find it again and ensure it has the fields we expect. + check, err := suite.db.GetFilterStatusByID(ctx, filterStatus.ID) + if err != nil { + t.Fatalf("error fetching filter status: %v", err) + } + suite.Equal(filterStatus.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + suite.Equal(filterStatus.AccountID, check.AccountID) + suite.Equal(filterStatus.FilterID, check.FilterID) + suite.Equal(filterStatus.StatusID, check.StatusID) + + // Loading filter statuses by account ID should find the one we inserted. + all, err = suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterStatus.ID, all[0].ID) + + // Loading filter statuses by filter ID should also find the one we inserted. + all, err = suite.db.GetFilterStatusesForFilterID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterStatus.ID, all[0].ID) + + // Delete the filter status from the DB. + err = suite.db.DeleteFilterStatusByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error deleting filter status: %v", err) + } + + // Ensure we can't refetch it. + check, err = suite.db.GetFilterStatusByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter status returned unexpected error: %v", err) + } + suite.Nil(check) + + // Ensure the filter itself is still there. + checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/migrations/20240126064004_add_filters.go b/internal/db/bundb/migrations/20240126064004_add_filters.go new file mode 100644 index 000000000..3ad22f9d8 --- /dev/null +++ b/internal/db/bundb/migrations/20240126064004_add_filters.go @@ -0,0 +1,97 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Filter table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.Filter{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Filter keyword table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.FilterKeyword{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Filter status table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.FilterStatus{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Add indexes to the filter tables. + for table, indexes := range map[string]map[string][]string{ + "filters": { + "filters_account_id_idx": {"account_id"}, + }, + "filter_keywords": { + "filter_keywords_account_id_idx": {"account_id"}, + "filter_keywords_filter_id_idx": {"filter_id"}, + }, + "filter_statuses": { + "filter_statuses_account_id_idx": {"account_id"}, + "filter_statuses_filter_id_idx": {"filter_id"}, + }, + } { + for index, columns := range indexes { + if _, err := tx. + NewCreateIndex(). + Table(table). + Index(index). + Column(columns...). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go new file mode 100644 index 000000000..34724446c --- /dev/null +++ b/internal/db/bundb/upsert.go @@ -0,0 +1,230 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "database/sql" + "reflect" + "strings" + + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" +) + +// UpsertQuery is a wrapper around an insert query that can update if an insert fails. +// Doesn't implement the full set of Bun query methods, but we can add more if we need them. +// See https://bun.uptrace.dev/guide/query-insert.html#upsert +type UpsertQuery struct { + db bun.IDB + model interface{} + constraints []string + columns []string +} + +func NewUpsert(idb bun.IDB) *UpsertQuery { + // note: passing in rawtx as conn iface so no double query-hook + // firing when passed through the bun.Tx.Query___() functions. + return &UpsertQuery{db: idb} +} + +// Model sets the model or models to upsert. +func (u *UpsertQuery) Model(model interface{}) *UpsertQuery { + u.model = model + return u +} + +// Constraint sets the columns or indexes that are used to check for conflicts. +// This is required. +func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery { + u.constraints = constraints + return u +} + +// Column sets the columns to update if an insert does't happen. +// If empty, all columns not being used for constraints will be updated. +// Cannot overlap with Constraint. +func (u *UpsertQuery) Column(columns ...string) *UpsertQuery { + u.columns = columns + return u +} + +// insertDialect errors if we're using a dialect in which we don't know how to upsert. +func (u *UpsertQuery) insertDialect() error { + dialectName := u.db.Dialect().Name() + switch dialectName { + case dialect.PG, dialect.SQLite: + return nil + default: + // FUTURE: MySQL has its own variation on upserts, but the syntax is different. + return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName) + } +} + +// insertConstraints checks that we have constraints and returns them. +func (u *UpsertQuery) insertConstraints() ([]string, error) { + if len(u.constraints) == 0 { + return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided") + } + return u.constraints, nil +} + +// insertColumns returns the non-constraint columns we'll be updating. +func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) { + // Constraints as a set. + constraintSet := make(map[string]struct{}, len(constraints)) + for _, constraint := range constraints { + constraintSet[constraint] = struct{}{} + } + + var columns []string + var err error + if len(u.columns) == 0 { + columns, err = u.insertColumnsDefault(constraintSet) + } else { + columns, err = u.insertColumnsSpecified(constraintSet) + } + if err != nil { + return nil, err + } + if len(columns) == 0 { + return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting") + } + + return columns, nil +} + +// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking. +func hasElem(modelType reflect.Type) bool { + switch modelType.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice: + return true + default: + return false + } +} + +// insertColumnsDefault returns all non-constraint columns from the model schema. +func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) { + // Get underlying struct type. + modelType := reflect.TypeOf(u.model) + for hasElem(modelType) { + modelType = modelType.Elem() + } + + table := u.db.Dialect().Tables().Get(modelType) + if table == nil { + return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model) + } + + columns := make([]string, 0, len(u.columns)) + for _, field := range table.Fields { + column := field.Name + if _, overlaps := constraintSet[column]; !overlaps { + columns = append(columns, column) + } + } + + return columns, nil +} + +// insertColumnsSpecified ensures constraints and specified columns to update don't overlap. +func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) { + overlapping := make([]string, 0, min(len(u.constraints), len(u.columns))) + for _, column := range u.columns { + if _, overlaps := constraintSet[column]; overlaps { + overlapping = append(overlapping, column) + } + } + + if len(overlapping) > 0 { + return nil, gtserror.Newf( + "UpsertQuery: the following columns can't be used for both constraints and columns to update: %s", + strings.Join(overlapping, ", "), + ) + } + + return u.columns, nil +} + +// insert tries to create a Bun insert query from an upsert query. +func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) { + var err error + + err = u.insertDialect() + if err != nil { + return nil, err + } + + constraints, err := u.insertConstraints() + if err != nil { + return nil, err + } + + columns, err := u.insertColumns(constraints) + if err != nil { + return nil, err + } + + // Build the parts of the query that need us to generate SQL. + constraintIDPlaceholders := make([]string, 0, len(constraints)) + constraintIDs := make([]interface{}, 0, len(constraints)) + for _, constraint := range constraints { + constraintIDPlaceholders = append(constraintIDPlaceholders, "?") + constraintIDs = append(constraintIDs, bun.Ident(constraint)) + } + onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update" + + setClauses := make([]string, 0, len(columns)) + setIDs := make([]interface{}, 0, 2*len(columns)) + for _, column := range columns { + // "excluded" is a special table that contains only the row involved in a conflict. + setClauses = append(setClauses, "? = excluded.?") + setIDs = append(setIDs, bun.Ident(column), bun.Ident(column)) + } + setSQL := strings.Join(setClauses, ", ") + + insertQuery := u.db. + NewInsert(). + Model(u.model). + On(onSQL, constraintIDs...). + Set(setSQL, setIDs...) + + return insertQuery, nil +} + +// Exec builds a Bun insert query from the upsert query, and executes it. +func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + insertQuery, err := u.insertQuery() + if err != nil { + return nil, err + } + + return insertQuery.Exec(ctx, dest...) +} + +// Scan builds a Bun insert query from the upsert query, and scans it. +func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error { + insertQuery, err := u.insertQuery() + if err != nil { + return err + } + + return insertQuery.Scan(ctx, dest...) +} diff --git a/internal/db/db.go b/internal/db/db.go index 361687e94..f23324777 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -32,6 +32,7 @@ type DB interface { Emoji HeaderFilter Instance + Filter List Marker Media diff --git a/internal/db/filter.go b/internal/db/filter.go new file mode 100644 index 000000000..18943b4f9 --- /dev/null +++ b/internal/db/filter.go @@ -0,0 +1,101 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Filter contains methods for creating, reading, updating, and deleting filters and their keyword and status entries. +type Filter interface { + //<editor-fold desc="Filter methods"> + + // GetFilterByID gets one filter with the given id. + GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) + + // GetFiltersForAccountID gets all filters owned by the given accountID. + GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) + + // PutFilter puts a new filter in the database, adding any attached keywords or statuses. + // It uses a transaction to ensure no partial updates. + PutFilter(ctx context.Context, filter *gtsmodel.Filter) error + + // UpdateFilter updates the given filter, + // upserts any attached keywords and inserts any new statuses (existing statuses cannot be updated), + // and deletes indicated filter keywords and statuses by ID. + // It uses a transaction to ensure no partial updates. + // The column lists are optional; if not specified, all columns will be updated. + UpdateFilter( + ctx context.Context, + filter *gtsmodel.Filter, + filterColumns []string, + filterKeywordColumns []string, + deleteFilterKeywordIDs []string, + deleteFilterStatusIDs []string, + ) error + + // DeleteFilterByID deletes one filter with the given ID. + // It uses a transaction to ensure no partial updates. + DeleteFilterByID(ctx context.Context, id string) error + + //</editor-fold> + + //<editor-fold desc="Filter keyword methods"> + + // GetFilterKeywordByID gets one filter keyword with the given ID. + GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) + + // GetFilterKeywordsForFilterID gets filter keywords from the given filterID. + GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) + + // GetFilterKeywordsForAccountID gets filter keywords from the given accountID. + GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) + + // PutFilterKeyword inserts a single filter keyword into the database. + PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error + + // UpdateFilterKeyword updates the given filter keyword. + // Columns is optional, if not specified all will be updated. + UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error + + // DeleteFilterKeywordByID deletes one filter keyword with the given id. + DeleteFilterKeywordByID(ctx context.Context, id string) error + + //</editor-fold> + + //<editor-fold desc="Filter status methods"> + + // GetFilterStatusByID gets one filter status with the given ID. + GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) + + // GetFilterStatusesForFilterID gets filter statuses from the given filterID. + GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) + + // GetFilterStatusesForAccountID gets filter keywords from the given accountID. + GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) + + // PutFilterStatus inserts a single filter status into the database. + PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error + + // DeleteFilterStatusByID deletes one filter status with the given id. + DeleteFilterStatusByID(ctx context.Context, id string) error + + //</editor-fold> +} |