diff options
Diffstat (limited to 'internal/db/bundb')
-rw-r--r-- | internal/db/bundb/basic.go | 1 | ||||
-rw-r--r-- | internal/db/bundb/bundb.go | 48 | ||||
-rw-r--r-- | internal/db/bundb/bundb_test.go | 4 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20230718161520_hashtaggery.go | 76 | ||||
-rw-r--r-- | internal/db/bundb/search.go | 99 | ||||
-rw-r--r-- | internal/db/bundb/search_test.go | 17 | ||||
-rw-r--r-- | internal/db/bundb/status.go | 13 | ||||
-rw-r--r-- | internal/db/bundb/tag.go | 119 | ||||
-rw-r--r-- | internal/db/bundb/tag_test.go | 91 | ||||
-rw-r--r-- | internal/db/bundb/timeline.go | 108 | ||||
-rw-r--r-- | internal/db/bundb/timeline_test.go | 15 | ||||
-rw-r--r-- | internal/db/bundb/util.go | 31 |
12 files changed, 571 insertions, 51 deletions
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 4991dcf69..33d6c6cb5 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -133,7 +133,6 @@ func (b *basicDB) CreateAllTables(ctx context.Context) error { >smodel.Mention{}, >smodel.Status{}, >smodel.StatusToEmoji{}, - >smodel.StatusToTag{}, >smodel.StatusFave{}, >smodel.StatusBookmark{}, >smodel.StatusMute{}, diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 6a6ff2224..8387bb8d1 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -39,7 +39,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/tracing" @@ -77,6 +76,7 @@ type DBService struct { db.Status db.StatusBookmark db.StatusFave + db.Tag db.Timeline db.User db.Tombstone @@ -230,6 +230,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + Tag: &tagDB{ + conn: db, + state: state, + }, Timeline: &timelineDB{ db: db, state: state, @@ -494,45 +498,3 @@ func sqlitePragmas(ctx context.Context, db *WrappedDB) error { return nil } - -/* - CONVERSION FUNCTIONS -*/ - -func (dbService *DBService) TagStringToTag(ctx context.Context, t string, originAccountID string) (*gtsmodel.Tag, error) { - protocol := config.GetProtocol() - host := config.GetHost() - now := time.Now() - - tag := >smodel.Tag{} - // we can use selectorinsert here to create the new tag if it doesn't exist already - // inserted will be true if this is a new tag we just created - if err := dbService.db.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows { - return nil, fmt.Errorf("error getting tag with name %s: %s", t, err) - } - - if tag.ID == "" { - // tag doesn't exist yet so populate it - newID, err := id.NewRandomULID() - if err != nil { - return nil, err - } - tag.ID = newID - tag.URL = protocol + "://" + host + "/tags/" + t - tag.Name = t - tag.FirstSeenFromAccountID = originAccountID - tag.CreatedAt = now - tag.UpdatedAt = now - useable := true - tag.Useable = &useable - listable := true - tag.Listable = &listable - } - - // bail already if the tag isn't useable - if !*tag.Useable { - return nil, fmt.Errorf("tag %s is not useable", t) - } - tag.LastStatusAt = now - return tag, nil -} diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index d608f7bc4..0cdbb5cce 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -84,5 +84,7 @@ func (suite *BunDBStandardTestSuite) SetupTest() { } func (suite *BunDBStandardTestSuite) TearDownTest() { - testrig.StandardDBTeardown(suite.db) + if suite.db != nil { + testrig.StandardDBTeardown(suite.db) + } } diff --git a/internal/db/bundb/migrations/20230718161520_hashtaggery.go b/internal/db/bundb/migrations/20230718161520_hashtaggery.go new file mode 100644 index 000000000..1b2c8edc9 --- /dev/null +++ b/internal/db/bundb/migrations/20230718161520_hashtaggery.go @@ -0,0 +1,76 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Drop now unused columns from tags table. + for _, column := range []string{ + "url", + "first_seen_from_account_id", + "last_status_at", + } { + if _, err := tx. + NewDropColumn(). + Table("tags"). + Column(column). + Exec(ctx); err != nil { + return err + } + } + + // Index status_to_tags table properly. + for index, columns := range map[string][]string{ + // Index for tag timeline paging. + "status_to_tags_tag_timeline_idx": {"tag_id", "status_id"}, + // These indexes were only implicit + // before, make them explicit now. + "status_to_tags_tag_id_idx": {"tag_id"}, + "status_to_tags_status_id_idx": {"status_id"}, + } { + if _, err := tx. + NewCreateIndex(). + Table("status_to_tags"). + Index(index). + Column(columns...). + Exec(ctx); err != nil { + return err + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/search.go b/internal/db/bundb/search.go index f4e41d0f4..755f60e7d 100644 --- a/internal/db/bundb/search.go +++ b/internal/db/bundb/search.go @@ -19,6 +19,7 @@ package bundb import ( "context" + "strings" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" @@ -385,3 +386,101 @@ func (s *searchDB) statusText() *bun.SelectQuery { return statusText } + +// Query example (SQLite): +// +// SELECT "tag"."id" FROM "tags" AS "tag" +// WHERE ("tag"."id" < 'ZZZZZZZZZZZZZZZZZZZZZZZZZZ') +// AND (("tag"."name") LIKE 'welcome%' ESCAPE '\') +// ORDER BY "tag"."id" DESC LIMIT 10 +func (s *searchDB) SearchForTags( + ctx context.Context, + query string, + maxID string, + minID string, + limit int, + offset int, +) ([]*gtsmodel.Tag, error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + var ( + tagIDs = make([]string, 0, limit) + frontToBack = true + ) + + q := s.db. + NewSelect(). + TableExpr("? AS ?", bun.Ident("tags"), bun.Ident("tag")). + // Select only IDs from table + Column("tag.id") + + // Return only items with a LOWER id than maxID. + if maxID == "" { + maxID = id.Highest + } + q = q.Where("? < ?", bun.Ident("tag.id"), maxID) + + if minID != "" { + // return only tags HIGHER (ie., newer) than minID + q = q.Where("? > ?", bun.Ident("tag.id"), minID) + + // page up + frontToBack = false + } + + // Normalize tag 'name' string. + name := strings.TrimSpace(query) + name = strings.ToLower(name) + + // Search using LIKE for tags that start with `name`. + q = whereStartsLike(q, bun.Ident("tag.name"), name) + + if limit > 0 { + // Limit amount of tags returned. + q = q.Limit(limit) + } + + if frontToBack { + // Page down. + q = q.Order("tag.id DESC") + } else { + // Page up. + q = q.Order("tag.id ASC") + } + + if err := q.Scan(ctx, &tagIDs); err != nil { + return nil, s.db.ProcessError(err) + } + + if len(tagIDs) == 0 { + return nil, nil + } + + // If we're paging up, we still want tags + // to be sorted by ID desc, so reverse slice. + // https://zchee.github.io/golang-wiki/SliceTricks/#reversing + if !frontToBack { + for l, r := 0, len(tagIDs)-1; l < r; l, r = l+1, r-1 { + tagIDs[l], tagIDs[r] = tagIDs[r], tagIDs[l] + } + } + + tags := make([]*gtsmodel.Tag, 0, len(tagIDs)) + for _, id := range tagIDs { + // Fetch tag from db for ID + tag, err := s.state.DB.GetTag(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching tag %q: %v", id, err) + continue + } + + // Append status to slice + tags = append(tags, tag) + } + + return tags, nil +} diff --git a/internal/db/bundb/search_test.go b/internal/db/bundb/search_test.go index d670c90d6..f84704df2 100644 --- a/internal/db/bundb/search_test.go +++ b/internal/db/bundb/search_test.go @@ -77,6 +77,23 @@ func (suite *SearchTestSuite) TestSearchStatuses() { suite.Len(statuses, 1) } +func (suite *SearchTestSuite) TestSearchTags() { + // Search with full tag string. + tags, err := suite.db.SearchForTags(context.Background(), "welcome", "", "", 10, 0) + suite.NoError(err) + suite.Len(tags, 1) + + // Search with partial tag string. + tags, err = suite.db.SearchForTags(context.Background(), "wel", "", "", 10, 0) + suite.NoError(err) + suite.Len(tags, 1) + + // Search with end of tag string. + tags, err = suite.db.SearchForTags(context.Background(), "come", "", "", 10, 0) + suite.NoError(err) + suite.Len(tags, 0) +} + func TestSearchTestSuite(t *testing.T) { suite.Run(t, new(SearchTestSuite)) } diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 4dc7d8468..0fef01736 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -214,9 +214,16 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) } } - // TODO: once we don't fetch using relations. - // if !status.TagsPopulated() { - // } + if !status.TagsPopulated() { + // Status tags are out-of-date with IDs, repopulate. + status.Tags, err = s.state.DB.GetTags( + ctx, + status.TagIDs, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating status tags: %w", err)) + } + } if !status.MentionsPopulated() { // Status mentions are out-of-date with IDs, repopulate. diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go new file mode 100644 index 000000000..043af5728 --- /dev/null +++ b/internal/db/bundb/tag.go @@ -0,0 +1,119 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "strings" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/uptrace/bun" +) + +type tagDB struct { + conn *WrappedDB + state *state.State +} + +func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { + return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) { + var tag gtsmodel.Tag + + q := m.conn. + NewSelect(). + Model(&tag). + Where("? = ?", bun.Ident("tag.id"), id) + + if err := q.Scan(ctx); err != nil { + return nil, m.conn.ProcessError(err) + } + + return &tag, nil + }, id) +} + +func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) { + // Normalize 'name' string. + name = strings.TrimSpace(name) + name = strings.ToLower(name) + + return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) { + var tag gtsmodel.Tag + + q := m.conn. + NewSelect(). + Model(&tag). + Where("? = ?", bun.Ident("tag.name"), name) + + if err := q.Scan(ctx); err != nil { + return nil, m.conn.ProcessError(err) + } + + return &tag, nil + }, name) +} + +func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { + tags := make([]*gtsmodel.Tag, 0, len(ids)) + + for _, id := range ids { + // Attempt fetch from DB + tag, err := m.GetTag(ctx, id) + if err != nil { + log.Errorf(ctx, "error getting tag %q: %v", id, err) + continue + } + + // Append tag + tags = append(tags, tag) + } + + return tags, nil +} + +func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { + // Normalize 'name' string before it enters + // the db, without changing tag we were given. + // + // First copy tag to new pointer. + t2 := new(gtsmodel.Tag) + *t2 = *tag + + // Normalize name on new pointer. + t2.Name = strings.TrimSpace(t2.Name) + t2.Name = strings.ToLower(t2.Name) + + // Insert the copy. + if err := m.state.Caches.GTS.Tag().Store(t2, func() error { + _, err := m.conn.NewInsert().Model(t2).Exec(ctx) + return m.conn.ProcessError(err) + }); err != nil { + return err // err already processed + } + + // Update original tag with + // field values populated by db. + tag.CreatedAt = t2.CreatedAt + tag.UpdatedAt = t2.UpdatedAt + tag.Useable = t2.Useable + tag.Listable = t2.Listable + + return nil +} diff --git a/internal/db/bundb/tag_test.go b/internal/db/bundb/tag_test.go new file mode 100644 index 000000000..324398d27 --- /dev/null +++ b/internal/db/bundb/tag_test.go @@ -0,0 +1,91 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" +) + +type TagTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *TagTestSuite) TestGetTag() { + testTag := suite.testTags["welcome"] + + dbTag, err := suite.db.GetTag(context.Background(), testTag.ID) + suite.NoError(err) + suite.NotNil(dbTag) + suite.Equal(testTag.ID, dbTag.ID) +} + +func (suite *TagTestSuite) TestGetTagByName() { + testTag := suite.testTags["welcome"] + + // Name is normalized when doing + // selects from the db, so these + // should all yield the same result. + for _, name := range []string{ + "WELCOME", + "welcome", + "Welcome", + "WELCoME ", + } { + dbTag, err := suite.db.GetTagByName(context.Background(), name) + suite.NoError(err) + suite.NotNil(dbTag) + suite.Equal(testTag.ID, dbTag.ID) + } +} + +func (suite *TagTestSuite) TestPutTag() { + // Name is normalized when doing + // inserts to the db, so these + // should all yield the same result. + for i, name := range []string{ + "NewTag", + "newtag", + "NEWtag", + "NEWTAG ", + } { + err := suite.db.PutTag(context.Background(), >smodel.Tag{ + ID: id.NewULID(), + Name: name, + }) + if i == 0 { + // This is the first one, so it + // should have just been created. + suite.NoError(err) + continue + } + + // Subsequent inserts should fail + // since all these tags are equivalent. + suite.ErrorIs(err, db.ErrAlreadyExists) + } +} + +func TestTagTestSuite(t *testing.T) { + suite.Run(t, new(TagTestSuite)) +} diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index 6aa4989d9..62f1f642d 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -410,3 +410,111 @@ func (t *timelineDB) GetListTimeline( return statuses, nil } + +func (t *timelineDB) GetTagTimeline( + ctx context.Context, + tagID string, + maxID string, + sinceID string, + minID string, + limit int, +) ([]*gtsmodel.Status, error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + var ( + statusIDs = make([]string, 0, limit) + frontToBack = true + ) + + q := t.db. + NewSelect(). + TableExpr("? AS ?", bun.Ident("status_to_tags"), bun.Ident("status_to_tag")). + Column("status_to_tag.status_id"). + // Join with statuses for filtering. + Join( + "INNER JOIN ? AS ? ON ? = ?", + bun.Ident("statuses"), bun.Ident("status"), + bun.Ident("status.id"), bun.Ident("status_to_tag.status_id"), + ). + // Public only. + Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). + // This tag only. + Where("? = ?", bun.Ident("status_to_tag.tag_id"), tagID) + + if maxID == "" || maxID >= id.Highest { + const future = 24 * time.Hour + + var err error + + // don't return statuses more than 24hr in the future + maxID, err = id.NewULIDFromTime(time.Now().Add(future)) + if err != nil { + return nil, err + } + } + + // return only statuses LOWER (ie., older) than maxID + q = q.Where("? < ?", bun.Ident("status_to_tag.status_id"), maxID) + + if sinceID != "" { + // return only statuses HIGHER (ie., newer) than sinceID + q = q.Where("? > ?", bun.Ident("status_to_tag.status_id"), sinceID) + } + + if minID != "" { + // return only statuses HIGHER (ie., newer) than minID + q = q.Where("? > ?", bun.Ident("status_to_tag.status_id"), minID) + + // page up + frontToBack = false + } + + if limit > 0 { + // limit amount of statuses returned + q = q.Limit(limit) + } + + if frontToBack { + // Page down. + q = q.Order("status_to_tag.status_id DESC") + } else { + // Page up. + q = q.Order("status_to_tag.status_id ASC") + } + + if err := q.Scan(ctx, &statusIDs); err != nil { + return nil, t.db.ProcessError(err) + } + + if len(statusIDs) == 0 { + return nil, nil + } + + // If we're paging up, we still want statuses + // to be sorted by ID desc, so reverse ids slice. + // https://zchee.github.io/golang-wiki/SliceTricks/#reversing + if !frontToBack { + for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 { + statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l] + } + } + + statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) + for _, id := range statusIDs { + // Fetch status from db for ID + status, err := t.state.DB.GetStatusByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching status %q: %v", id, err) + continue + } + + // Append status to slice + statuses = append(statuses, status) + } + + return statuses, nil +} diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index 7e8fd0838..43407bc69 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -272,6 +272,21 @@ func (suite *TimelineTestSuite) TestGetListTimelineMinIDPagingUp() { suite.Equal("01F8MHCP5P2NWYQ416SBA0XSEV", s[len(s)-1].ID) } +func (suite *TimelineTestSuite) TestGetTagTimelineNoParams() { + var ( + ctx = context.Background() + tag = suite.testTags["welcome"] + ) + + s, err := suite.db.GetTagTimeline(ctx, tag.ID, "", "", "", 1) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 1) + suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[0].ID) +} + func TestTimelineTestSuite(t *testing.T) { suite.Run(t, new(TimelineTestSuite)) } diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index bdd45d1e7..3c3249daf 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -34,9 +34,10 @@ var likeEscaper = strings.NewReplacer( `_`, `\_`, // Exactly one char. ) -// whereSubqueryLike appends a WHERE clause to the -// given SelectQuery, which searches for matches -// of `search` in the given subQuery using LIKE. +// whereLike appends a WHERE clause to the +// given SelectQuery, which searches for +// matches of `search` in the given subQuery +// using LIKE. func whereLike( query *bun.SelectQuery, subject interface{}, @@ -58,6 +59,30 @@ func whereLike( ) } +// whereStartsLike is like whereLike, +// but only searches for strings that +// START WITH `search`. +func whereStartsLike( + query *bun.SelectQuery, + subject interface{}, + search string, +) *bun.SelectQuery { + // Escape existing wildcard + escape + // chars in the search query string. + search = likeEscaper.Replace(search) + + // Add our own wildcards back in; search + // zero or more chars after the query. + search += `%` + + // Append resulting WHERE + // clause to the main query. + return query.Where( + "(?) LIKE ? ESCAPE ?", + subject, search, `\`, + ) +} + // updateWhere parses []db.Where and adds it to the given update query. func updateWhere(q *bun.UpdateQuery, where []db.Where) { for _, w := range where { |