diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/filter.go | 19 | ||||
-rw-r--r-- | internal/db/bundb/filter_test.go | 6 | ||||
-rw-r--r-- | internal/db/filter.go | 4 |
3 files changed, 19 insertions, 10 deletions
diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go index d09a5067d..30a8494a7 100644 --- a/internal/db/bundb/filter.go +++ b/internal/db/bundb/filter.go @@ -19,6 +19,7 @@ package bundb import ( "context" + "errors" "slices" "time" @@ -197,10 +198,14 @@ func (f *filterDB) UpdateFilter( ctx context.Context, filter *gtsmodel.Filter, filterColumns []string, - filterKeywordColumns []string, + filterKeywordColumns [][]string, deleteFilterKeywordIDs []string, deleteFilterStatusIDs []string, ) error { + if len(filter.Keywords) != len(filterKeywordColumns) { + return errors.New("number of filter keywords must match number of lists of filter keyword columns") + } + updatedAt := time.Now() filter.UpdatedAt = updatedAt for _, filterKeyword := range filter.Keywords { @@ -214,8 +219,10 @@ func (f *filterDB) UpdateFilter( if len(filterColumns) > 0 { filterColumns = append(filterColumns, "updated_at") } - if len(filterKeywordColumns) > 0 { - filterKeywordColumns = append(filterKeywordColumns, "updated_at") + for i := range filterKeywordColumns { + if len(filterKeywordColumns[i]) > 0 { + filterKeywordColumns[i] = append(filterKeywordColumns[i], "updated_at") + } } // Update database. @@ -229,11 +236,11 @@ func (f *filterDB) UpdateFilter( return err } - if len(filter.Keywords) > 0 { + for i, filterKeyword := range filter.Keywords { if _, err := NewUpsert(tx). - Model(&filter.Keywords). + Model(filterKeyword). Constraint("id"). - Column(filterKeywordColumns...). + Column(filterKeywordColumns[i]...). Exec(ctx); err != nil { return err } diff --git a/internal/db/bundb/filter_test.go b/internal/db/bundb/filter_test.go index 7940b6651..d1249d16b 100644 --- a/internal/db/bundb/filter_test.go +++ b/internal/db/bundb/filter_test.go @@ -127,7 +127,7 @@ func (suite *FilterTestSuite) TestFilterCRUD() { } check.Statuses = append(check.Statuses, newStatus) - if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + if err := suite.db.UpdateFilter(ctx, check, nil, [][]string{nil, nil}, nil, nil); err != nil { t.Fatalf("error updating filter: %v", err) } // Now fetch newly updated filter. @@ -175,7 +175,7 @@ func (suite *FilterTestSuite) TestFilterCRUD() { check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} check.Statuses = nil - if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil { + if err := suite.db.UpdateFilter(ctx, check, nil, [][]string{{"whole_word"}}, []string{newKeyword.ID}, nil); err != nil { t.Fatalf("error updating filter: %v", err) } check, err = suite.db.GetFilterByID(ctx, filter.ID) @@ -222,7 +222,7 @@ func (suite *FilterTestSuite) TestFilterCRUD() { StatusID: newStatus.StatusID, } check.Statuses = []*gtsmodel.FilterStatus{redundantStatus} - if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + if err := suite.db.UpdateFilter(ctx, check, nil, [][]string{nil}, nil, nil); err != nil { t.Fatalf("error updating filter: %v", err) } check, err = suite.db.GetFilterByID(ctx, filter.ID) diff --git a/internal/db/filter.go b/internal/db/filter.go index 18943b4f9..eee61a99d 100644 --- a/internal/db/filter.go +++ b/internal/db/filter.go @@ -42,11 +42,13 @@ type Filter interface { // and deletes indicated filter keywords and statuses by ID. // It uses a transaction to ensure no partial updates. // The column lists are optional; if not specified, all columns will be updated. + // The filter keyword columns list is *per keyword*. + // To update all keyword columns, provide a list where every element is an empty list. UpdateFilter( ctx context.Context, filter *gtsmodel.Filter, filterColumns []string, - filterKeywordColumns []string, + filterKeywordColumns [][]string, deleteFilterKeywordIDs []string, deleteFilterStatusIDs []string, ) error |