diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/advancedmigration.go | 29 | ||||
| -rw-r--r-- | internal/db/bundb/advancedmigration.go | 52 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 11 | ||||
| -rw-r--r-- | internal/db/bundb/conversation.go | 494 | ||||
| -rw-r--r-- | internal/db/bundb/conversation_test.go | 115 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20240611190733_add_conversations.go | 78 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go | 49 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 32 | ||||
| -rw-r--r-- | internal/db/conversation.go | 52 | ||||
| -rw-r--r-- | internal/db/db.go | 2 | ||||
| -rw-r--r-- | internal/db/status.go | 12 | ||||
| -rw-r--r-- | internal/db/test/conversation.go | 122 | 
12 files changed, 1048 insertions, 0 deletions
| diff --git a/internal/db/advancedmigration.go b/internal/db/advancedmigration.go new file mode 100644 index 000000000..2b4601bdb --- /dev/null +++ b/internal/db/advancedmigration.go @@ -0,0 +1,29 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type AdvancedMigration interface { +	GetAdvancedMigration(ctx context.Context, id string) (*gtsmodel.AdvancedMigration, error) +	PutAdvancedMigration(ctx context.Context, advancedMigration *gtsmodel.AdvancedMigration) error +} diff --git a/internal/db/bundb/advancedmigration.go b/internal/db/bundb/advancedmigration.go new file mode 100644 index 000000000..2a0ec93e6 --- /dev/null +++ b/internal/db/bundb/advancedmigration.go @@ -0,0 +1,52 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/uptrace/bun" +) + +type advancedMigrationDB struct { +	db    *bun.DB +	state *state.State +} + +func (a *advancedMigrationDB) GetAdvancedMigration(ctx context.Context, id string) (*gtsmodel.AdvancedMigration, error) { +	var advancedMigration gtsmodel.AdvancedMigration +	err := a.db.NewSelect(). +		Model(&advancedMigration). +		Where("? = ?", bun.Ident("id"), id). +		Limit(1). +		Scan(ctx) +	if err != nil { +		return nil, err +	} +	return &advancedMigration, nil +} + +func (a *advancedMigrationDB) PutAdvancedMigration(ctx context.Context, advancedMigration *gtsmodel.AdvancedMigration) error { +	_, err := NewUpsert(a.db). +		Model(advancedMigration). +		Constraint("id"). +		Exec(ctx) +	return err +} diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 57fb661df..070d4eb91 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -54,8 +54,10 @@ import (  type DBService struct {  	db.Account  	db.Admin +	db.AdvancedMigration  	db.Application  	db.Basic +	db.Conversation  	db.Domain  	db.Emoji  	db.HeaderFilter @@ -158,6 +160,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  	// https://bun.uptrace.dev/orm/many-to-many-relation/  	for _, t := range []interface{}{  		>smodel.AccountToEmoji{}, +		>smodel.ConversationToStatus{},  		>smodel.StatusToEmoji{},  		>smodel.StatusToTag{},  		>smodel.ThreadToStatus{}, @@ -181,6 +184,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			db:    db,  			state: state,  		}, +		AdvancedMigration: &advancedMigrationDB{ +			db:    db, +			state: state, +		},  		Application: &applicationDB{  			db:    db,  			state: state, @@ -188,6 +195,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  		Basic: &basicDB{  			db: db,  		}, +		Conversation: &conversationDB{ +			db:    db, +			state: state, +		},  		Domain: &domainDB{  			db:    db,  			state: state, diff --git a/internal/db/bundb/conversation.go b/internal/db/bundb/conversation.go new file mode 100644 index 000000000..1a3958a79 --- /dev/null +++ b/internal/db/bundb/conversation.go @@ -0,0 +1,494 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"errors" +	"slices" + +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/paging" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util" +	"github.com/uptrace/bun" +	"github.com/uptrace/bun/dialect" +) + +type conversationDB struct { +	db    *bun.DB +	state *state.State +} + +func (c *conversationDB) GetConversationByID(ctx context.Context, id string) (*gtsmodel.Conversation, error) { +	return c.getConversation( +		ctx, +		"ID", +		func(conversation *gtsmodel.Conversation) error { +			return c.db. +				NewSelect(). +				Model(conversation). +				Where("? = ?", bun.Ident("id"), id). +				Scan(ctx) +		}, +		id, +	) +} + +func (c *conversationDB) GetConversationByThreadAndAccountIDs(ctx context.Context, threadID string, accountID string, otherAccountIDs []string) (*gtsmodel.Conversation, error) { +	otherAccountsKey := gtsmodel.ConversationOtherAccountsKey(otherAccountIDs) +	return c.getConversation( +		ctx, +		"ThreadID,AccountID,OtherAccountsKey", +		func(conversation *gtsmodel.Conversation) error { +			return c.db. +				NewSelect(). +				Model(conversation). +				Where("? = ?", bun.Ident("thread_id"), threadID). +				Where("? = ?", bun.Ident("account_id"), accountID). +				Where("? = ?", bun.Ident("other_accounts_key"), otherAccountsKey). +				Scan(ctx) +		}, +		threadID, +		accountID, +		otherAccountsKey, +	) +} + +func (c *conversationDB) getConversation( +	ctx context.Context, +	lookup string, +	dbQuery func(conversation *gtsmodel.Conversation) error, +	keyParts ...any, +) (*gtsmodel.Conversation, error) { +	// Fetch conversation from cache with loader callback +	conversation, err := c.state.Caches.GTS.Conversation.LoadOne(lookup, func() (*gtsmodel.Conversation, error) { +		var conversation gtsmodel.Conversation + +		// Not cached! Perform database query +		if err := dbQuery(&conversation); err != nil { +			return nil, err +		} + +		return &conversation, nil +	}, keyParts...) +	if err != nil { +		// already processe +		return nil, err +	} + +	if gtscontext.Barebones(ctx) { +		// Only a barebones model was requested. +		return conversation, nil +	} + +	if err := c.populateConversation(ctx, conversation); err != nil { +		return nil, err +	} + +	return conversation, nil +} + +func (c *conversationDB) populateConversation(ctx context.Context, conversation *gtsmodel.Conversation) error { +	var ( +		errs gtserror.MultiError +		err  error +	) + +	if conversation.Account == nil { +		conversation.Account, err = c.state.DB.GetAccountByID( +			gtscontext.SetBarebones(ctx), +			conversation.AccountID, +		) +		if err != nil { +			errs.Appendf("error populating conversation owner account: %w", err) +		} +	} + +	if conversation.OtherAccounts == nil { +		conversation.OtherAccounts, err = c.state.DB.GetAccountsByIDs( +			gtscontext.SetBarebones(ctx), +			conversation.OtherAccountIDs, +		) +		if err != nil { +			errs.Appendf("error populating other conversation accounts: %w", err) +		} +	} + +	if conversation.LastStatus == nil && conversation.LastStatusID != "" { +		conversation.LastStatus, err = c.state.DB.GetStatusByID( +			gtscontext.SetBarebones(ctx), +			conversation.LastStatusID, +		) +		if err != nil { +			errs.Appendf("error populating conversation last status: %w", err) +		} +	} + +	return errs.Combine() +} + +func (c *conversationDB) GetConversationsByOwnerAccountID(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Conversation, error) { +	conversationLastStatusIDs, err := c.getAccountConversationLastStatusIDs(ctx, accountID, page) +	if err != nil { +		return nil, err +	} +	return c.getConversationsByLastStatusIDs(ctx, accountID, conversationLastStatusIDs) +} + +func (c *conversationDB) getAccountConversationLastStatusIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +	return loadPagedIDs(&c.state.Caches.GTS.ConversationLastStatusIDs, accountID, page, func() ([]string, error) { +		var conversationLastStatusIDs []string + +		// Conversation last status IDs not in cache. Perform DB query. +		if _, err := c.db. +			NewSelect(). +			Model((*gtsmodel.Conversation)(nil)). +			Column("last_status_id"). +			Where("? = ?", bun.Ident("account_id"), accountID). +			OrderExpr("? DESC", bun.Ident("last_status_id")). +			Exec(ctx, &conversationLastStatusIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) { +			return nil, err +		} + +		return conversationLastStatusIDs, nil +	}) +} + +func (c *conversationDB) getConversationsByLastStatusIDs( +	ctx context.Context, +	accountID string, +	conversationLastStatusIDs []string, +) ([]*gtsmodel.Conversation, error) { +	// Load all conversation IDs via cache loader callbacks. +	conversations, err := c.state.Caches.GTS.Conversation.LoadIDs2Part( +		"AccountID,LastStatusID", +		accountID, +		conversationLastStatusIDs, +		func(accountID string, uncached []string) ([]*gtsmodel.Conversation, error) { +			// Preallocate expected length of uncached conversations. +			conversations := make([]*gtsmodel.Conversation, 0, len(uncached)) + +			// Perform database query scanning the remaining (uncached) IDs. +			if err := c.db.NewSelect(). +				Model(&conversations). +				Where("? = ?", bun.Ident("account_id"), accountID). +				Where("? IN (?)", bun.Ident("last_status_id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return conversations, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the conversations by their last status IDs to ensure correct order. +	getID := func(b *gtsmodel.Conversation) string { return b.ID } +	util.OrderBy(conversations, conversationLastStatusIDs, getID) + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return conversations, nil +	} + +	// Populate all loaded conversations, removing those we fail to populate. +	conversations = slices.DeleteFunc(conversations, func(conversation *gtsmodel.Conversation) bool { +		if err := c.populateConversation(ctx, conversation); err != nil { +			log.Errorf(ctx, "error populating conversation %s: %v", conversation.ID, err) +			return true +		} +		return false +	}) + +	return conversations, nil +} + +func (c *conversationDB) UpsertConversation(ctx context.Context, conversation *gtsmodel.Conversation, columns ...string) error { +	// If we're updating by column, ensure "updated_at" is included. +	if len(columns) > 0 { +		columns = append(columns, "updated_at") +	} + +	return c.state.Caches.GTS.Conversation.Store(conversation, func() error { +		_, err := NewUpsert(c.db). +			Model(conversation). +			Constraint("id"). +			Column(columns...). +			Exec(ctx) +		return err +	}) +} + +func (c *conversationDB) LinkConversationToStatus(ctx context.Context, conversationID string, statusID string) error { +	conversationToStatus := >smodel.ConversationToStatus{ +		ConversationID: conversationID, +		StatusID:       statusID, +	} + +	if _, err := c.db.NewInsert(). +		Model(conversationToStatus). +		Exec(ctx); // nocollapse +	err != nil { +		return err +	} +	return nil +} + +func (c *conversationDB) DeleteConversationByID(ctx context.Context, id string) error { +	// Load conversation into cache before attempting a delete, +	// as we need it cached in order to trigger the invalidate +	// callback. This in turn invalidates others. +	_, err := c.GetConversationByID(gtscontext.SetBarebones(ctx), id) +	if err != nil { +		if errors.Is(err, db.ErrNoEntries) { +			// not an issue. +			err = nil +		} +		return err +	} + +	// Drop this now-cached conversation on return after delete. +	defer c.state.Caches.GTS.Conversation.Invalidate("ID", id) + +	// Finally delete conversation from DB. +	_, err = c.db.NewDelete(). +		Model((*gtsmodel.Conversation)(nil)). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx) +	return err +} + +func (c *conversationDB) DeleteConversationsByOwnerAccountID(ctx context.Context, accountID string) error { +	defer func() { +		// Invalidate any cached conversations and conversation IDs owned by this account on return. +		// Conversation invalidate hooks only invalidate the conversation ID cache, +		// so we don't need to load all conversations into the cache to run invalidation hooks, +		// as with some other object types (blocks, for example). +		c.state.Caches.GTS.Conversation.Invalidate("AccountID", accountID) +		// In case there were no cached conversations, +		// explicitly invalidate the user's conversation last status ID cache. +		c.state.Caches.GTS.ConversationLastStatusIDs.Invalidate(accountID) +	}() + +	return c.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +		// Delete conversations matching the account ID. +		deletedConversationIDs := []string{} +		if err := tx.NewDelete(). +			Model((*gtsmodel.Conversation)(nil)). +			Where("? = ?", bun.Ident("account_id"), accountID). +			Returning("?", bun.Ident("id")). +			Scan(ctx, &deletedConversationIDs); // nocollapse +		err != nil { +			return gtserror.Newf("error deleting conversations for account %s: %w", accountID, err) +		} + +		// Delete any conversation-to-status links matching the deleted conversation IDs. +		if _, err := tx.NewDelete(). +			Model((*gtsmodel.ConversationToStatus)(nil)). +			Where("? IN (?)", bun.Ident("conversation_id"), bun.In(deletedConversationIDs)). +			Exec(ctx); // nocollapse +		err != nil { +			return gtserror.Newf("error deleting conversation-to-status links for account %s: %w", accountID, err) +		} + +		return nil +	}) +} + +func (c *conversationDB) DeleteStatusFromConversations(ctx context.Context, statusID string) error { +	// SQL returning the current time. +	var nowSQL string +	switch c.db.Dialect().Name() { +	case dialect.SQLite: +		nowSQL = "DATE('now')" +	case dialect.PG: +		nowSQL = "NOW()" +	default: +		log.Panicf(nil, "db conn %s was neither pg nor sqlite", c.db) +	} + +	updatedConversationIDs := []string{} +	deletedConversationIDs := []string{} + +	if err := c.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +		// Delete this status from conversation-to-status links. +		if _, err := tx.NewDelete(). +			Model((*gtsmodel.ConversationToStatus)(nil)). +			Where("? = ?", bun.Ident("status_id"), statusID). +			Exec(ctx); // nocollapse +		err != nil { +			return gtserror.Newf("error deleting conversation-to-status links while deleting status %s: %w", statusID, err) +		} + +		// Note: Bun doesn't currently support CREATE TABLE … AS SELECT … so we need to use raw queries here. + +		// Create a temporary table with all statuses other than the deleted status +		// in each conversation for which the deleted status is the last status +		// (if there are such statuses). +		conversationStatusesTempTable := "conversation_statuses_" + id.NewULID() +		if _, err := tx.NewRaw( +			"CREATE TEMPORARY TABLE ? AS ?", +			bun.Ident(conversationStatusesTempTable), +			tx.NewSelect(). +				ColumnExpr( +					"? AS ?", +					bun.Ident("conversations.id"), +					bun.Ident("conversation_id"), +				). +				ColumnExpr( +					"? AS ?", +					bun.Ident("conversation_to_statuses.status_id"), +					bun.Ident("id"), +				). +				Column("statuses.created_at"). +				Table("conversations"). +				Join("LEFT JOIN ?", bun.Ident("conversation_to_statuses")). +				JoinOn( +					"? = ?", +					bun.Ident("conversations.id"), +					bun.Ident("conversation_to_statuses.conversation_id"), +				). +				JoinOn( +					"? != ?", +					bun.Ident("conversation_to_statuses.status_id"), +					statusID, +				). +				Join("LEFT JOIN ?", bun.Ident("statuses")). +				JoinOn( +					"? = ?", +					bun.Ident("conversation_to_statuses.status_id"), +					bun.Ident("statuses.id"), +				). +				Where( +					"? = ?", +					bun.Ident("conversations.last_status_id"), +					statusID, +				), +		). +			Exec(ctx); // nocollapse +		err != nil { +			return gtserror.Newf("error creating conversationStatusesTempTable while deleting status %s: %w", statusID, err) +		} + +		// Create a temporary table with the most recently created status in each conversation +		// for which the deleted status is the last status (if there is such a status). +		latestConversationStatusesTempTable := "latest_conversation_statuses_" + id.NewULID() +		if _, err := tx.NewRaw( +			"CREATE TEMPORARY TABLE ? AS ?", +			bun.Ident(latestConversationStatusesTempTable), +			tx.NewSelect(). +				Column( +					"conversation_statuses.conversation_id", +					"conversation_statuses.id", +				). +				TableExpr( +					"? AS ?", +					bun.Ident(conversationStatusesTempTable), +					bun.Ident("conversation_statuses"), +				). +				Join( +					"LEFT JOIN ? AS ?", +					bun.Ident(conversationStatusesTempTable), +					bun.Ident("later_statuses"), +				). +				JoinOn( +					"? = ?", +					bun.Ident("conversation_statuses.conversation_id"), +					bun.Ident("later_statuses.conversation_id"), +				). +				JoinOn( +					"? > ?", +					bun.Ident("later_statuses.created_at"), +					bun.Ident("conversation_statuses.created_at"), +				). +				Where("? IS NULL", bun.Ident("later_statuses.id")), +		). +			Exec(ctx); // nocollapse +		err != nil { +			return gtserror.Newf("error creating latestConversationStatusesTempTable while deleting status %s: %w", statusID, err) +		} + +		// For every conversation where the given status was the last one, +		// reset its last status to the most recently created in the conversation other than that one, +		// if there is such a status. +		// Return conversation IDs for invalidation. +		if err := tx.NewUpdate(). +			Model((*gtsmodel.Conversation)(nil)). +			SetColumn("last_status_id", "?", bun.Ident("latest_conversation_statuses.id")). +			SetColumn("updated_at", "?", bun.Safe(nowSQL)). +			TableExpr("? AS ?", bun.Ident(latestConversationStatusesTempTable), bun.Ident("latest_conversation_statuses")). +			Where("?TableAlias.? = ?", bun.Ident("id"), bun.Ident("latest_conversation_statuses.conversation_id")). +			Where("? IS NOT NULL", bun.Ident("latest_conversation_statuses.id")). +			Returning("?TableName.?", bun.Ident("id")). +			Scan(ctx, &updatedConversationIDs); // nocollapse +		err != nil { +			return gtserror.Newf("error rolling back last status for conversation while deleting status %s: %w", statusID, err) +		} + +		// If there is no such status, delete the conversation. +		// Return conversation IDs for invalidation. +		if err := tx.NewDelete(). +			Model((*gtsmodel.Conversation)(nil)). +			Where( +				"? IN (?)", +				bun.Ident("id"), +				tx.NewSelect(). +					Table(latestConversationStatusesTempTable). +					Column("conversation_id"). +					Where("? IS NULL", bun.Ident("id")), +			). +			Returning("?", bun.Ident("id")). +			Scan(ctx, &deletedConversationIDs); // nocollapse +		err != nil { +			return gtserror.Newf("error deleting conversation while deleting status %s: %w", statusID, err) +		} + +		// Clean up. +		for _, tempTable := range []string{ +			conversationStatusesTempTable, +			latestConversationStatusesTempTable, +		} { +			if _, err := tx.NewDropTable().Table(tempTable).Exec(ctx); err != nil { +				return gtserror.Newf( +					"error dropping temporary table %s after deleting status %s: %w", +					tempTable, +					statusID, +					err, +				) +			} +		} + +		return nil +	}); err != nil { +		return err +	} + +	updatedConversationIDs = append(updatedConversationIDs, deletedConversationIDs...) +	c.state.Caches.GTS.Conversation.InvalidateIDs("ID", updatedConversationIDs) + +	return nil +} diff --git a/internal/db/bundb/conversation_test.go b/internal/db/bundb/conversation_test.go new file mode 100644 index 000000000..24d35d482 --- /dev/null +++ b/internal/db/bundb/conversation_test.go @@ -0,0 +1,115 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( +	"context" +	"testing" +	"time" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/db/test" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type ConversationTestSuite struct { +	BunDBStandardTestSuite + +	cf test.ConversationFactory + +	// testAccount is the owner of statuses and conversations in these tests (must be local). +	testAccount *gtsmodel.Account +	// threadID is the thread used for statuses in any given test. +	threadID string +} + +func (suite *ConversationTestSuite) SetupSuite() { +	suite.BunDBStandardTestSuite.SetupSuite() + +	suite.cf.SetupSuite(suite) + +	suite.testAccount = suite.testAccounts["local_account_1"] +} + +func (suite *ConversationTestSuite) SetupTest() { +	suite.BunDBStandardTestSuite.SetupTest() + +	suite.cf.SetupTest(suite.db) + +	suite.threadID = suite.cf.NewULID(0) +} + +// deleteStatus deletes a status from conversations and ends the test if that fails. +func (suite *ConversationTestSuite) deleteStatus(statusID string) { +	err := suite.db.DeleteStatusFromConversations(context.Background(), statusID) +	if err != nil { +		suite.FailNow(err.Error()) +	} +} + +// getConversation fetches a conversation by ID and ends the test if that fails. +func (suite *ConversationTestSuite) getConversation(conversationID string) *gtsmodel.Conversation { +	conversation, err := suite.db.GetConversationByID(context.Background(), conversationID) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	return conversation +} + +// If we delete a status that is in a conversation but not the last status, +// the conversation's last status should not change. +func (suite *ConversationTestSuite) TestDeleteNonLastStatus() { +	conversation := suite.cf.NewTestConversation(suite.testAccount, 0) +	initial := conversation.LastStatus +	reply := suite.cf.NewTestStatus(suite.testAccount, conversation.ThreadID, 1*time.Second, initial) +	conversation = suite.cf.SetLastStatus(conversation, reply) + +	suite.deleteStatus(initial.ID) +	conversation = suite.getConversation(conversation.ID) +	suite.Equal(reply.ID, conversation.LastStatusID) +} + +// If we delete the last status in a conversation that has other statuses, +// a previous status should become the new last status. +func (suite *ConversationTestSuite) TestDeleteLastStatus() { +	conversation := suite.cf.NewTestConversation(suite.testAccount, 0) +	initial := conversation.LastStatus +	reply := suite.cf.NewTestStatus(suite.testAccount, conversation.ThreadID, 1*time.Second, initial) +	conversation = suite.cf.SetLastStatus(conversation, reply) +	conversation = suite.getConversation(conversation.ID) + +	suite.deleteStatus(reply.ID) +	conversation = suite.getConversation(conversation.ID) +	suite.Equal(initial.ID, conversation.LastStatusID) +} + +// If we delete the only status in a conversation, +// the conversation should be deleted as well. +func (suite *ConversationTestSuite) TestDeleteOnlyStatus() { +	conversation := suite.cf.NewTestConversation(suite.testAccount, 0) +	initial := conversation.LastStatus + +	suite.deleteStatus(initial.ID) +	_, err := suite.db.GetConversationByID(context.Background(), conversation.ID) +	suite.ErrorIs(err, db.ErrNoEntries) +} + +func TestConversationTestSuite(t *testing.T) { +	suite.Run(t, new(ConversationTestSuite)) +} diff --git a/internal/db/bundb/migrations/20240611190733_add_conversations.go b/internal/db/bundb/migrations/20240611190733_add_conversations.go new file mode 100644 index 000000000..25b226aff --- /dev/null +++ b/internal/db/bundb/migrations/20240611190733_add_conversations.go @@ -0,0 +1,78 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +// Note: this migration has an advanced migration followup. +// See Conversations.MigrateDMs(). +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			for _, model := range []interface{}{ +				>smodel.Conversation{}, +				>smodel.ConversationToStatus{}, +			} { +				if _, err := tx. +					NewCreateTable(). +					Model(model). +					IfNotExists(). +					Exec(ctx); err != nil { +					return err +				} +			} + +			// Add indexes to the conversations table. +			for index, columns := range map[string][]string{ +				"conversations_account_id_idx": { +					"account_id", +				}, +				"conversations_last_status_id_idx": { +					"last_status_id", +				}, +			} { +				if _, err := tx. +					NewCreateIndex(). +					Model(>smodel.Conversation{}). +					Index(index). +					Column(columns...). +					IfNotExists(). +					Exec(ctx); err != nil { +					return err +				} +			} + +			return nil +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go b/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go new file mode 100644 index 000000000..183065285 --- /dev/null +++ b/internal/db/bundb/migrations/20240712005536_add_advanced_migrations.go @@ -0,0 +1,49 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +// Create the advanced migrations table. +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			_, err := tx. +				NewCreateTable(). +				Model((*gtsmodel.AdvancedMigration)(nil)). +				IfNotExists(). +				Exec(ctx) +			return err +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index dfb97cff1..b0ed32e0e 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -682,3 +682,35 @@ func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]st  		return statusIDs, nil  	})  } + +func (s *statusDB) MaxDirectStatusID(ctx context.Context) (string, error) { +	maxID := "" +	if err := s.db. +		NewSelect(). +		Model((*gtsmodel.Status)(nil)). +		ColumnExpr("COALESCE(MAX(?), '')", bun.Ident("id")). +		Where("? = ?", bun.Ident("visibility"), gtsmodel.VisibilityDirect). +		Scan(ctx, &maxID); // nocollapse +	err != nil { +		return "", err +	} +	return maxID, nil +} + +func (s *statusDB) GetDirectStatusIDsBatch(ctx context.Context, minID string, maxIDInclusive string, count int) ([]string, error) { +	var statusIDs []string +	if err := s.db. +		NewSelect(). +		Model((*gtsmodel.Status)(nil)). +		Column("id"). +		Where("? = ?", bun.Ident("visibility"), gtsmodel.VisibilityDirect). +		Where("? > ?", bun.Ident("id"), minID). +		Where("? <= ?", bun.Ident("id"), maxIDInclusive). +		Order("id ASC"). +		Limit(count). +		Scan(ctx, &statusIDs); // nocollapse +	err != nil { +		return nil, err +	} +	return statusIDs, nil +} diff --git a/internal/db/conversation.go b/internal/db/conversation.go new file mode 100644 index 000000000..3d0b4213e --- /dev/null +++ b/internal/db/conversation.go @@ -0,0 +1,52 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/paging" +) + +type Conversation interface { +	// GetConversationByID gets a single conversation by ID. +	GetConversationByID(ctx context.Context, id string) (*gtsmodel.Conversation, error) + +	// GetConversationByThreadAndAccountIDs retrieves a conversation by thread ID and participant account IDs, if it exists. +	GetConversationByThreadAndAccountIDs(ctx context.Context, threadID string, accountID string, otherAccountIDs []string) (*gtsmodel.Conversation, error) + +	// GetConversationsByOwnerAccountID gets all conversations owned by the given account, +	// with optional paging based on last status ID. +	GetConversationsByOwnerAccountID(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Conversation, error) + +	// UpsertConversation creates or updates a conversation. +	UpsertConversation(ctx context.Context, conversation *gtsmodel.Conversation, columns ...string) error + +	// LinkConversationToStatus creates a conversation-to-status link. +	LinkConversationToStatus(ctx context.Context, statusID string, conversationID string) error + +	// DeleteConversationByID deletes a conversation, removing it from the owning account's conversation list. +	DeleteConversationByID(ctx context.Context, id string) error + +	// DeleteConversationsByOwnerAccountID deletes all conversations owned by the given account. +	DeleteConversationsByOwnerAccountID(ctx context.Context, accountID string) error + +	// DeleteStatusFromConversations handles when a status is deleted by updating or deleting conversations for which it was the last status. +	DeleteStatusFromConversations(ctx context.Context, statusID string) error +} diff --git a/internal/db/db.go b/internal/db/db.go index a148d778a..4b2152732 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -26,8 +26,10 @@ const (  type DB interface {  	Account  	Admin +	AdvancedMigration  	Application  	Basic +	Conversation  	Domain  	Emoji  	HeaderFilter diff --git a/internal/db/status.go b/internal/db/status.go index 88ae12a12..ade900728 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -78,4 +78,16 @@ type Status interface {  	// GetStatusChildren gets the child statuses of a given status.  	GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) + +	// MaxDirectStatusID returns the newest ID across all DM statuses. +	// Returns the empty string with no error if there are no DM statuses yet. +	// It is used only by the conversation advanced migration. +	MaxDirectStatusID(ctx context.Context) (string, error) + +	// GetDirectStatusIDsBatch returns up to count DM status IDs strictly greater than minID +	// and less than or equal to maxIDInclusive. Note that this is different from most of our paging, +	// which uses a maxID and returns IDs strictly less than that, because it's called with the result of +	// MaxDirectStatusID, and expects to eventually return the status with that ID. +	// It is used only by the conversation advanced migration. +	GetDirectStatusIDsBatch(ctx context.Context, minID string, maxIDInclusive string, count int) ([]string, error)  } diff --git a/internal/db/test/conversation.go b/internal/db/test/conversation.go new file mode 100644 index 000000000..95713927e --- /dev/null +++ b/internal/db/test/conversation.go @@ -0,0 +1,122 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package test + +import ( +	"context" +	"crypto/rand" +	"time" + +	"github.com/oklog/ulid" +	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/util" +) + +type testSuite interface { +	FailNow(string, ...interface{}) bool +} + +// ConversationFactory can be embedded or included by test suites that want to generate statuses and conversations. +type ConversationFactory struct { +	// Test suite, or at least the methods from it that we care about. +	suite testSuite +	// Test DB. +	db db.DB + +	// TestStart is the timestamp used as a base for timestamps and ULIDs in any given test. +	TestStart time.Time +} + +// SetupSuite should be called by the SetupSuite of test suites that use this mixin. +func (f *ConversationFactory) SetupSuite(suite testSuite) { +	f.suite = suite +} + +// SetupTest should be called by the SetupTest of test suites that use this mixin. +func (f *ConversationFactory) SetupTest(db db.DB) { +	f.db = db +	f.TestStart = time.Now() +} + +// NewULID is a version of id.NewULID that uses the test start time and an offset instead of the real time. +func (f *ConversationFactory) NewULID(offset time.Duration) string { +	ulid, err := ulid.New( +		ulid.Timestamp(f.TestStart.Add(offset)), rand.Reader, +	) +	if err != nil { +		panic(err) +	} +	return ulid.String() +} + +func (f *ConversationFactory) NewTestStatus(localAccount *gtsmodel.Account, threadID string, nowOffset time.Duration, inReplyToStatus *gtsmodel.Status) *gtsmodel.Status { +	statusID := f.NewULID(nowOffset) +	createdAt := f.TestStart.Add(nowOffset) +	status := >smodel.Status{ +		ID:                  statusID, +		CreatedAt:           createdAt, +		UpdatedAt:           createdAt, +		URI:                 "http://localhost:8080/users/" + localAccount.Username + "/statuses/" + statusID, +		AccountID:           localAccount.ID, +		AccountURI:          localAccount.URI, +		Local:               util.Ptr(true), +		ThreadID:            threadID, +		Visibility:          gtsmodel.VisibilityDirect, +		ActivityStreamsType: ap.ObjectNote, +		Federated:           util.Ptr(true), +	} +	if inReplyToStatus != nil { +		status.InReplyToID = inReplyToStatus.ID +		status.InReplyToURI = inReplyToStatus.URI +		status.InReplyToAccountID = inReplyToStatus.AccountID +	} +	if err := f.db.PutStatus(context.Background(), status); err != nil { +		f.suite.FailNow(err.Error()) +	} +	return status +} + +// NewTestConversation creates a new status and adds it to a new unread conversation, returning the conversation. +func (f *ConversationFactory) NewTestConversation(localAccount *gtsmodel.Account, nowOffset time.Duration) *gtsmodel.Conversation { +	threadID := f.NewULID(nowOffset) +	status := f.NewTestStatus(localAccount, threadID, nowOffset, nil) +	conversation := >smodel.Conversation{ +		ID:        f.NewULID(nowOffset), +		AccountID: localAccount.ID, +		ThreadID:  status.ThreadID, +		Read:      util.Ptr(false), +	} +	f.SetLastStatus(conversation, status) +	return conversation +} + +// SetLastStatus sets an already stored status as the last status of a new or already stored conversation, +// and returns the updated conversation. +func (f *ConversationFactory) SetLastStatus(conversation *gtsmodel.Conversation, status *gtsmodel.Status) *gtsmodel.Conversation { +	conversation.LastStatusID = status.ID +	conversation.LastStatus = status +	if err := f.db.UpsertConversation(context.Background(), conversation, "last_status_id"); err != nil { +		f.suite.FailNow(err.Error()) +	} +	if err := f.db.LinkConversationToStatus(context.Background(), conversation.ID, status.ID); err != nil { +		f.suite.FailNow(err.Error()) +	} +	return conversation +} | 
