summaryrefslogtreecommitdiff
path: root/internal/db/bundb
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb')
-rw-r--r--internal/db/bundb/advancedmigration.go52
-rw-r--r--internal/db/bundb/bundb.go11
-rw-r--r--internal/db/bundb/conversation.go494
-rw-r--r--internal/db/bundb/conversation_test.go115
-rw-r--r--internal/db/bundb/migrations/20240611190733_add_conversations.go78
-rw-r--r--internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go49
-rw-r--r--internal/db/bundb/status.go32
7 files changed, 831 insertions, 0 deletions
diff --git a/internal/db/bundb/advancedmigration.go b/internal/db/bundb/advancedmigration.go
new file mode 100644
index 000000000..2a0ec93e6
--- /dev/null
+++ b/internal/db/bundb/advancedmigration.go
@@ -0,0 +1,52 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type advancedMigrationDB struct {
+ db *bun.DB
+ state *state.State
+}
+
+func (a *advancedMigrationDB) GetAdvancedMigration(ctx context.Context, id string) (*gtsmodel.AdvancedMigration, error) {
+ var advancedMigration gtsmodel.AdvancedMigration
+ err := a.db.NewSelect().
+ Model(&advancedMigration).
+ Where("? = ?", bun.Ident("id"), id).
+ Limit(1).
+ Scan(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return &advancedMigration, nil
+}
+
+func (a *advancedMigrationDB) PutAdvancedMigration(ctx context.Context, advancedMigration *gtsmodel.AdvancedMigration) error {
+ _, err := NewUpsert(a.db).
+ Model(advancedMigration).
+ Constraint("id").
+ Exec(ctx)
+ return err
+}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 57fb661df..070d4eb91 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -54,8 +54,10 @@ import (
type DBService struct {
db.Account
db.Admin
+ db.AdvancedMigration
db.Application
db.Basic
+ db.Conversation
db.Domain
db.Emoji
db.HeaderFilter
@@ -158,6 +160,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
// https://bun.uptrace.dev/orm/many-to-many-relation/
for _, t := range []interface{}{
&gtsmodel.AccountToEmoji{},
+ &gtsmodel.ConversationToStatus{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
&gtsmodel.ThreadToStatus{},
@@ -181,6 +184,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ AdvancedMigration: &advancedMigrationDB{
+ db: db,
+ state: state,
+ },
Application: &applicationDB{
db: db,
state: state,
@@ -188,6 +195,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
Basic: &basicDB{
db: db,
},
+ Conversation: &conversationDB{
+ db: db,
+ state: state,
+ },
Domain: &domainDB{
db: db,
state: state,
diff --git a/internal/db/bundb/conversation.go b/internal/db/bundb/conversation.go
new file mode 100644
index 000000000..1a3958a79
--- /dev/null
+++ b/internal/db/bundb/conversation.go
@@ -0,0 +1,494 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "slices"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect"
+)
+
+type conversationDB struct {
+ db *bun.DB
+ state *state.State
+}
+
+func (c *conversationDB) GetConversationByID(ctx context.Context, id string) (*gtsmodel.Conversation, error) {
+ return c.getConversation(
+ ctx,
+ "ID",
+ func(conversation *gtsmodel.Conversation) error {
+ return c.db.
+ NewSelect().
+ Model(conversation).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (c *conversationDB) GetConversationByThreadAndAccountIDs(ctx context.Context, threadID string, accountID string, otherAccountIDs []string) (*gtsmodel.Conversation, error) {
+ otherAccountsKey := gtsmodel.ConversationOtherAccountsKey(otherAccountIDs)
+ return c.getConversation(
+ ctx,
+ "ThreadID,AccountID,OtherAccountsKey",
+ func(conversation *gtsmodel.Conversation) error {
+ return c.db.
+ NewSelect().
+ Model(conversation).
+ Where("? = ?", bun.Ident("thread_id"), threadID).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Where("? = ?", bun.Ident("other_accounts_key"), otherAccountsKey).
+ Scan(ctx)
+ },
+ threadID,
+ accountID,
+ otherAccountsKey,
+ )
+}
+
+func (c *conversationDB) getConversation(
+ ctx context.Context,
+ lookup string,
+ dbQuery func(conversation *gtsmodel.Conversation) error,
+ keyParts ...any,
+) (*gtsmodel.Conversation, error) {
+ // Fetch conversation from cache with loader callback
+ conversation, err := c.state.Caches.GTS.Conversation.LoadOne(lookup, func() (*gtsmodel.Conversation, error) {
+ var conversation gtsmodel.Conversation
+
+ // Not cached! Perform database query
+ if err := dbQuery(&conversation); err != nil {
+ return nil, err
+ }
+
+ return &conversation, nil
+ }, keyParts...)
+ if err != nil {
+ // already processe
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return conversation, nil
+ }
+
+ if err := c.populateConversation(ctx, conversation); err != nil {
+ return nil, err
+ }
+
+ return conversation, nil
+}
+
+func (c *conversationDB) populateConversation(ctx context.Context, conversation *gtsmodel.Conversation) error {
+ var (
+ errs gtserror.MultiError
+ err error
+ )
+
+ if conversation.Account == nil {
+ conversation.Account, err = c.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ conversation.AccountID,
+ )
+ if err != nil {
+ errs.Appendf("error populating conversation owner account: %w", err)
+ }
+ }
+
+ if conversation.OtherAccounts == nil {
+ conversation.OtherAccounts, err = c.state.DB.GetAccountsByIDs(
+ gtscontext.SetBarebones(ctx),
+ conversation.OtherAccountIDs,
+ )
+ if err != nil {
+ errs.Appendf("error populating other conversation accounts: %w", err)
+ }
+ }
+
+ if conversation.LastStatus == nil && conversation.LastStatusID != "" {
+ conversation.LastStatus, err = c.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ conversation.LastStatusID,
+ )
+ if err != nil {
+ errs.Appendf("error populating conversation last status: %w", err)
+ }
+ }
+
+ return errs.Combine()
+}
+
+func (c *conversationDB) GetConversationsByOwnerAccountID(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Conversation, error) {
+ conversationLastStatusIDs, err := c.getAccountConversationLastStatusIDs(ctx, accountID, page)
+ if err != nil {
+ return nil, err
+ }
+ return c.getConversationsByLastStatusIDs(ctx, accountID, conversationLastStatusIDs)
+}
+
+func (c *conversationDB) getAccountConversationLastStatusIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ return loadPagedIDs(&c.state.Caches.GTS.ConversationLastStatusIDs, accountID, page, func() ([]string, error) {
+ var conversationLastStatusIDs []string
+
+ // Conversation last status IDs not in cache. Perform DB query.
+ if _, err := c.db.
+ NewSelect().
+ Model((*gtsmodel.Conversation)(nil)).
+ Column("last_status_id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("last_status_id")).
+ Exec(ctx, &conversationLastStatusIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, err
+ }
+
+ return conversationLastStatusIDs, nil
+ })
+}
+
+func (c *conversationDB) getConversationsByLastStatusIDs(
+ ctx context.Context,
+ accountID string,
+ conversationLastStatusIDs []string,
+) ([]*gtsmodel.Conversation, error) {
+ // Load all conversation IDs via cache loader callbacks.
+ conversations, err := c.state.Caches.GTS.Conversation.LoadIDs2Part(
+ "AccountID,LastStatusID",
+ accountID,
+ conversationLastStatusIDs,
+ func(accountID string, uncached []string) ([]*gtsmodel.Conversation, error) {
+ // Preallocate expected length of uncached conversations.
+ conversations := make([]*gtsmodel.Conversation, 0, len(uncached))
+
+ // Perform database query scanning the remaining (uncached) IDs.
+ if err := c.db.NewSelect().
+ Model(&conversations).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Where("? IN (?)", bun.Ident("last_status_id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return conversations, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the conversations by their last status IDs to ensure correct order.
+ getID := func(b *gtsmodel.Conversation) string { return b.ID }
+ util.OrderBy(conversations, conversationLastStatusIDs, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return conversations, nil
+ }
+
+ // Populate all loaded conversations, removing those we fail to populate.
+ conversations = slices.DeleteFunc(conversations, func(conversation *gtsmodel.Conversation) bool {
+ if err := c.populateConversation(ctx, conversation); err != nil {
+ log.Errorf(ctx, "error populating conversation %s: %v", conversation.ID, err)
+ return true
+ }
+ return false
+ })
+
+ return conversations, nil
+}
+
+func (c *conversationDB) UpsertConversation(ctx context.Context, conversation *gtsmodel.Conversation, columns ...string) error {
+ // If we're updating by column, ensure "updated_at" is included.
+ if len(columns) > 0 {
+ columns = append(columns, "updated_at")
+ }
+
+ return c.state.Caches.GTS.Conversation.Store(conversation, func() error {
+ _, err := NewUpsert(c.db).
+ Model(conversation).
+ Constraint("id").
+ Column(columns...).
+ Exec(ctx)
+ return err
+ })
+}
+
+func (c *conversationDB) LinkConversationToStatus(ctx context.Context, conversationID string, statusID string) error {
+ conversationToStatus := &gtsmodel.ConversationToStatus{
+ ConversationID: conversationID,
+ StatusID: statusID,
+ }
+
+ if _, err := c.db.NewInsert().
+ Model(conversationToStatus).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *conversationDB) DeleteConversationByID(ctx context.Context, id string) error {
+ // Load conversation into cache before attempting a delete,
+ // as we need it cached in order to trigger the invalidate
+ // callback. This in turn invalidates others.
+ _, err := c.GetConversationByID(gtscontext.SetBarebones(ctx), id)
+ if err != nil {
+ if errors.Is(err, db.ErrNoEntries) {
+ // not an issue.
+ err = nil
+ }
+ return err
+ }
+
+ // Drop this now-cached conversation on return after delete.
+ defer c.state.Caches.GTS.Conversation.Invalidate("ID", id)
+
+ // Finally delete conversation from DB.
+ _, err = c.db.NewDelete().
+ Model((*gtsmodel.Conversation)(nil)).
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ return err
+}
+
+func (c *conversationDB) DeleteConversationsByOwnerAccountID(ctx context.Context, accountID string) error {
+ defer func() {
+ // Invalidate any cached conversations and conversation IDs owned by this account on return.
+ // Conversation invalidate hooks only invalidate the conversation ID cache,
+ // so we don't need to load all conversations into the cache to run invalidation hooks,
+ // as with some other object types (blocks, for example).
+ c.state.Caches.GTS.Conversation.Invalidate("AccountID", accountID)
+ // In case there were no cached conversations,
+ // explicitly invalidate the user's conversation last status ID cache.
+ c.state.Caches.GTS.ConversationLastStatusIDs.Invalidate(accountID)
+ }()
+
+ return c.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Delete conversations matching the account ID.
+ deletedConversationIDs := []string{}
+ if err := tx.NewDelete().
+ Model((*gtsmodel.Conversation)(nil)).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Returning("?", bun.Ident("id")).
+ Scan(ctx, &deletedConversationIDs); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversations for account %s: %w", accountID, err)
+ }
+
+ // Delete any conversation-to-status links matching the deleted conversation IDs.
+ if _, err := tx.NewDelete().
+ Model((*gtsmodel.ConversationToStatus)(nil)).
+ Where("? IN (?)", bun.Ident("conversation_id"), bun.In(deletedConversationIDs)).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversation-to-status links for account %s: %w", accountID, err)
+ }
+
+ return nil
+ })
+}
+
+func (c *conversationDB) DeleteStatusFromConversations(ctx context.Context, statusID string) error {
+ // SQL returning the current time.
+ var nowSQL string
+ switch c.db.Dialect().Name() {
+ case dialect.SQLite:
+ nowSQL = "DATE('now')"
+ case dialect.PG:
+ nowSQL = "NOW()"
+ default:
+ log.Panicf(nil, "db conn %s was neither pg nor sqlite", c.db)
+ }
+
+ updatedConversationIDs := []string{}
+ deletedConversationIDs := []string{}
+
+ if err := c.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Delete this status from conversation-to-status links.
+ if _, err := tx.NewDelete().
+ Model((*gtsmodel.ConversationToStatus)(nil)).
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversation-to-status links while deleting status %s: %w", statusID, err)
+ }
+
+ // Note: Bun doesn't currently support CREATE TABLE … AS SELECT … so we need to use raw queries here.
+
+ // Create a temporary table with all statuses other than the deleted status
+ // in each conversation for which the deleted status is the last status
+ // (if there are such statuses).
+ conversationStatusesTempTable := "conversation_statuses_" + id.NewULID()
+ if _, err := tx.NewRaw(
+ "CREATE TEMPORARY TABLE ? AS ?",
+ bun.Ident(conversationStatusesTempTable),
+ tx.NewSelect().
+ ColumnExpr(
+ "? AS ?",
+ bun.Ident("conversations.id"),
+ bun.Ident("conversation_id"),
+ ).
+ ColumnExpr(
+ "? AS ?",
+ bun.Ident("conversation_to_statuses.status_id"),
+ bun.Ident("id"),
+ ).
+ Column("statuses.created_at").
+ Table("conversations").
+ Join("LEFT JOIN ?", bun.Ident("conversation_to_statuses")).
+ JoinOn(
+ "? = ?",
+ bun.Ident("conversations.id"),
+ bun.Ident("conversation_to_statuses.conversation_id"),
+ ).
+ JoinOn(
+ "? != ?",
+ bun.Ident("conversation_to_statuses.status_id"),
+ statusID,
+ ).
+ Join("LEFT JOIN ?", bun.Ident("statuses")).
+ JoinOn(
+ "? = ?",
+ bun.Ident("conversation_to_statuses.status_id"),
+ bun.Ident("statuses.id"),
+ ).
+ Where(
+ "? = ?",
+ bun.Ident("conversations.last_status_id"),
+ statusID,
+ ),
+ ).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error creating conversationStatusesTempTable while deleting status %s: %w", statusID, err)
+ }
+
+ // Create a temporary table with the most recently created status in each conversation
+ // for which the deleted status is the last status (if there is such a status).
+ latestConversationStatusesTempTable := "latest_conversation_statuses_" + id.NewULID()
+ if _, err := tx.NewRaw(
+ "CREATE TEMPORARY TABLE ? AS ?",
+ bun.Ident(latestConversationStatusesTempTable),
+ tx.NewSelect().
+ Column(
+ "conversation_statuses.conversation_id",
+ "conversation_statuses.id",
+ ).
+ TableExpr(
+ "? AS ?",
+ bun.Ident(conversationStatusesTempTable),
+ bun.Ident("conversation_statuses"),
+ ).
+ Join(
+ "LEFT JOIN ? AS ?",
+ bun.Ident(conversationStatusesTempTable),
+ bun.Ident("later_statuses"),
+ ).
+ JoinOn(
+ "? = ?",
+ bun.Ident("conversation_statuses.conversation_id"),
+ bun.Ident("later_statuses.conversation_id"),
+ ).
+ JoinOn(
+ "? > ?",
+ bun.Ident("later_statuses.created_at"),
+ bun.Ident("conversation_statuses.created_at"),
+ ).
+ Where("? IS NULL", bun.Ident("later_statuses.id")),
+ ).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return gtserror.Newf("error creating latestConversationStatusesTempTable while deleting status %s: %w", statusID, err)
+ }
+
+ // For every conversation where the given status was the last one,
+ // reset its last status to the most recently created in the conversation other than that one,
+ // if there is such a status.
+ // Return conversation IDs for invalidation.
+ if err := tx.NewUpdate().
+ Model((*gtsmodel.Conversation)(nil)).
+ SetColumn("last_status_id", "?", bun.Ident("latest_conversation_statuses.id")).
+ SetColumn("updated_at", "?", bun.Safe(nowSQL)).
+ TableExpr("? AS ?", bun.Ident(latestConversationStatusesTempTable), bun.Ident("latest_conversation_statuses")).
+ Where("?TableAlias.? = ?", bun.Ident("id"), bun.Ident("latest_conversation_statuses.conversation_id")).
+ Where("? IS NOT NULL", bun.Ident("latest_conversation_statuses.id")).
+ Returning("?TableName.?", bun.Ident("id")).
+ Scan(ctx, &updatedConversationIDs); // nocollapse
+ err != nil {
+ return gtserror.Newf("error rolling back last status for conversation while deleting status %s: %w", statusID, err)
+ }
+
+ // If there is no such status, delete the conversation.
+ // Return conversation IDs for invalidation.
+ if err := tx.NewDelete().
+ Model((*gtsmodel.Conversation)(nil)).
+ Where(
+ "? IN (?)",
+ bun.Ident("id"),
+ tx.NewSelect().
+ Table(latestConversationStatusesTempTable).
+ Column("conversation_id").
+ Where("? IS NULL", bun.Ident("id")),
+ ).
+ Returning("?", bun.Ident("id")).
+ Scan(ctx, &deletedConversationIDs); // nocollapse
+ err != nil {
+ return gtserror.Newf("error deleting conversation while deleting status %s: %w", statusID, err)
+ }
+
+ // Clean up.
+ for _, tempTable := range []string{
+ conversationStatusesTempTable,
+ latestConversationStatusesTempTable,
+ } {
+ if _, err := tx.NewDropTable().Table(tempTable).Exec(ctx); err != nil {
+ return gtserror.Newf(
+ "error dropping temporary table %s after deleting status %s: %w",
+ tempTable,
+ statusID,
+ err,
+ )
+ }
+ }
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ updatedConversationIDs = append(updatedConversationIDs, deletedConversationIDs...)
+ c.state.Caches.GTS.Conversation.InvalidateIDs("ID", updatedConversationIDs)
+
+ return nil
+}
diff --git a/internal/db/bundb/conversation_test.go b/internal/db/bundb/conversation_test.go
new file mode 100644
index 000000000..24d35d482
--- /dev/null
+++ b/internal/db/bundb/conversation_test.go
@@ -0,0 +1,115 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/db/test"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type ConversationTestSuite struct {
+ BunDBStandardTestSuite
+
+ cf test.ConversationFactory
+
+ // testAccount is the owner of statuses and conversations in these tests (must be local).
+ testAccount *gtsmodel.Account
+ // threadID is the thread used for statuses in any given test.
+ threadID string
+}
+
+func (suite *ConversationTestSuite) SetupSuite() {
+ suite.BunDBStandardTestSuite.SetupSuite()
+
+ suite.cf.SetupSuite(suite)
+
+ suite.testAccount = suite.testAccounts["local_account_1"]
+}
+
+func (suite *ConversationTestSuite) SetupTest() {
+ suite.BunDBStandardTestSuite.SetupTest()
+
+ suite.cf.SetupTest(suite.db)
+
+ suite.threadID = suite.cf.NewULID(0)
+}
+
+// deleteStatus deletes a status from conversations and ends the test if that fails.
+func (suite *ConversationTestSuite) deleteStatus(statusID string) {
+ err := suite.db.DeleteStatusFromConversations(context.Background(), statusID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+}
+
+// getConversation fetches a conversation by ID and ends the test if that fails.
+func (suite *ConversationTestSuite) getConversation(conversationID string) *gtsmodel.Conversation {
+ conversation, err := suite.db.GetConversationByID(context.Background(), conversationID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ return conversation
+}
+
+// If we delete a status that is in a conversation but not the last status,
+// the conversation's last status should not change.
+func (suite *ConversationTestSuite) TestDeleteNonLastStatus() {
+ conversation := suite.cf.NewTestConversation(suite.testAccount, 0)
+ initial := conversation.LastStatus
+ reply := suite.cf.NewTestStatus(suite.testAccount, conversation.ThreadID, 1*time.Second, initial)
+ conversation = suite.cf.SetLastStatus(conversation, reply)
+
+ suite.deleteStatus(initial.ID)
+ conversation = suite.getConversation(conversation.ID)
+ suite.Equal(reply.ID, conversation.LastStatusID)
+}
+
+// If we delete the last status in a conversation that has other statuses,
+// a previous status should become the new last status.
+func (suite *ConversationTestSuite) TestDeleteLastStatus() {
+ conversation := suite.cf.NewTestConversation(suite.testAccount, 0)
+ initial := conversation.LastStatus
+ reply := suite.cf.NewTestStatus(suite.testAccount, conversation.ThreadID, 1*time.Second, initial)
+ conversation = suite.cf.SetLastStatus(conversation, reply)
+ conversation = suite.getConversation(conversation.ID)
+
+ suite.deleteStatus(reply.ID)
+ conversation = suite.getConversation(conversation.ID)
+ suite.Equal(initial.ID, conversation.LastStatusID)
+}
+
+// If we delete the only status in a conversation,
+// the conversation should be deleted as well.
+func (suite *ConversationTestSuite) TestDeleteOnlyStatus() {
+ conversation := suite.cf.NewTestConversation(suite.testAccount, 0)
+ initial := conversation.LastStatus
+
+ suite.deleteStatus(initial.ID)
+ _, err := suite.db.GetConversationByID(context.Background(), conversation.ID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+}
+
+func TestConversationTestSuite(t *testing.T) {
+ suite.Run(t, new(ConversationTestSuite))
+}
diff --git a/internal/db/bundb/migrations/20240611190733_add_conversations.go b/internal/db/bundb/migrations/20240611190733_add_conversations.go
new file mode 100644
index 000000000..25b226aff
--- /dev/null
+++ b/internal/db/bundb/migrations/20240611190733_add_conversations.go
@@ -0,0 +1,78 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+// Note: this migration has an advanced migration followup.
+// See Conversations.MigrateDMs().
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ for _, model := range []interface{}{
+ &gtsmodel.Conversation{},
+ &gtsmodel.ConversationToStatus{},
+ } {
+ if _, err := tx.
+ NewCreateTable().
+ Model(model).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // Add indexes to the conversations table.
+ for index, columns := range map[string][]string{
+ "conversations_account_id_idx": {
+ "account_id",
+ },
+ "conversations_last_status_id_idx": {
+ "last_status_id",
+ },
+ } {
+ if _, err := tx.
+ NewCreateIndex().
+ Model(&gtsmodel.Conversation{}).
+ Index(index).
+ Column(columns...).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go b/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go
new file mode 100644
index 000000000..183065285
--- /dev/null
+++ b/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go
@@ -0,0 +1,49 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+// Create the advanced migrations table.
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ _, err := tx.
+ NewCreateTable().
+ Model((*gtsmodel.AdvancedMigration)(nil)).
+ IfNotExists().
+ Exec(ctx)
+ return err
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index dfb97cff1..b0ed32e0e 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -682,3 +682,35 @@ func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]st
return statusIDs, nil
})
}
+
+func (s *statusDB) MaxDirectStatusID(ctx context.Context) (string, error) {
+ maxID := ""
+ if err := s.db.
+ NewSelect().
+ Model((*gtsmodel.Status)(nil)).
+ ColumnExpr("COALESCE(MAX(?), '')", bun.Ident("id")).
+ Where("? = ?", bun.Ident("visibility"), gtsmodel.VisibilityDirect).
+ Scan(ctx, &maxID); // nocollapse
+ err != nil {
+ return "", err
+ }
+ return maxID, nil
+}
+
+func (s *statusDB) GetDirectStatusIDsBatch(ctx context.Context, minID string, maxIDInclusive string, count int) ([]string, error) {
+ var statusIDs []string
+ if err := s.db.
+ NewSelect().
+ Model((*gtsmodel.Status)(nil)).
+ Column("id").
+ Where("? = ?", bun.Ident("visibility"), gtsmodel.VisibilityDirect).
+ Where("? > ?", bun.Ident("id"), minID).
+ Where("? <= ?", bun.Ident("id"), maxIDInclusive).
+ Order("id ASC").
+ Limit(count).
+ Scan(ctx, &statusIDs); // nocollapse
+ err != nil {
+ return nil, err
+ }
+ return statusIDs, nil
+}