summaryrefslogtreecommitdiff
path: root/internal/db/bundb/bundb.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/bundb.go')
-rw-r--r--internal/db/bundb/bundb.go17
1 files changed, 12 insertions, 5 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index b944ae3ea..2fc65364f 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -67,12 +67,13 @@ const (
)
var registerTables = []interface{}{
+ &gtsmodel.AccountToEmoji{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
}
-// bunDBService satisfies the DB interface
-type bunDBService struct {
+// DBService satisfies the DB interface
+type DBService struct {
db.Account
db.Admin
db.Basic
@@ -89,6 +90,12 @@ type bunDBService struct {
conn *DBConn
}
+// GetConn returns the underlying bun connection.
+// Should only be used in testing + exceptional circumstance.
+func (dbService *DBService) GetConn() *DBConn {
+ return dbService.conn
+}
+
func doMigration(ctx context.Context, db *bun.DB) error {
migrator := migrate.NewMigrator(db, migrations.Migrations)
@@ -177,7 +184,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
// Prepare domain block cache
blockCache := cache.NewDomainBlockCache()
- ps := &bunDBService{
+ ps := &DBService{
Account: accounts,
Admin: &adminDB{
conn: conn,
@@ -399,7 +406,7 @@ func tweakConnectionValues(sqldb *sql.DB) {
CONVERSION FUNCTIONS
*/
-func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {
+func (dbService *DBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {
protocol := config.GetProtocol()
host := config.GetHost()
@@ -408,7 +415,7 @@ func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, ori
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
- if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
+ if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
if err == sql.ErrNoRows {
// tag doesn't exist yet so populate it
newID, err := id.NewRandomULID()