diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
| -rw-r--r-- | internal/db/bundb/filter.go | 339 | ||||
| -rw-r--r-- | internal/db/bundb/filter_test.go | 252 | ||||
| -rw-r--r-- | internal/db/bundb/filterkeyword.go | 191 | ||||
| -rw-r--r-- | internal/db/bundb/filterkeyword_test.go | 143 | ||||
| -rw-r--r-- | internal/db/bundb/filterstatus.go | 191 | ||||
| -rw-r--r-- | internal/db/bundb/filterstatus_test.go | 122 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20240126064004_add_filters.go | 97 | ||||
| -rw-r--r-- | internal/db/bundb/upsert.go | 230 | ||||
| -rw-r--r-- | internal/db/db.go | 1 | ||||
| -rw-r--r-- | internal/db/filter.go | 101 | 
11 files changed, 1672 insertions, 0 deletions
| diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 4ecbec7b9..c49da272b 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -62,6 +62,7 @@ type DBService struct {  	db.Emoji  	db.HeaderFilter  	db.Instance +	db.Filter  	db.List  	db.Marker  	db.Media @@ -200,6 +201,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			db:    db,  			state: state,  		}, +		Filter: &filterDB{ +			db:    db, +			state: state, +		},  		List: &listDB{  			db:    db,  			state: state, diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go new file mode 100644 index 000000000..bcd572f34 --- /dev/null +++ b/internal/db/bundb/filter.go @@ -0,0 +1,339 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"slices" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util" +	"github.com/uptrace/bun" +) + +type filterDB struct { +	db    *bun.DB +	state *state.State +} + +func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) { +	filter, err := f.state.Caches.GTS.Filter.LoadOne( +		"ID", +		func() (*gtsmodel.Filter, error) { +			var filter gtsmodel.Filter +			err := f.db. +				NewSelect(). +				Model(&filter). +				Where("? = ?", bun.Ident("id"), id). +				Scan(ctx) +			return &filter, err +		}, +		id, +	) +	if err != nil { +		// already processed +		return nil, err +	} + +	if !gtscontext.Barebones(ctx) { +		if err := f.populateFilter(ctx, filter); err != nil { +			return nil, err +		} +	} + +	return filter, nil +} + +func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) { +	// Fetch IDs of all filters owned by this account. +	var filterIDs []string +	if err := f.db. +		NewSelect(). +		Model((*gtsmodel.Filter)(nil)). +		Column("id"). +		Where("? = ?", bun.Ident("account_id"), accountID). +		Scan(ctx, &filterIDs); err != nil { +		return nil, err +	} +	if len(filterIDs) == 0 { +		return nil, nil +	} + +	// Get each filter by ID from the cache or DB. +	uncachedFilterIDs := make([]string, 0, len(filterIDs)) +	filters, err := f.state.Caches.GTS.Filter.Load( +		"ID", +		func(load func(keyParts ...any) bool) { +			for _, id := range filterIDs { +				if !load(id) { +					uncachedFilterIDs = append(uncachedFilterIDs, id) +				} +			} +		}, +		func() ([]*gtsmodel.Filter, error) { +			uncachedFilters := make([]*gtsmodel.Filter, 0, len(uncachedFilterIDs)) +			if err := f.db. +				NewSelect(). +				Model(&uncachedFilters). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterIDs)). +				Scan(ctx); err != nil { +				return nil, err +			} +			return uncachedFilters, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Put the filter structs in the same order as the filter IDs. +	util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID }) + +	if gtscontext.Barebones(ctx) { +		return filters, nil +	} + +	// Populate the filters. Remove any that we can't populate from the return slice. +	errs := gtserror.NewMultiError(len(filters)) +	filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool { +		if err := f.populateFilter(ctx, filter); err != nil { +			errs.Appendf("error populating filter %s: %w", filter.ID, err) +			return true +		} +		return false +	}) + +	return filters, errs.Combine() +} + +func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error { +	var err error +	errs := gtserror.NewMultiError(2) + +	if filter.Keywords == nil { +		// Filter keywords are not set, fetch from the database. +		filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID( +			gtscontext.SetBarebones(ctx), +			filter.ID, +		) +		if err != nil { +			errs.Appendf("error populating filter keywords: %w", err) +		} +		for i := range filter.Keywords { +			filter.Keywords[i].Filter = filter +		} +	} + +	if filter.Statuses == nil { +		// Filter statuses are not set, fetch from the database. +		filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID( +			gtscontext.SetBarebones(ctx), +			filter.ID, +		) +		if err != nil { +			errs.Appendf("error populating filter statuses: %w", err) +		} +		for i := range filter.Statuses { +			filter.Statuses[i].Filter = filter +		} +	} + +	return errs.Combine() +} + +func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error { +	// Update database. +	if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +		if _, err := tx. +			NewInsert(). +			Model(filter). +			Exec(ctx); err != nil { +			return err +		} + +		if len(filter.Keywords) > 0 { +			if _, err := tx. +				NewInsert(). +				Model(&filter.Keywords). +				Exec(ctx); err != nil { +				return err +			} +		} + +		if len(filter.Statuses) > 0 { +			if _, err := tx. +				NewInsert(). +				Model(&filter.Statuses). +				Exec(ctx); err != nil { +				return err +			} +		} + +		return nil +	}); err != nil { +		return err +	} + +	// Update cache. +	f.state.Caches.GTS.Filter.Put(filter) +	f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) +	f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) + +	return nil +} + +func (f *filterDB) UpdateFilter( +	ctx context.Context, +	filter *gtsmodel.Filter, +	filterColumns []string, +	filterKeywordColumns []string, +	deleteFilterKeywordIDs []string, +	deleteFilterStatusIDs []string, +) error { +	updatedAt := time.Now() +	filter.UpdatedAt = updatedAt +	for _, filterKeyword := range filter.Keywords { +		filterKeyword.UpdatedAt = updatedAt +	} +	for _, filterStatus := range filter.Statuses { +		filterStatus.UpdatedAt = updatedAt +	} + +	// If we're updating by column, ensure "updated_at" is included. +	if len(filterColumns) > 0 { +		filterColumns = append(filterColumns, "updated_at") +	} +	if len(filterKeywordColumns) > 0 { +		filterKeywordColumns = append(filterKeywordColumns, "updated_at") +	} + +	// Update database. +	if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +		if _, err := tx. +			NewUpdate(). +			Model(filter). +			Column(filterColumns...). +			Where("? = ?", bun.Ident("id"), filter.ID). +			Exec(ctx); err != nil { +			return err +		} + +		if len(filter.Keywords) > 0 { +			if _, err := NewUpsert(tx). +				Model(&filter.Keywords). +				Constraint("id"). +				Column(filterKeywordColumns...). +				Exec(ctx); err != nil { +				return err +			} +		} + +		if len(filter.Statuses) > 0 { +			if _, err := tx. +				NewInsert(). +				Ignore(). +				Model(&filter.Statuses). +				Exec(ctx); err != nil { +				return err +			} +		} + +		if len(deleteFilterKeywordIDs) > 0 { +			if _, err := tx. +				NewDelete(). +				Model((*gtsmodel.FilterKeyword)(nil)). +				Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)). +				Exec(ctx); err != nil { +				return err +			} +		} + +		if len(deleteFilterStatusIDs) > 0 { +			if _, err := tx. +				NewDelete(). +				Model((*gtsmodel.FilterStatus)(nil)). +				Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)). +				Exec(ctx); err != nil { +				return err +			} +		} + +		return nil +	}); err != nil { +		return err +	} + +	// Update cache. +	f.state.Caches.GTS.Filter.Put(filter) +	f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) +	f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) +	// TODO: (Vyr) replace with cache multi-invalidate call +	for _, id := range deleteFilterKeywordIDs { +		f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) +	} +	for _, id := range deleteFilterStatusIDs { +		f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) +	} + +	return nil +} + +func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error { +	if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +		// Delete all keywords attached to filter. +		if _, err := tx. +			NewDelete(). +			Model((*gtsmodel.FilterKeyword)(nil)). +			Where("? = ?", bun.Ident("filter_id"), id). +			Exec(ctx); err != nil { +			return err +		} + +		// Delete all statuses attached to filter. +		if _, err := tx. +			NewDelete(). +			Model((*gtsmodel.FilterStatus)(nil)). +			Where("? = ?", bun.Ident("filter_id"), id). +			Exec(ctx); err != nil { +			return err +		} + +		// Delete the filter itself. +		_, err := tx. +			NewDelete(). +			Model((*gtsmodel.Filter)(nil)). +			Where("? = ?", bun.Ident("id"), id). +			Exec(ctx) +		return err +	}); err != nil { +		return err +	} + +	// Invalidate this filter. +	f.state.Caches.GTS.Filter.Invalidate("ID", id) + +	// Invalidate all keywords and statuses for this filter. +	f.state.Caches.GTS.FilterKeyword.Invalidate("FilterID", id) +	f.state.Caches.GTS.FilterStatus.Invalidate("FilterID", id) + +	return nil +} diff --git a/internal/db/bundb/filter_test.go b/internal/db/bundb/filter_test.go new file mode 100644 index 000000000..7940b6651 --- /dev/null +++ b/internal/db/bundb/filter_test.go @@ -0,0 +1,252 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( +	"context" +	"errors" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/util" +) + +type FilterTestSuite struct { +	BunDBStandardTestSuite +} + +// TestFilterCRUD tests CRUD and read-all operations on filters. +func (suite *FilterTestSuite) TestFilterCRUD() { +	t := suite.T() + +	// Create new example filter with attached keyword. +	filter := >smodel.Filter{ +		ID:            "01HNEJNVZZVXJTRB3FX3K2B1YF", +		AccountID:     "01HNEJXCPRTJVJY9MV0VVHGD47", +		Title:         "foss jail", +		Action:        gtsmodel.FilterActionWarn, +		ContextHome:   util.Ptr(true), +		ContextPublic: util.Ptr(true), +	} +	filterKeyword := >smodel.FilterKeyword{ +		ID:        "01HNEK4RW5QEAMG9Y4ET6ST0J4", +		AccountID: filter.AccountID, +		FilterID:  filter.ID, +		Keyword:   "GNU/Linux", +	} +	filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + +	// Create new cancellable test context. +	ctx := context.Background() +	ctx, cncl := context.WithCancel(ctx) +	defer cncl() + +	// Insert the example filter into db. +	if err := suite.db.PutFilter(ctx, filter); err != nil { +		t.Fatalf("error inserting filter: %v", err) +	} + +	// Now fetch newly created filter. +	check, err := suite.db.GetFilterByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching filter: %v", err) +	} + +	// Check all expected fields match. +	suite.Equal(filter.ID, check.ID) +	suite.Equal(filter.AccountID, check.AccountID) +	suite.Equal(filter.Title, check.Title) +	suite.Equal(filter.Action, check.Action) +	suite.Equal(filter.ContextHome, check.ContextHome) +	suite.Equal(filter.ContextNotifications, check.ContextNotifications) +	suite.Equal(filter.ContextPublic, check.ContextPublic) +	suite.Equal(filter.ContextThread, check.ContextThread) +	suite.Equal(filter.ContextAccount, check.ContextAccount) +	suite.NotZero(check.CreatedAt) +	suite.NotZero(check.UpdatedAt) + +	suite.Equal(len(filter.Keywords), len(check.Keywords)) +	suite.Equal(filter.Keywords[0].ID, check.Keywords[0].ID) +	suite.Equal(filter.Keywords[0].AccountID, check.Keywords[0].AccountID) +	suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) +	suite.Equal(filter.Keywords[0].Keyword, check.Keywords[0].Keyword) +	suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) +	suite.NotZero(check.Keywords[0].CreatedAt) +	suite.NotZero(check.Keywords[0].UpdatedAt) + +	suite.Equal(len(filter.Statuses), len(check.Statuses)) + +	// Fetch all filters. +	all, err := suite.db.GetFiltersForAccountID(ctx, filter.AccountID) +	if err != nil { +		t.Fatalf("error fetching filters: %v", err) +	} + +	// Ensure the result contains our example filter. +	suite.Len(all, 1) +	suite.Equal(filter.ID, all[0].ID) + +	suite.Len(all[0].Keywords, 1) +	suite.Equal(filter.Keywords[0].ID, all[0].Keywords[0].ID) + +	suite.Empty(all[0].Statuses) + +	// Update the filter context and add another keyword and a status. +	check.ContextNotifications = util.Ptr(true) + +	newKeyword := >smodel.FilterKeyword{ +		ID:        "01HNEMY810E5XKWDDMN5ZRE749", +		FilterID:  filter.ID, +		AccountID: filter.AccountID, +		Keyword:   "tux", +	} +	check.Keywords = append(check.Keywords, newKeyword) + +	newStatus := >smodel.FilterStatus{ +		ID:        "01HNEMYD5XE7C8HH8TNCZ76FN2", +		FilterID:  filter.ID, +		AccountID: filter.AccountID, +		StatusID:  "01HNEKZW34SQZ8PSDQ0Z10NZES", +	} +	check.Statuses = append(check.Statuses, newStatus) + +	if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { +		t.Fatalf("error updating filter: %v", err) +	} +	// Now fetch newly updated filter. +	check, err = suite.db.GetFilterByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching updated filter: %v", err) +	} + +	// Ensure expected fields were modified on check filter. +	suite.True(check.UpdatedAt.After(filter.UpdatedAt)) +	if suite.NotNil(check.ContextHome) { +		suite.True(*check.ContextHome) +	} +	if suite.NotNil(check.ContextNotifications) { +		suite.True(*check.ContextNotifications) +	} +	if suite.NotNil(check.ContextPublic) { +		suite.True(*check.ContextPublic) +	} +	if suite.NotNil(check.ContextThread) { +		suite.False(*check.ContextThread) +	} +	if suite.NotNil(check.ContextAccount) { +		suite.False(*check.ContextAccount) +	} + +	// Ensure keyword entries were added. +	suite.Len(check.Keywords, 2) +	checkFilterKeywordIDs := make([]string, 0, 2) +	for _, checkFilterKeyword := range check.Keywords { +		checkFilterKeywordIDs = append(checkFilterKeywordIDs, checkFilterKeyword.ID) +	} +	suite.ElementsMatch([]string{filterKeyword.ID, newKeyword.ID}, checkFilterKeywordIDs) + +	// Ensure status entry was added. +	suite.Len(check.Statuses, 1) +	checkFilterStatusIDs := make([]string, 0, 1) +	for _, checkFilterStatus := range check.Statuses { +		checkFilterStatusIDs = append(checkFilterStatusIDs, checkFilterStatus.ID) +	} +	suite.ElementsMatch([]string{newStatus.ID}, checkFilterStatusIDs) + +	// Update one filter keyword and delete another. Don't change the filter or the filter status. +	filterKeyword.WholeWord = util.Ptr(true) +	check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} +	check.Statuses = nil + +	if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil { +		t.Fatalf("error updating filter: %v", err) +	} +	check, err = suite.db.GetFilterByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching updated filter: %v", err) +	} + +	// Ensure expected fields were not modified. +	suite.Equal(filter.Title, check.Title) +	suite.Equal(gtsmodel.FilterActionWarn, check.Action) +	if suite.NotNil(check.ContextHome) { +		suite.True(*check.ContextHome) +	} +	if suite.NotNil(check.ContextNotifications) { +		suite.True(*check.ContextNotifications) +	} +	if suite.NotNil(check.ContextPublic) { +		suite.True(*check.ContextPublic) +	} +	if suite.NotNil(check.ContextThread) { +		suite.False(*check.ContextThread) +	} +	if suite.NotNil(check.ContextAccount) { +		suite.False(*check.ContextAccount) +	} + +	// Ensure only changed field of keyword was modified, and other keyword was deleted. +	suite.Len(check.Keywords, 1) +	suite.Equal(filterKeyword.ID, check.Keywords[0].ID) +	suite.Equal("GNU/Linux", check.Keywords[0].Keyword) +	if suite.NotNil(check.Keywords[0].WholeWord) { +		suite.True(*check.Keywords[0].WholeWord) +	} + +	// Ensure status entry was not deleted. +	suite.Len(check.Statuses, 1) +	suite.Equal(newStatus.ID, check.Statuses[0].ID) + +	// Add another status entry for the same status ID. It should be ignored without problems. +	redundantStatus := >smodel.FilterStatus{ +		ID:        "01HQXJ5Y405XZSQ67C2BSQ6HJ0", +		FilterID:  filter.ID, +		AccountID: filter.AccountID, +		StatusID:  newStatus.StatusID, +	} +	check.Statuses = []*gtsmodel.FilterStatus{redundantStatus} +	if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { +		t.Fatalf("error updating filter: %v", err) +	} +	check, err = suite.db.GetFilterByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching updated filter: %v", err) +	} + +	// Ensure status entry was not deleted, updated, or duplicated. +	suite.Len(check.Statuses, 1) +	suite.Equal(newStatus.ID, check.Statuses[0].ID) +	suite.Equal(newStatus.StatusID, check.Statuses[0].StatusID) + +	// Now delete the filter from the DB. +	if err := suite.db.DeleteFilterByID(ctx, filter.ID); err != nil { +		t.Fatalf("error deleting filter: %v", err) +	} + +	// Ensure we can't refetch it. +	_, err = suite.db.GetFilterByID(ctx, filter.ID) +	if !errors.Is(err, db.ErrNoEntries) { +		t.Fatalf("fetching deleted filter returned unexpected error: %v", err) +	} +} + +func TestFilterTestSuite(t *testing.T) { +	suite.Run(t, new(FilterTestSuite)) +} diff --git a/internal/db/bundb/filterkeyword.go b/internal/db/bundb/filterkeyword.go new file mode 100644 index 000000000..703d58d43 --- /dev/null +++ b/internal/db/bundb/filterkeyword.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"slices" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/util" +	"github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) { +	filterKeyword, err := f.state.Caches.GTS.FilterKeyword.LoadOne( +		"ID", +		func() (*gtsmodel.FilterKeyword, error) { +			var filterKeyword gtsmodel.FilterKeyword +			err := f.db. +				NewSelect(). +				Model(&filterKeyword). +				Where("? = ?", bun.Ident("id"), id). +				Scan(ctx) +			return &filterKeyword, err +		}, +		id, +	) +	if err != nil { +		return nil, err +	} + +	if !gtscontext.Barebones(ctx) { +		err = f.populateFilterKeyword(ctx, filterKeyword) +		if err != nil { +			return nil, err +		} +	} + +	return filterKeyword, nil +} + +func (f *filterDB) populateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { +	if filterKeyword.Filter == nil { +		// Filter is not set, fetch from the cache or database. +		filter, err := f.state.DB.GetFilterByID( +			// Don't populate the filter with all of its keywords and statuses or we'll just end up back here. +			gtscontext.SetBarebones(ctx), +			filterKeyword.FilterID, +		) +		if err != nil { +			return err +		} +		filterKeyword.Filter = filter +	} + +	return nil +} + +func (f *filterDB) GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) { +	return f.getFilterKeywords(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) { +	return f.getFilterKeywords(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterKeywords(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterKeyword, error) { +	var filterKeywordIDs []string +	if err := f.db. +		NewSelect(). +		Model((*gtsmodel.FilterKeyword)(nil)). +		Column("id"). +		Where("? = ?", bun.Ident(idColumn), id). +		Scan(ctx, &filterKeywordIDs); err != nil { +		return nil, err +	} +	if len(filterKeywordIDs) == 0 { +		return nil, nil +	} + +	// Get each filter keyword by ID from the cache or DB. +	uncachedFilterKeywordIDs := make([]string, 0, len(filterKeywordIDs)) +	filterKeywords, err := f.state.Caches.GTS.FilterKeyword.Load( +		"ID", +		func(load func(keyParts ...any) bool) { +			for _, id := range filterKeywordIDs { +				if !load(id) { +					uncachedFilterKeywordIDs = append(uncachedFilterKeywordIDs, id) +				} +			} +		}, +		func() ([]*gtsmodel.FilterKeyword, error) { +			uncachedFilterKeywords := make([]*gtsmodel.FilterKeyword, 0, len(uncachedFilterKeywordIDs)) +			if err := f.db. +				NewSelect(). +				Model(&uncachedFilterKeywords). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterKeywordIDs)). +				Scan(ctx); err != nil { +				return nil, err +			} +			return uncachedFilterKeywords, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Put the filter keyword structs in the same order as the filter keyword IDs. +	util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string { +		return filterKeyword.ID +	}) + +	if gtscontext.Barebones(ctx) { +		return filterKeywords, nil +	} + +	// Populate the filter keywords. Remove any that we can't populate from the return slice. +	errs := gtserror.NewMultiError(len(filterKeywords)) +	filterKeywords = slices.DeleteFunc(filterKeywords, func(filterKeyword *gtsmodel.FilterKeyword) bool { +		if err := f.populateFilterKeyword(ctx, filterKeyword); err != nil { +			errs.Appendf( +				"error populating filter keyword %s: %w", +				filterKeyword.ID, +				err, +			) +			return true +		} +		return false +	}) + +	return filterKeywords, errs.Combine() +} + +func (f *filterDB) PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { +	return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { +		_, err := f.db. +			NewInsert(). +			Model(filterKeyword). +			Exec(ctx) +		return err +	}) +} + +func (f *filterDB) UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error { +	filterKeyword.UpdatedAt = time.Now() +	if len(columns) > 0 { +		columns = append(columns, "updated_at") +	} + +	return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { +		_, err := f.db. +			NewUpdate(). +			Model(filterKeyword). +			Where("? = ?", bun.Ident("id"), filterKeyword.ID). +			Column(columns...). +			Exec(ctx) +		return err +	}) +} + +func (f *filterDB) DeleteFilterKeywordByID(ctx context.Context, id string) error { +	if _, err := f.db. +		NewDelete(). +		Model((*gtsmodel.FilterKeyword)(nil)). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx); err != nil { +		return err +	} + +	f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) + +	return nil +} diff --git a/internal/db/bundb/filterkeyword_test.go b/internal/db/bundb/filterkeyword_test.go new file mode 100644 index 000000000..91c8d192c --- /dev/null +++ b/internal/db/bundb/filterkeyword_test.go @@ -0,0 +1,143 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( +	"context" +	"errors" + +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterKeywordCRUD tests CRUD and read-all operations on filter keywords. +func (suite *FilterTestSuite) TestFilterKeywordCRUD() { +	t := suite.T() + +	// Create new filter. +	filter := >smodel.Filter{ +		ID:            "01HNEJNVZZVXJTRB3FX3K2B1YF", +		AccountID:     "01HNEJXCPRTJVJY9MV0VVHGD47", +		Title:         "foss jail", +		Action:        gtsmodel.FilterActionWarn, +		ContextHome:   util.Ptr(true), +		ContextPublic: util.Ptr(true), +	} + +	// Create new cancellable test context. +	ctx := context.Background() +	ctx, cncl := context.WithCancel(ctx) +	defer cncl() + +	// Insert the new filter into the DB. +	err := suite.db.PutFilter(ctx, filter) +	if err != nil { +		t.Fatalf("error inserting filter: %v", err) +	} + +	// There should be no filter keywords yet. +	all, err := suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) +	if err != nil { +		t.Fatalf("error fetching filter keywords: %v", err) +	} +	suite.Empty(all) + +	// Add a filter keyword to it. +	filterKeyword := >smodel.FilterKeyword{ +		ID:        "01HNEK4RW5QEAMG9Y4ET6ST0J4", +		AccountID: filter.AccountID, +		FilterID:  filter.ID, +		Keyword:   "GNU/Linux", +	} + +	// Insert the new filter keyword into the DB. +	err = suite.db.PutFilterKeyword(ctx, filterKeyword) +	if err != nil { +		t.Fatalf("error inserting filter keyword: %v", err) +	} + +	// Try to find it again and ensure it has the fields we expect. +	check, err := suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) +	if err != nil { +		t.Fatalf("error fetching filter keyword: %v", err) +	} +	suite.Equal(filterKeyword.ID, check.ID) +	suite.NotZero(check.CreatedAt) +	suite.NotZero(check.UpdatedAt) +	suite.Equal(filterKeyword.AccountID, check.AccountID) +	suite.Equal(filterKeyword.FilterID, check.FilterID) +	suite.Equal(filterKeyword.Keyword, check.Keyword) +	suite.Equal(filterKeyword.WholeWord, check.WholeWord) + +	// Loading filter keywords by account ID should find the one we inserted. +	all, err = suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) +	if err != nil { +		t.Fatalf("error fetching filter keywords: %v", err) +	} +	suite.Len(all, 1) +	suite.Equal(filterKeyword.ID, all[0].ID) + +	// Loading filter keywords by filter ID should also find the one we inserted. +	all, err = suite.db.GetFilterKeywordsForFilterID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching filter keywords: %v", err) +	} +	suite.Len(all, 1) +	suite.Equal(filterKeyword.ID, all[0].ID) + +	// Modify the filter keyword. +	filterKeyword.WholeWord = util.Ptr(true) +	err = suite.db.UpdateFilterKeyword(ctx, filterKeyword) +	if err != nil { +		t.Fatalf("error updating filter keyword: %v", err) +	} + +	// Try to find it again and ensure it has the updated fields we expect. +	check, err = suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) +	if err != nil { +		t.Fatalf("error fetching filter keyword: %v", err) +	} +	suite.Equal(filterKeyword.ID, check.ID) +	suite.NotZero(check.CreatedAt) +	suite.True(check.UpdatedAt.After(check.CreatedAt)) +	suite.Equal(filterKeyword.AccountID, check.AccountID) +	suite.Equal(filterKeyword.FilterID, check.FilterID) +	suite.Equal(filterKeyword.Keyword, check.Keyword) +	suite.Equal(filterKeyword.WholeWord, check.WholeWord) + +	// Delete the filter keyword from the DB. +	err = suite.db.DeleteFilterKeywordByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error deleting filter keyword: %v", err) +	} + +	// Ensure we can't refetch it. +	check, err = suite.db.GetFilterKeywordByID(ctx, filter.ID) +	if !errors.Is(err, db.ErrNoEntries) { +		t.Fatalf("fetching deleted filter keyword returned unexpected error: %v", err) +	} +	suite.Nil(check) + +	// Ensure the filter itself is still there. +	checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching filter: %v", err) +	} +	suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/filterstatus.go b/internal/db/bundb/filterstatus.go new file mode 100644 index 000000000..1e98f5958 --- /dev/null +++ b/internal/db/bundb/filterstatus.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"slices" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/util" +	"github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) { +	filterStatus, err := f.state.Caches.GTS.FilterStatus.LoadOne( +		"ID", +		func() (*gtsmodel.FilterStatus, error) { +			var filterStatus gtsmodel.FilterStatus +			err := f.db. +				NewSelect(). +				Model(&filterStatus). +				Where("? = ?", bun.Ident("id"), id). +				Scan(ctx) +			return &filterStatus, err +		}, +		id, +	) +	if err != nil { +		return nil, err +	} + +	if !gtscontext.Barebones(ctx) { +		err = f.populateFilterStatus(ctx, filterStatus) +		if err != nil { +			return nil, err +		} +	} + +	return filterStatus, nil +} + +func (f *filterDB) populateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { +	if filterStatus.Filter == nil { +		// Filter is not set, fetch from the cache or database. +		filter, err := f.state.DB.GetFilterByID( +			// Don't populate the filter with all of its keywords and statuses or we'll just end up back here. +			gtscontext.SetBarebones(ctx), +			filterStatus.FilterID, +		) +		if err != nil { +			return err +		} +		filterStatus.Filter = filter +	} + +	return nil +} + +func (f *filterDB) GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) { +	return f.getFilterStatuses(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) { +	return f.getFilterStatuses(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterStatuses(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterStatus, error) { +	var filterStatusIDs []string +	if err := f.db. +		NewSelect(). +		Model((*gtsmodel.FilterStatus)(nil)). +		Column("id"). +		Where("? = ?", bun.Ident(idColumn), id). +		Scan(ctx, &filterStatusIDs); err != nil { +		return nil, err +	} +	if len(filterStatusIDs) == 0 { +		return nil, nil +	} + +	// Get each filter status by ID from the cache or DB. +	uncachedFilterStatusIDs := make([]string, 0, len(filterStatusIDs)) +	filterStatuses, err := f.state.Caches.GTS.FilterStatus.Load( +		"ID", +		func(load func(keyParts ...any) bool) { +			for _, id := range filterStatusIDs { +				if !load(id) { +					uncachedFilterStatusIDs = append(uncachedFilterStatusIDs, id) +				} +			} +		}, +		func() ([]*gtsmodel.FilterStatus, error) { +			uncachedFilterStatuses := make([]*gtsmodel.FilterStatus, 0, len(uncachedFilterStatusIDs)) +			if err := f.db. +				NewSelect(). +				Model(&uncachedFilterStatuses). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterStatusIDs)). +				Scan(ctx); err != nil { +				return nil, err +			} +			return uncachedFilterStatuses, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Put the filter status structs in the same order as the filter status IDs. +	util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string { +		return filterStatus.ID +	}) + +	if gtscontext.Barebones(ctx) { +		return filterStatuses, nil +	} + +	// Populate the filter statuses. Remove any that we can't populate from the return slice. +	errs := gtserror.NewMultiError(len(filterStatuses)) +	filterStatuses = slices.DeleteFunc(filterStatuses, func(filterStatus *gtsmodel.FilterStatus) bool { +		if err := f.populateFilterStatus(ctx, filterStatus); err != nil { +			errs.Appendf( +				"error populating filter status %s: %w", +				filterStatus.ID, +				err, +			) +			return true +		} +		return false +	}) + +	return filterStatuses, errs.Combine() +} + +func (f *filterDB) PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { +	return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { +		_, err := f.db. +			NewInsert(). +			Model(filterStatus). +			Exec(ctx) +		return err +	}) +} + +func (f *filterDB) UpdateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus, columns ...string) error { +	filterStatus.UpdatedAt = time.Now() +	if len(columns) > 0 { +		columns = append(columns, "updated_at") +	} + +	return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { +		_, err := f.db. +			NewUpdate(). +			Model(filterStatus). +			Where("? = ?", bun.Ident("id"), filterStatus.ID). +			Column(columns...). +			Exec(ctx) +		return err +	}) +} + +func (f *filterDB) DeleteFilterStatusByID(ctx context.Context, id string) error { +	if _, err := f.db. +		NewDelete(). +		Model((*gtsmodel.FilterStatus)(nil)). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx); err != nil { +		return err +	} + +	f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) + +	return nil +} diff --git a/internal/db/bundb/filterstatus_test.go b/internal/db/bundb/filterstatus_test.go new file mode 100644 index 000000000..48ddb1bed --- /dev/null +++ b/internal/db/bundb/filterstatus_test.go @@ -0,0 +1,122 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( +	"context" +	"errors" + +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterStatusCRD tests CRD (no U) and read-all operations on filter statuses. +func (suite *FilterTestSuite) TestFilterStatusCRD() { +	t := suite.T() + +	// Create new filter. +	filter := >smodel.Filter{ +		ID:            "01HNEJNVZZVXJTRB3FX3K2B1YF", +		AccountID:     "01HNEJXCPRTJVJY9MV0VVHGD47", +		Title:         "foss jail", +		Action:        gtsmodel.FilterActionWarn, +		ContextHome:   util.Ptr(true), +		ContextPublic: util.Ptr(true), +	} + +	// Create new cancellable test context. +	ctx := context.Background() +	ctx, cncl := context.WithCancel(ctx) +	defer cncl() + +	// Insert the new filter into the DB. +	err := suite.db.PutFilter(ctx, filter) +	if err != nil { +		t.Fatalf("error inserting filter: %v", err) +	} + +	// There should be no filter statuses yet. +	all, err := suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) +	if err != nil { +		t.Fatalf("error fetching filter statuses: %v", err) +	} +	suite.Empty(all) + +	// Add a filter status to it. +	filterStatus := >smodel.FilterStatus{ +		ID:        "01HNEK4RW5QEAMG9Y4ET6ST0J4", +		AccountID: filter.AccountID, +		FilterID:  filter.ID, +		StatusID:  "01HQXGMQ3QFXRT4GX9WNQ8KC0X", +	} + +	// Insert the new filter status into the DB. +	err = suite.db.PutFilterStatus(ctx, filterStatus) +	if err != nil { +		t.Fatalf("error inserting filter status: %v", err) +	} + +	// Try to find it again and ensure it has the fields we expect. +	check, err := suite.db.GetFilterStatusByID(ctx, filterStatus.ID) +	if err != nil { +		t.Fatalf("error fetching filter status: %v", err) +	} +	suite.Equal(filterStatus.ID, check.ID) +	suite.NotZero(check.CreatedAt) +	suite.NotZero(check.UpdatedAt) +	suite.Equal(filterStatus.AccountID, check.AccountID) +	suite.Equal(filterStatus.FilterID, check.FilterID) +	suite.Equal(filterStatus.StatusID, check.StatusID) + +	// Loading filter statuses by account ID should find the one we inserted. +	all, err = suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) +	if err != nil { +		t.Fatalf("error fetching filter statuses: %v", err) +	} +	suite.Len(all, 1) +	suite.Equal(filterStatus.ID, all[0].ID) + +	// Loading filter statuses by filter ID should also find the one we inserted. +	all, err = suite.db.GetFilterStatusesForFilterID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching filter statuses: %v", err) +	} +	suite.Len(all, 1) +	suite.Equal(filterStatus.ID, all[0].ID) + +	// Delete the filter status from the DB. +	err = suite.db.DeleteFilterStatusByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error deleting filter status: %v", err) +	} + +	// Ensure we can't refetch it. +	check, err = suite.db.GetFilterStatusByID(ctx, filter.ID) +	if !errors.Is(err, db.ErrNoEntries) { +		t.Fatalf("fetching deleted filter status returned unexpected error: %v", err) +	} +	suite.Nil(check) + +	// Ensure the filter itself is still there. +	checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching filter: %v", err) +	} +	suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/migrations/20240126064004_add_filters.go b/internal/db/bundb/migrations/20240126064004_add_filters.go new file mode 100644 index 000000000..3ad22f9d8 --- /dev/null +++ b/internal/db/bundb/migrations/20240126064004_add_filters.go @@ -0,0 +1,97 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			// Filter table. +			if _, err := tx. +				NewCreateTable(). +				Model(>smodel.Filter{}). +				IfNotExists(). +				Exec(ctx); err != nil { +				return err +			} + +			// Filter keyword table. +			if _, err := tx. +				NewCreateTable(). +				Model(>smodel.FilterKeyword{}). +				IfNotExists(). +				Exec(ctx); err != nil { +				return err +			} + +			// Filter status table. +			if _, err := tx. +				NewCreateTable(). +				Model(>smodel.FilterStatus{}). +				IfNotExists(). +				Exec(ctx); err != nil { +				return err +			} + +			// Add indexes to the filter tables. +			for table, indexes := range map[string]map[string][]string{ +				"filters": { +					"filters_account_id_idx": {"account_id"}, +				}, +				"filter_keywords": { +					"filter_keywords_account_id_idx": {"account_id"}, +					"filter_keywords_filter_id_idx":  {"filter_id"}, +				}, +				"filter_statuses": { +					"filter_statuses_account_id_idx": {"account_id"}, +					"filter_statuses_filter_id_idx":  {"filter_id"}, +				}, +			} { +				for index, columns := range indexes { +					if _, err := tx. +						NewCreateIndex(). +						Table(table). +						Index(index). +						Column(columns...). +						IfNotExists(). +						Exec(ctx); err != nil { +						return err +					} +				} +			} + +			return nil +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go new file mode 100644 index 000000000..34724446c --- /dev/null +++ b/internal/db/bundb/upsert.go @@ -0,0 +1,230 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"database/sql" +	"reflect" +	"strings" + +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/uptrace/bun" +	"github.com/uptrace/bun/dialect" +) + +// UpsertQuery is a wrapper around an insert query that can update if an insert fails. +// Doesn't implement the full set of Bun query methods, but we can add more if we need them. +// See https://bun.uptrace.dev/guide/query-insert.html#upsert +type UpsertQuery struct { +	db          bun.IDB +	model       interface{} +	constraints []string +	columns     []string +} + +func NewUpsert(idb bun.IDB) *UpsertQuery { +	// note: passing in rawtx as conn iface so no double query-hook +	// firing when passed through the bun.Tx.Query___() functions. +	return &UpsertQuery{db: idb} +} + +// Model sets the model or models to upsert. +func (u *UpsertQuery) Model(model interface{}) *UpsertQuery { +	u.model = model +	return u +} + +// Constraint sets the columns or indexes that are used to check for conflicts. +// This is required. +func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery { +	u.constraints = constraints +	return u +} + +// Column sets the columns to update if an insert does't happen. +// If empty, all columns not being used for constraints will be updated. +// Cannot overlap with Constraint. +func (u *UpsertQuery) Column(columns ...string) *UpsertQuery { +	u.columns = columns +	return u +} + +// insertDialect errors if we're using a dialect in which we don't know how to upsert. +func (u *UpsertQuery) insertDialect() error { +	dialectName := u.db.Dialect().Name() +	switch dialectName { +	case dialect.PG, dialect.SQLite: +		return nil +	default: +		// FUTURE: MySQL has its own variation on upserts, but the syntax is different. +		return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName) +	} +} + +// insertConstraints checks that we have constraints and returns them. +func (u *UpsertQuery) insertConstraints() ([]string, error) { +	if len(u.constraints) == 0 { +		return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided") +	} +	return u.constraints, nil +} + +// insertColumns returns the non-constraint columns we'll be updating. +func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) { +	// Constraints as a set. +	constraintSet := make(map[string]struct{}, len(constraints)) +	for _, constraint := range constraints { +		constraintSet[constraint] = struct{}{} +	} + +	var columns []string +	var err error +	if len(u.columns) == 0 { +		columns, err = u.insertColumnsDefault(constraintSet) +	} else { +		columns, err = u.insertColumnsSpecified(constraintSet) +	} +	if err != nil { +		return nil, err +	} +	if len(columns) == 0 { +		return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting") +	} + +	return columns, nil +} + +// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking. +func hasElem(modelType reflect.Type) bool { +	switch modelType.Kind() { +	case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice: +		return true +	default: +		return false +	} +} + +// insertColumnsDefault returns all non-constraint columns from the model schema. +func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) { +	// Get underlying struct type. +	modelType := reflect.TypeOf(u.model) +	for hasElem(modelType) { +		modelType = modelType.Elem() +	} + +	table := u.db.Dialect().Tables().Get(modelType) +	if table == nil { +		return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model) +	} + +	columns := make([]string, 0, len(u.columns)) +	for _, field := range table.Fields { +		column := field.Name +		if _, overlaps := constraintSet[column]; !overlaps { +			columns = append(columns, column) +		} +	} + +	return columns, nil +} + +// insertColumnsSpecified ensures constraints and specified columns to update don't overlap. +func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) { +	overlapping := make([]string, 0, min(len(u.constraints), len(u.columns))) +	for _, column := range u.columns { +		if _, overlaps := constraintSet[column]; overlaps { +			overlapping = append(overlapping, column) +		} +	} + +	if len(overlapping) > 0 { +		return nil, gtserror.Newf( +			"UpsertQuery: the following columns can't be used for both constraints and columns to update: %s", +			strings.Join(overlapping, ", "), +		) +	} + +	return u.columns, nil +} + +// insert tries to create a Bun insert query from an upsert query. +func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) { +	var err error + +	err = u.insertDialect() +	if err != nil { +		return nil, err +	} + +	constraints, err := u.insertConstraints() +	if err != nil { +		return nil, err +	} + +	columns, err := u.insertColumns(constraints) +	if err != nil { +		return nil, err +	} + +	// Build the parts of the query that need us to generate SQL. +	constraintIDPlaceholders := make([]string, 0, len(constraints)) +	constraintIDs := make([]interface{}, 0, len(constraints)) +	for _, constraint := range constraints { +		constraintIDPlaceholders = append(constraintIDPlaceholders, "?") +		constraintIDs = append(constraintIDs, bun.Ident(constraint)) +	} +	onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update" + +	setClauses := make([]string, 0, len(columns)) +	setIDs := make([]interface{}, 0, 2*len(columns)) +	for _, column := range columns { +		// "excluded" is a special table that contains only the row involved in a conflict. +		setClauses = append(setClauses, "? = excluded.?") +		setIDs = append(setIDs, bun.Ident(column), bun.Ident(column)) +	} +	setSQL := strings.Join(setClauses, ", ") + +	insertQuery := u.db. +		NewInsert(). +		Model(u.model). +		On(onSQL, constraintIDs...). +		Set(setSQL, setIDs...) + +	return insertQuery, nil +} + +// Exec builds a Bun insert query from the upsert query, and executes it. +func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { +	insertQuery, err := u.insertQuery() +	if err != nil { +		return nil, err +	} + +	return insertQuery.Exec(ctx, dest...) +} + +// Scan builds a Bun insert query from the upsert query, and scans it. +func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error { +	insertQuery, err := u.insertQuery() +	if err != nil { +		return err +	} + +	return insertQuery.Scan(ctx, dest...) +} diff --git a/internal/db/db.go b/internal/db/db.go index 361687e94..f23324777 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -32,6 +32,7 @@ type DB interface {  	Emoji  	HeaderFilter  	Instance +	Filter  	List  	Marker  	Media diff --git a/internal/db/filter.go b/internal/db/filter.go new file mode 100644 index 000000000..18943b4f9 --- /dev/null +++ b/internal/db/filter.go @@ -0,0 +1,101 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Filter contains methods for creating, reading, updating, and deleting filters and their keyword and status entries. +type Filter interface { +	//<editor-fold desc="Filter methods"> + +	// GetFilterByID gets one filter with the given id. +	GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) + +	// GetFiltersForAccountID gets all filters owned by the given accountID. +	GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) + +	// PutFilter puts a new filter in the database, adding any attached keywords or statuses. +	// It uses a transaction to ensure no partial updates. +	PutFilter(ctx context.Context, filter *gtsmodel.Filter) error + +	// UpdateFilter updates the given filter, +	// upserts any attached keywords and inserts any new statuses (existing statuses cannot be updated), +	// and deletes indicated filter keywords and statuses by ID. +	// It uses a transaction to ensure no partial updates. +	// The column lists are optional; if not specified, all columns will be updated. +	UpdateFilter( +		ctx context.Context, +		filter *gtsmodel.Filter, +		filterColumns []string, +		filterKeywordColumns []string, +		deleteFilterKeywordIDs []string, +		deleteFilterStatusIDs []string, +	) error + +	// DeleteFilterByID deletes one filter with the given ID. +	// It uses a transaction to ensure no partial updates. +	DeleteFilterByID(ctx context.Context, id string) error + +	//</editor-fold> + +	//<editor-fold desc="Filter keyword methods"> + +	// GetFilterKeywordByID gets one filter keyword with the given ID. +	GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) + +	// GetFilterKeywordsForFilterID gets filter keywords from the given filterID. +	GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) + +	// GetFilterKeywordsForAccountID gets filter keywords from the given accountID. +	GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) + +	// PutFilterKeyword inserts a single filter keyword into the database. +	PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error + +	// UpdateFilterKeyword updates the given filter keyword. +	// Columns is optional, if not specified all will be updated. +	UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error + +	// DeleteFilterKeywordByID deletes one filter keyword with the given id. +	DeleteFilterKeywordByID(ctx context.Context, id string) error + +	//</editor-fold> + +	//<editor-fold desc="Filter status methods"> + +	// GetFilterStatusByID gets one filter status with the given ID. +	GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) + +	// GetFilterStatusesForFilterID gets filter statuses from the given filterID. +	GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) + +	// GetFilterStatusesForAccountID gets filter keywords from the given accountID. +	GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) + +	// PutFilterStatus inserts a single filter status into the database. +	PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error + +	// DeleteFilterStatusByID deletes one filter status with the given id. +	DeleteFilterStatusByID(ctx context.Context, id string) error + +	//</editor-fold> +} | 
