summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/advancedmigration.go29
-rw-r--r--internal/db/bundb/advancedmigration.go52
-rw-r--r--internal/db/bundb/bundb.go11
-rw-r--r--internal/db/bundb/conversation.go494
-rw-r--r--internal/db/bundb/conversation_test.go115
-rw-r--r--internal/db/bundb/migrations/20240611190733_add_conversations.go78
-rw-r--r--internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go49
-rw-r--r--internal/db/bundb/status.go32
-rw-r--r--internal/db/conversation.go52
-rw-r--r--internal/db/db.go2
-rw-r--r--internal/db/status.go12
-rw-r--r--internal/db/test/conversation.go122
12 files changed, 1048 insertions, 0 deletions
diff --git a/internal/db/advancedmigration.go b/internal/db/advancedmigration.go
new file mode 100644
index 000000000..2b4601bdb
--- /dev/null
+++ b/internal/db/advancedmigration.go
@@ -0,0 +1,29 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type AdvancedMigration interface {
+ GetAdvancedMigration(ctx context.Context, id string) (*gtsmodel.AdvancedMigration, error)
+ PutAdvancedMigration(ctx context.Context, advancedMigration *gtsmodel.AdvancedMigration) error
+}
diff --git a/internal/db/bundb/advancedmigration.go b/internal/db/bundb/advancedmigration.go
new file mode 100644
index 000000000..2a0ec93e6
--- /dev/null
+++ b/internal/db/bundb/advancedmigration.go
@@ -0,0 +1,52 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type advancedMigrationDB struct {
+ db *bun.DB
+ state *state.State
+}
+
+func (a *advancedMigrationDB) GetAdvancedMigration(ctx context.Context, id string) (*gtsmodel.AdvancedMigration, error) {
+ var advancedMigration gtsmodel.AdvancedMigration
+ err := a.db.NewSelect().
+ Model(&advancedMigration).
+ Where("? = ?", bun.Ident("id"), id).
+ Limit(1).
+ Scan(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return &advancedMigration, nil
+}
+
+func (a *advancedMigrationDB) PutAdvancedMigration(ctx context.Context, advancedMigration *gtsmodel.AdvancedMigration) error {
+ _, err := NewUpsert(a.db).
+ Model(advancedMigration).
+ Constraint("id").
+ Exec(ctx)
+ return err
+}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 57fb661df..070d4eb91 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -54,8 +54,10 @@ import (
type DBService struct {
db.Account
db.Admin
+ db.AdvancedMigration
db.Application
db.Basic
+ db.Conversation
db.Domain
db.Emoji
db.HeaderFilter
@@ -158,6 +160,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
// https://bun.uptrace.dev/orm/many-to-many-relation/
for _, t := range []interface{}{
&gtsmodel.AccountToEmoji{},
+ &gtsmodel.ConversationToStatus{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
&gtsmodel.ThreadToStatus{},
@@ -181,6 +184,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ AdvancedMigration: &advancedMigrationDB{
+ db: db,
+ state: state,
+ },
Application: &applicationDB{
db: db,
state: state,
@@ -188,6 +195,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
Basic: &basicDB{
db: db,
},
+ Conversation: &conversationDB{
+ db: db,
+ state: state,
+ },
Domain: &domainDB{
db: db,
state: state,
diff --git a/internal/db/bundb/conversation.go b/internal/db/bundb/conversation.go
new file mode 100644
index 000000000..1a3958a79
--- /dev/null
+++ b/internal/db/bundb/conversation.go
@@ -0,0 +1,494 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "slices"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect"
+)
+
+type conversationDB struct {
+ db *bun.DB
+ state *state.State
+}
+
+func (c *conversationDB) GetConversationByID(ctx context.Context, id string) (*gtsmodel.Conversation, error) {
+ return c.getConversation(
+ ctx,
+ "ID",
+ func(conversation *gtsmodel.Conversation) error {
+ return c.db.
+ NewSelect().
+ Model(conversation).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (c *conversationDB) GetConversationByThreadAndAccountIDs(ctx context.Context, threadID string, accountID string, otherAccountIDs []string) (*gtsmodel.Conversation, error) {
+ otherAccountsKey := gtsmodel.ConversationOtherAccountsKey(otherAccountIDs)
+ return c.getConversation(
+ ctx,
+ "ThreadID,AccountID,OtherAccountsKey",
+ func(conversation *gtsmodel.Conversation) error {
+ return c.db.
+ NewSelect().
+ Model(conversation).
+ Where("? = ?", bun.Ident("thread_id"), threadID).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Where("? = ?", bun.Ident("other_accounts_key"), otherAccountsKey).
+ Scan(ctx)
+ },
+ threadID,
+ accountID,
+ otherAccountsKey,
+ )
+}
+
+func (c *conversationDB) getConversation(
+ ctx context.Context,
+ lookup string,
+ dbQuery func(conversation *gtsmodel.Conversation) error,
+ keyParts ...any,
+) (*gtsmodel.Conversation, error) {
+ // Fetch conversation from cache with loader callback
+ conversation, err := c.state.Caches.GTS.Conversation.LoadOne(lookup, func() (*gtsmodel.Conversation, error) {
+ var conversation gtsmodel.Conversation
+
+ // Not cached! Perform database query
+ if err := dbQuery(&conversation); err != nil {
+ return nil, err
+ }
+
+ return &conversation, nil
+ }, keyParts...)
+ if err != nil {
+ // already processe
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return conversation, nil
+ }
+
+ if err := c.populateConversation(ctx, conversation); err != nil {
+ return nil, err
+ }
+
+ return conversation, nil
+}
+
+func (c *conversationDB) populateConversation(ctx context.Context, conversation *gtsmodel.Conversation) error {
+ var (
+ errs gtserror.MultiError
+ err error
+ )
+
+ if conversation.Account == nil {
+ conversation.Account, err = c.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ conversation.AccountID,
+ )
+ if err != nil {
+ errs.Appendf("error populating conversation owner account: %w", err)
+ }
+ }
+
+ if conversation.OtherAccounts == nil {
+ conversation.OtherAccounts, err = c.state.DB.GetAccountsByIDs(
+ gtscontext.SetBarebones(ctx),
+ conversation.OtherAccountIDs,
+ )
+ if err != nil {
+ errs.Appendf("error populating other conversation accounts: %w", err)
+ }
+ }
+
+ if conversation.LastStatus == nil && conversation.LastStatusID != "" {
+ conversation.LastStatus, err = c.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ conversation.LastStatusID,
+ )
+ if err != nil {
+ errs.Appendf("error populating conversation last status: %w", err)
+ }
+ }
+
+ return errs.Combine()
+}
+
+func (c *conversationDB) GetConversationsByOwnerAccountID(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Conversation, error) {
+ conversationLastStatusIDs, err := c.getAccountConversationLastStatusIDs(ctx, accountID, page)
+ if err != nil {
+ return nil, err
+ }
+ return c.getConversationsByLastStatusIDs(ctx, accountID, conversationLastStatusIDs)
+}
+
+func (c *conversationDB) getAccountConversationLastStatusIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ return loadPagedIDs(&c.state.Caches.GTS.ConversationLastStatusIDs, accountID, page, func() ([]string, error) {
+ var conversationLastStatusIDs []string
+
+ // Conversation last status IDs not in cache. Perform DB query.
+ if _, err := c.db.
+ NewSelect().
+ Model((*gtsmodel.Conversation)(nil)).
+ Column("last_status_id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("last_status_id")).
+ Exec(ctx, &conversationLastStatusIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, err
+ }
+
+ return conversationLastStatusIDs, nil
+ })
+}
+
+func (c *conversationDB) getConversationsByLastStatusIDs(
+ ctx context.Context,
+ accountID string,
+ conversationLastStatusIDs []string,
+) ([]*gtsmodel.Conversation, error) {
+ // Load all conversation IDs via cache loader callbacks.
+ conversations, err := c.state.Caches.GTS.Conversation.LoadIDs2Part(
+ "AccountID,LastStatusID",
+ accountID,
+ conversationLastStatusIDs,
+ func(accountID string, uncached []string) ([]*gtsmodel.Conversation, error) {
+ // Preallocate expected length of uncached conversations.
+ conversations := make([]*gtsmodel.Conversation, 0, len(uncached))
+
+ // Perform database query scanning the remaining (uncached) IDs.
+ if err := c.db.NewSelect().
+ Model(&conversations).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Where("? IN (?)", bun.Ident("last_status_id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return conversations, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the conversations by their last status IDs to ensure correct order.
+ getID := func(b *gtsmodel.Conversation) string { return b.ID }
+ util.OrderBy(conversations, conversationLastStatusIDs, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return conversations, nil
+ }
+
+ // Populate all loaded conversations, removing those we fail to populate.
+ conversations = slices.DeleteFunc(conversations, func(conversation *gtsmodel.Conversation) bool {
+ if err := c.populateConversation(ctx, conversation); err != nil {
+ log.Errorf(ctx, "error populating conversation %s: %v", conversation.ID, err)
+ return true
+ }
+ return false
+ })
+
+ return conversations, nil
+}
+
+func (c *conversationDB) UpsertConversation(ctx context.Context, conversation *gtsmodel.Conversation, columns ...string) error {
+ // If we're updating by column, ensure "updated_at" is included.
+ if len(columns) > 0 {
+ columns = append(columns, "updated_at")
+ }
+
+ return c.state.Caches.GTS.Conversation.Store(conversation, func() error {
+ _, err := NewUpsert(c.db).
+ Model(conversation).
+ Constraint("id").
+ Column(columns...).
+ Exec(ctx)
+ return err
+ })
+}
+
+func (c *conversationDB) LinkConversationToStatus(ctx context.Context, conversationID string, statusID string) error {
+ conversationToStatus := &gtsmodel.ConversationToStatus{
+ ConversationID: conversationID,
+ StatusID: statusID,
+ }
+
+ if _, err := c.db.NewInsert().
+ Model(conversationToStatus).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *conversationDB) DeleteConversationByID(ctx context.Context, id string) error {
+ // Load conversation into cache before attempting a delete,
+ // as we need it cached in order to trigger the invalidate
+ // callback. This in turn invalidates others.
+ _, err := c.GetConversationByID(gtscontext.SetBarebones(ctx), id)
+ if err != nil {
+ if errors.Is(err, db.ErrNoEntries) {
+ // not an issue.
+ err = nil
+ }
+ return err
+ }
+
+ // Drop this now-cached conversation on return after delete.
+ defer c.state.Caches.GTS.Conversation.Invalidate("ID", id)
+
+ // Finally delete conversation from DB.
+ _, err = c.db.NewDelete().
+ Model((*gtsmodel.Conversation)(nil)).
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ return err
+}
+
+func (c *conversationDB) DeleteConversationsByOwnerAccountID(ctx context.Context, accountID string) error {
+ defer func() {
+ // Invalidate any cached conversations and conversation IDs owned by this account on return.
+ // Conversation invalidate hooks only invalidate the conversation ID cache,
+ // so we don't need to load all conversations into the cache to run invalidation hooks,
+ // as with some other object types (blocks, for example).
+ c.state.Caches.GTS.Conversation.Invalidate("AccountID", accountID)
+ // In case there were no cached conversations,
+ // explicitly invalidate the user's conversation last status ID cache.
+ c.state.Caches.GTS.ConversationLastStatusIDs.Invalidate(accountID)
+ }()
+
+ return c.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Delete conversations matching the account ID.
+ deletedConversationIDs := []string{}
+ if err := tx.NewDelete().
+ Model((*gtsmodel.Conversation)(nil)).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Returning("?", bun.Ident("id")).
+ Scan(ctx, &deletedConversationIDs); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversations for account %s: %w", accountID, err)
+ }
+
+ // Delete any conversation-to-status links matching the deleted conversation IDs.
+ if _, err := tx.NewDelete().
+ Model((*gtsmodel.ConversationToStatus)(nil)).
+ Where("? IN (?)", bun.Ident("conversation_id"), bun.In(deletedConversationIDs)).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversation-to-status links for account %s: %w", accountID, err)
+ }
+
+ return nil
+ })
+}
+
+func (c *conversationDB) DeleteStatusFromConversations(ctx context.Context, statusID string) error {
+ // SQL returning the current time.
+ var nowSQL string
+ switch c.db.Dialect().Name() {
+ case dialect.SQLite:
+ nowSQL = "DATE('now')"
+ case dialect.PG:
+ nowSQL = "NOW()"
+ default:
+ log.Panicf(nil, "db conn %s was neither pg nor sqlite", c.db)
+ }
+
+ updatedConversationIDs := []string{}
+ deletedConversationIDs := []string{}
+
+ if err := c.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Delete this status from conversation-to-status links.
+ if _, err := tx.NewDelete().
+ Model((*gtsmodel.ConversationToStatus)(nil)).
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversation-to-status links while deleting status %s: %w", statusID, err)
+ }
+
+ // Note: Bun doesn't currently support CREATE TABLE … AS SELECT … so we need to use raw queries here.
+
+ // Create a temporary table with all statuses other than the deleted status
+ // in each conversation for which the deleted status is the last status
+ // (if there are such statuses).
+ conversationStatusesTempTable := "conversation_statuses_" + id.NewULID()
+ if _, err := tx.NewRaw(
+ "CREATE TEMPORARY TABLE ? AS ?",
+ bun.Ident(conversationStatusesTempTable),
+ tx.NewSelect().
+ ColumnExpr(
+ "? AS ?",
+ bun.Ident("conversations.id"),
+ bun.Ident("conversation_id"),
+ ).
+ ColumnExpr(
+ "? AS ?",
+ bun.Ident("conversation_to_statuses.status_id"),
+ bun.Ident("id"),
+ ).
+ Column("statuses.created_at").
+ Table("conversations").
+ Join("LEFT JOIN ?", bun.Ident("conversation_to_statuses")).
+ JoinOn(
+ "? = ?",
+ bun.Ident("conversations.id"),
+ bun.Ident("conversation_to_statuses.conversation_id"),
+ ).
+ JoinOn(
+ "? != ?",
+ bun.Ident("conversation_to_statuses.status_id"),
+ statusID,
+ ).
+ Join("LEFT JOIN ?", bun.Ident("statuses")).
+ JoinOn(
+ "? = ?",
+ bun.Ident("conversation_to_statuses.status_id"),
+ bun.Ident("statuses.id"),
+ ).
+ Where(
+ "? = ?",
+ bun.Ident("conversations.last_status_id"),
+ statusID,
+ ),
+ ).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error creating conversationStatusesTempTable while deleting status %s: %w", statusID, err)
+ }
+
+ // Create a temporary table with the most recently created status in each conversation
+ // for which the deleted status is the last status (if there is such a status).
+ latestConversationStatusesTempTable := "latest_conversation_statuses_" + id.NewULID()
+ if _, err := tx.NewRaw(
+ "CREATE TEMPORARY TABLE ? AS ?",
+ bun.Ident(latestConversationStatusesTempTable),
+ tx.NewSelect().
+ Column(
+ "conversation_statuses.conversation_id",
+ "conversation_statuses.id",
+ ).
+ TableExpr(
+ "? AS ?",
+ bun.Ident(conversationStatusesTempTable),
+ bun.Ident("conversation_statuses"),
+ ).
+ Join(
+ "LEFT JOIN ? AS ?",
+ bun.Ident(conversationStatusesTempTable),
+ bun.Ident("later_statuses"),
+ ).
+ JoinOn(
+ "? = ?",
+ bun.Ident("conversation_statuses.conversation_id"),
+ bun.Ident("later_statuses.conversation_id"),
+ ).
+ JoinOn(
+ "? > ?",
+ bun.Ident("later_statuses.created_at"),
+ bun.Ident("conversation_statuses.created_at"),
+ ).
+ Where("? IS NULL", bun.Ident("later_statuses.id")),
+ ).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error creating latestConversationStatusesTempTable while deleting status %s: %w", statusID, err)
+ }
+
+ // For every conversation where the given status was the last one,
+ // reset its last status to the most recently created in the conversation other than that one,
+ // if there is such a status.
+ // Return conversation IDs for invalidation.
+ if err := tx.NewUpdate().
+ Model((*gtsmodel.Conversation)(nil)).
+ SetColumn("last_status_id", "?", bun.Ident("latest_conversation_statuses.id")).
+ SetColumn("updated_at", "?", bun.Safe(nowSQL)).
+ TableExpr("? AS ?", bun.Ident(latestConversationStatusesTempTable), bun.Ident("latest_conversation_statuses")).
+ Where("?TableAlias.? = ?", bun.Ident("id"), bun.Ident("latest_conversation_statuses.conversation_id")).
+ Where("? IS NOT NULL", bun.Ident("latest_conversation_statuses.id")).
+ Returning("?TableName.?", bun.Ident("id")).
+ Scan(ctx, &updatedConversationIDs); // nocollapse
+ err != nil {
+ return gtserror.Newf("error rolling back last status for conversation while deleting status %s: %w", statusID, err)
+ }
+
+ // If there is no such status, delete the conversation.
+ // Return conversation IDs for invalidation.
+ if err := tx.NewDelete().
+ Model((*gtsmodel.Conversation)(nil)).
+ Where(
+ "? IN (?)",
+ bun.Ident("id"),
+ tx.NewSelect().
+ Table(latestConversationStatusesTempTable).
+ Column("conversation_id").
+ Where("? IS NULL", bun.Ident("id")),
+ ).
+ Returning("?", bun.Ident("id")).
+ Scan(ctx, &deletedConversationIDs); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversation while deleting status %s: %w", statusID, err)
+ }
+
+ // Clean up.
+ for _, tempTable := range []string{
+ conversationStatusesTempTable,
+ latestConversationStatusesTempTable,
+ } {
+ if _, err := tx.NewDropTable().Table(tempTable).Exec(ctx); err != nil {
+ return gtserror.Newf(
+ "error dropping temporary table %s after deleting status %s: %w",
+ tempTable,
+ statusID,
+ err,
+ )
+ }
+ }
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ updatedConversationIDs = append(updatedConversationIDs, deletedConversationIDs...)
+ c.state.Caches.GTS.Conversation.InvalidateIDs("ID", updatedConversationIDs)
+
+ return nil
+}
diff --git a/internal/db/bundb/conversation_test.go b/internal/db/bundb/conversation_test.go
new file mode 100644
index 000000000..24d35d482
--- /dev/null
+++ b/internal/db/bundb/conversation_test.go
@@ -0,0 +1,115 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/db/test"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type ConversationTestSuite struct {
+ BunDBStandardTestSuite
+
+ cf test.ConversationFactory
+
+ // testAccount is the owner of statuses and conversations in these tests (must be local).
+ testAccount *gtsmodel.Account
+ // threadID is the thread used for statuses in any given test.
+ threadID string
+}
+
+func (suite *ConversationTestSuite) SetupSuite() {
+ suite.BunDBStandardTestSuite.SetupSuite()
+
+ suite.cf.SetupSuite(suite)
+
+ suite.testAccount = suite.testAccounts["local_account_1"]
+}
+
+func (suite *ConversationTestSuite) SetupTest() {
+ suite.BunDBStandardTestSuite.SetupTest()
+
+ suite.cf.SetupTest(suite.db)
+
+ suite.threadID = suite.cf.NewULID(0)
+}
+
+// deleteStatus deletes a status from conversations and ends the test if that fails.
+func (suite *ConversationTestSuite) deleteStatus(statusID string) {
+ err := suite.db.DeleteStatusFromConversations(context.Background(), statusID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+}
+
+// getConversation fetches a conversation by ID and ends the test if that fails.
+func (suite *ConversationTestSuite) getConversation(conversationID string) *gtsmodel.Conversation {
+ conversation, err := suite.db.GetConversationByID(context.Background(), conversationID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ return conversation
+}
+
+// If we delete a status that is in a conversation but not the last status,
+// the conversation's last status should not change.
+func (suite *ConversationTestSuite) TestDeleteNonLastStatus() {
+ conversation := suite.cf.NewTestConversation(suite.testAccount, 0)
+ initial := conversation.LastStatus
+ reply := suite.cf.NewTestStatus(suite.testAccount, conversation.ThreadID, 1*time.Second, initial)
+ conversation = suite.cf.SetLastStatus(conversation, reply)
+
+ suite.deleteStatus(initial.ID)
+ conversation = suite.getConversation(conversation.ID)
+ suite.Equal(reply.ID, conversation.LastStatusID)
+}
+
+// If we delete the last status in a conversation that has other statuses,
+// a previous status should become the new last status.
+func (suite *ConversationTestSuite) TestDeleteLastStatus() {
+ conversation := suite.cf.NewTestConversation(suite.testAccount, 0)
+ initial := conversation.LastStatus
+ reply := suite.cf.NewTestStatus(suite.testAccount, conversation.ThreadID, 1*time.Second, initial)
+ conversation = suite.cf.SetLastStatus(conversation, reply)
+ conversation = suite.getConversation(conversation.ID)
+
+ suite.deleteStatus(reply.ID)
+ conversation = suite.getConversation(conversation.ID)
+ suite.Equal(initial.ID, conversation.LastStatusID)
+}
+
+// If we delete the only status in a conversation,
+// the conversation should be deleted as well.
+func (suite *ConversationTestSuite) TestDeleteOnlyStatus() {
+ conversation := suite.cf.NewTestConversation(suite.testAccount, 0)
+ initial := conversation.LastStatus
+
+ suite.deleteStatus(initial.ID)
+ _, err := suite.db.GetConversationByID(context.Background(), conversation.ID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+}
+
+func TestConversationTestSuite(t *testing.T) {
+ suite.Run(t, new(ConversationTestSuite))
+}
diff --git a/internal/db/bundb/migrations/20240611190733_add_conversations.go b/internal/db/bundb/migrations/20240611190733_add_conversations.go
new file mode 100644
index 000000000..25b226aff
--- /dev/null
+++ b/internal/db/bundb/migrations/20240611190733_add_conversations.go
@@ -0,0 +1,78 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+// Note: this migration has an advanced migration followup.
+// See Conversations.MigrateDMs().
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ for _, model := range []interface{}{
+ &gtsmodel.Conversation{},
+ &gtsmodel.ConversationToStatus{},
+ } {
+ if _, err := tx.
+ NewCreateTable().
+ Model(model).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // Add indexes to the conversations table.
+ for index, columns := range map[string][]string{
+ "conversations_account_id_idx": {
+ "account_id",
+ },
+ "conversations_last_status_id_idx": {
+ "last_status_id",
+ },
+ } {
+ if _, err := tx.
+ NewCreateIndex().
+ Model(&gtsmodel.Conversation{}).
+ Index(index).
+ Column(columns...).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go b/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go
new file mode 100644
index 000000000..183065285
--- /dev/null
+++ b/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go
@@ -0,0 +1,49 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+// Create the advanced migrations table.
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ _, err := tx.
+ NewCreateTable().
+ Model((*gtsmodel.AdvancedMigration)(nil)).
+ IfNotExists().
+ Exec(ctx)
+ return err
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index dfb97cff1..b0ed32e0e 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -682,3 +682,35 @@ func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]st
return statusIDs, nil
})
}
+
+func (s *statusDB) MaxDirectStatusID(ctx context.Context) (string, error) {
+ maxID := ""
+ if err := s.db.
+ NewSelect().
+ Model((*gtsmodel.Status)(nil)).
+ ColumnExpr("COALESCE(MAX(?), '')", bun.Ident("id")).
+ Where("? = ?", bun.Ident("visibility"), gtsmodel.VisibilityDirect).
+ Scan(ctx, &maxID); // nocollapse
+ err != nil {
+ return "", err
+ }
+ return maxID, nil
+}
+
+func (s *statusDB) GetDirectStatusIDsBatch(ctx context.Context, minID string, maxIDInclusive string, count int) ([]string, error) {
+ var statusIDs []string
+ if err := s.db.
+ NewSelect().
+ Model((*gtsmodel.Status)(nil)).
+ Column("id").
+ Where("? = ?", bun.Ident("visibility"), gtsmodel.VisibilityDirect).
+ Where("? > ?", bun.Ident("id"), minID).
+ Where("? <= ?", bun.Ident("id"), maxIDInclusive).
+ Order("id ASC").
+ Limit(count).
+ Scan(ctx, &statusIDs); // nocollapse
+ err != nil {
+ return nil, err
+ }
+ return statusIDs, nil
+}
diff --git a/internal/db/conversation.go b/internal/db/conversation.go
new file mode 100644
index 000000000..3d0b4213e
--- /dev/null
+++ b/internal/db/conversation.go
@@ -0,0 +1,52 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
+)
+
+type Conversation interface {
+ // GetConversationByID gets a single conversation by ID.
+ GetConversationByID(ctx context.Context, id string) (*gtsmodel.Conversation, error)
+
+ // GetConversationByThreadAndAccountIDs retrieves a conversation by thread ID and participant account IDs, if it exists.
+ GetConversationByThreadAndAccountIDs(ctx context.Context, threadID string, accountID string, otherAccountIDs []string) (*gtsmodel.Conversation, error)
+
+ // GetConversationsByOwnerAccountID gets all conversations owned by the given account,
+ // with optional paging based on last status ID.
+ GetConversationsByOwnerAccountID(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Conversation, error)
+
+ // UpsertConversation creates or updates a conversation.
+ UpsertConversation(ctx context.Context, conversation *gtsmodel.Conversation, columns ...string) error
+
+ // LinkConversationToStatus creates a conversation-to-status link.
+ LinkConversationToStatus(ctx context.Context, statusID string, conversationID string) error
+
+ // DeleteConversationByID deletes a conversation, removing it from the owning account's conversation list.
+ DeleteConversationByID(ctx context.Context, id string) error
+
+ // DeleteConversationsByOwnerAccountID deletes all conversations owned by the given account.
+ DeleteConversationsByOwnerAccountID(ctx context.Context, accountID string) error
+
+ // DeleteStatusFromConversations handles when a status is deleted by updating or deleting conversations for which it was the last status.
+ DeleteStatusFromConversations(ctx context.Context, statusID string) error
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index a148d778a..4b2152732 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -26,8 +26,10 @@ const (
type DB interface {
Account
Admin
+ AdvancedMigration
Application
Basic
+ Conversation
Domain
Emoji
HeaderFilter
diff --git a/internal/db/status.go b/internal/db/status.go
index 88ae12a12..ade900728 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -78,4 +78,16 @@ type Status interface {
// GetStatusChildren gets the child statuses of a given status.
GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error)
+
+ // MaxDirectStatusID returns the newest ID across all DM statuses.
+ // Returns the empty string with no error if there are no DM statuses yet.
+ // It is used only by the conversation advanced migration.
+ MaxDirectStatusID(ctx context.Context) (string, error)
+
+ // GetDirectStatusIDsBatch returns up to count DM status IDs strictly greater than minID
+ // and less than or equal to maxIDInclusive. Note that this is different from most of our paging,
+ // which uses a maxID and returns IDs strictly less than that, because it's called with the result of
+ // MaxDirectStatusID, and expects to eventually return the status with that ID.
+ // It is used only by the conversation advanced migration.
+ GetDirectStatusIDsBatch(ctx context.Context, minID string, maxIDInclusive string, count int) ([]string, error)
}
diff --git a/internal/db/test/conversation.go b/internal/db/test/conversation.go
new file mode 100644
index 000000000..95713927e
--- /dev/null
+++ b/internal/db/test/conversation.go
@@ -0,0 +1,122 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package test
+
+import (
+ "context"
+ "crypto/rand"
+ "time"
+
+ "github.com/oklog/ulid"
+ "github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+)
+
+type testSuite interface {
+ FailNow(string, ...interface{}) bool
+}
+
+// ConversationFactory can be embedded or included by test suites that want to generate statuses and conversations.
+type ConversationFactory struct {
+ // Test suite, or at least the methods from it that we care about.
+ suite testSuite
+ // Test DB.
+ db db.DB
+
+ // TestStart is the timestamp used as a base for timestamps and ULIDs in any given test.
+ TestStart time.Time
+}
+
+// SetupSuite should be called by the SetupSuite of test suites that use this mixin.
+func (f *ConversationFactory) SetupSuite(suite testSuite) {
+ f.suite = suite
+}
+
+// SetupTest should be called by the SetupTest of test suites that use this mixin.
+func (f *ConversationFactory) SetupTest(db db.DB) {
+ f.db = db
+ f.TestStart = time.Now()
+}
+
+// NewULID is a version of id.NewULID that uses the test start time and an offset instead of the real time.
+func (f *ConversationFactory) NewULID(offset time.Duration) string {
+ ulid, err := ulid.New(
+ ulid.Timestamp(f.TestStart.Add(offset)), rand.Reader,
+ )
+ if err != nil {
+ panic(err)
+ }
+ return ulid.String()
+}
+
+func (f *ConversationFactory) NewTestStatus(localAccount *gtsmodel.Account, threadID string, nowOffset time.Duration, inReplyToStatus *gtsmodel.Status) *gtsmodel.Status {
+ statusID := f.NewULID(nowOffset)
+ createdAt := f.TestStart.Add(nowOffset)
+ status := &gtsmodel.Status{
+ ID: statusID,
+ CreatedAt: createdAt,
+ UpdatedAt: createdAt,
+ URI: "http://localhost:8080/users/" + localAccount.Username + "/statuses/" + statusID,
+ AccountID: localAccount.ID,
+ AccountURI: localAccount.URI,
+ Local: util.Ptr(true),
+ ThreadID: threadID,
+ Visibility: gtsmodel.VisibilityDirect,
+ ActivityStreamsType: ap.ObjectNote,
+ Federated: util.Ptr(true),
+ }
+ if inReplyToStatus != nil {
+ status.InReplyToID = inReplyToStatus.ID
+ status.InReplyToURI = inReplyToStatus.URI
+ status.InReplyToAccountID = inReplyToStatus.AccountID
+ }
+ if err := f.db.PutStatus(context.Background(), status); err != nil {
+ f.suite.FailNow(err.Error())
+ }
+ return status
+}
+
+// NewTestConversation creates a new status and adds it to a new unread conversation, returning the conversation.
+func (f *ConversationFactory) NewTestConversation(localAccount *gtsmodel.Account, nowOffset time.Duration) *gtsmodel.Conversation {
+ threadID := f.NewULID(nowOffset)
+ status := f.NewTestStatus(localAccount, threadID, nowOffset, nil)
+ conversation := &gtsmodel.Conversation{
+ ID: f.NewULID(nowOffset),
+ AccountID: localAccount.ID,
+ ThreadID: status.ThreadID,
+ Read: util.Ptr(false),
+ }
+ f.SetLastStatus(conversation, status)
+ return conversation
+}
+
+// SetLastStatus sets an already stored status as the last status of a new or already stored conversation,
+// and returns the updated conversation.
+func (f *ConversationFactory) SetLastStatus(conversation *gtsmodel.Conversation, status *gtsmodel.Status) *gtsmodel.Conversation {
+ conversation.LastStatusID = status.ID
+ conversation.LastStatus = status
+ if err := f.db.UpsertConversation(context.Background(), conversation, "last_status_id"); err != nil {
+ f.suite.FailNow(err.Error())
+ }
+ if err := f.db.LinkConversationToStatus(context.Background(), conversation.ID, status.ID); err != nil {
+ f.suite.FailNow(err.Error())
+ }
+ return conversation
+}