summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/filter.go19
-rw-r--r--internal/db/bundb/filter_test.go6
-rw-r--r--internal/db/filter.go4
3 files changed, 19 insertions, 10 deletions
diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go
index d09a5067d..30a8494a7 100644
--- a/internal/db/bundb/filter.go
+++ b/internal/db/bundb/filter.go
@@ -19,6 +19,7 @@ package bundb
import (
"context"
+ "errors"
"slices"
"time"
@@ -197,10 +198,14 @@ func (f *filterDB) UpdateFilter(
ctx context.Context,
filter *gtsmodel.Filter,
filterColumns []string,
- filterKeywordColumns []string,
+ filterKeywordColumns [][]string,
deleteFilterKeywordIDs []string,
deleteFilterStatusIDs []string,
) error {
+ if len(filter.Keywords) != len(filterKeywordColumns) {
+ return errors.New("number of filter keywords must match number of lists of filter keyword columns")
+ }
+
updatedAt := time.Now()
filter.UpdatedAt = updatedAt
for _, filterKeyword := range filter.Keywords {
@@ -214,8 +219,10 @@ func (f *filterDB) UpdateFilter(
if len(filterColumns) > 0 {
filterColumns = append(filterColumns, "updated_at")
}
- if len(filterKeywordColumns) > 0 {
- filterKeywordColumns = append(filterKeywordColumns, "updated_at")
+ for i := range filterKeywordColumns {
+ if len(filterKeywordColumns[i]) > 0 {
+ filterKeywordColumns[i] = append(filterKeywordColumns[i], "updated_at")
+ }
}
// Update database.
@@ -229,11 +236,11 @@ func (f *filterDB) UpdateFilter(
return err
}
- if len(filter.Keywords) > 0 {
+ for i, filterKeyword := range filter.Keywords {
if _, err := NewUpsert(tx).
- Model(&filter.Keywords).
+ Model(filterKeyword).
Constraint("id").
- Column(filterKeywordColumns...).
+ Column(filterKeywordColumns[i]...).
Exec(ctx); err != nil {
return err
}
diff --git a/internal/db/bundb/filter_test.go b/internal/db/bundb/filter_test.go
index 7940b6651..d1249d16b 100644
--- a/internal/db/bundb/filter_test.go
+++ b/internal/db/bundb/filter_test.go
@@ -127,7 +127,7 @@ func (suite *FilterTestSuite) TestFilterCRUD() {
}
check.Statuses = append(check.Statuses, newStatus)
- if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
+ if err := suite.db.UpdateFilter(ctx, check, nil, [][]string{nil, nil}, nil, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
// Now fetch newly updated filter.
@@ -175,7 +175,7 @@ func (suite *FilterTestSuite) TestFilterCRUD() {
check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
check.Statuses = nil
- if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil {
+ if err := suite.db.UpdateFilter(ctx, check, nil, [][]string{{"whole_word"}}, []string{newKeyword.ID}, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
check, err = suite.db.GetFilterByID(ctx, filter.ID)
@@ -222,7 +222,7 @@ func (suite *FilterTestSuite) TestFilterCRUD() {
StatusID: newStatus.StatusID,
}
check.Statuses = []*gtsmodel.FilterStatus{redundantStatus}
- if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
+ if err := suite.db.UpdateFilter(ctx, check, nil, [][]string{nil}, nil, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
check, err = suite.db.GetFilterByID(ctx, filter.ID)
diff --git a/internal/db/filter.go b/internal/db/filter.go
index 18943b4f9..eee61a99d 100644
--- a/internal/db/filter.go
+++ b/internal/db/filter.go
@@ -42,11 +42,13 @@ type Filter interface {
// and deletes indicated filter keywords and statuses by ID.
// It uses a transaction to ensure no partial updates.
// The column lists are optional; if not specified, all columns will be updated.
+ // The filter keyword columns list is *per keyword*.
+ // To update all keyword columns, provide a list where every element is an empty list.
UpdateFilter(
ctx context.Context,
filter *gtsmodel.Filter,
filterColumns []string,
- filterKeywordColumns []string,
+ filterKeywordColumns [][]string,
deleteFilterKeywordIDs []string,
deleteFilterStatusIDs []string,
) error