summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/bundb.go2
-rw-r--r--internal/db/bundb/emoji.go119
-rw-r--r--internal/db/bundb/emoji_test.go18
-rw-r--r--internal/db/bundb/migrations/20221031145649_emoji_categories.go46
-rw-r--r--internal/db/emoji.go8
5 files changed, 181 insertions, 12 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 43e9a07c9..cf6643f6b 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -180,7 +180,7 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
// Create DB structs that require ptrs to each other
accounts := &accountDB{conn: conn, cache: accountCache}
status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
- emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
+ emoji := &emojiDB{conn: conn, emojiCache: cache.NewEmojiCache(), categoryCache: cache.NewEmojiCategoryCache()}
timeline := &timelineDB{conn: conn}
tombstone := &tombstoneDB{conn: conn}
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go
index 51d767a7b..81374ce78 100644
--- a/internal/db/bundb/emoji.go
+++ b/internal/db/bundb/emoji.go
@@ -32,14 +32,22 @@ import (
)
type emojiDB struct {
- conn *DBConn
- cache *cache.EmojiCache
+ conn *DBConn
+ emojiCache *cache.EmojiCache
+ categoryCache *cache.EmojiCategoryCache
}
func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery {
return e.conn.
NewSelect().
- Model(emoji)
+ Model(emoji).
+ Relation("Category")
+}
+
+func (e *emojiDB) newEmojiCategoryQ(emojiCategory *gtsmodel.EmojiCategory) *bun.SelectQuery {
+ return e.conn.
+ NewSelect().
+ Model(emojiCategory)
}
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error {
@@ -47,7 +55,7 @@ func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error
return e.conn.ProcessError(err)
}
- e.cache.Put(emoji)
+ e.emojiCache.Put(emoji)
return nil
}
@@ -64,7 +72,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
return nil, e.conn.ProcessError(err)
}
- e.cache.Invalidate(emoji.ID)
+ e.emojiCache.Invalidate(emoji.ID)
return emoji, nil
}
@@ -101,7 +109,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
return err
}
- e.cache.Invalidate(id)
+ e.emojiCache.Invalidate(id)
return nil
}
@@ -245,7 +253,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
return e.getEmoji(
ctx,
func() (*gtsmodel.Emoji, bool) {
- return e.cache.GetByID(id)
+ return e.emojiCache.GetByID(id)
},
func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
@@ -257,7 +265,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
return e.getEmoji(
ctx,
func() (*gtsmodel.Emoji, bool) {
- return e.cache.GetByURI(uri)
+ return e.emojiCache.GetByURI(uri)
},
func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
@@ -269,7 +277,7 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
return e.getEmoji(
ctx,
func() (*gtsmodel.Emoji, bool) {
- return e.cache.GetByShortcodeDomain(shortcode, domain)
+ return e.emojiCache.GetByShortcodeDomain(shortcode, domain)
},
func(emoji *gtsmodel.Emoji) error {
q := e.newEmojiQ(emoji)
@@ -291,7 +299,7 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string
return e.getEmoji(
ctx,
func() (*gtsmodel.Emoji, bool) {
- return e.cache.GetByImageStaticURL(imageStaticURL)
+ return e.emojiCache.GetByImageStaticURL(imageStaticURL)
},
func(emoji *gtsmodel.Emoji) error {
return e.
@@ -302,6 +310,55 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string
)
}
+func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error {
+ if _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx); err != nil {
+ return e.conn.ProcessError(err)
+ }
+
+ e.categoryCache.Put(emojiCategory)
+ return nil
+}
+
+func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) {
+ emojiCategoryIDs := []string{}
+
+ q := e.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("emoji_categories"), bun.Ident("emoji_category")).
+ Column("emoji_category.id").
+ Order("emoji_category.name ASC")
+
+ if err := q.Scan(ctx, &emojiCategoryIDs); err != nil {
+ return nil, e.conn.ProcessError(err)
+ }
+
+ return e.emojiCategoriesFromIDs(ctx, emojiCategoryIDs)
+}
+
+func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) {
+ return e.getEmojiCategory(
+ ctx,
+ func() (*gtsmodel.EmojiCategory, bool) {
+ return e.categoryCache.GetByID(id)
+ },
+ func(emojiCategory *gtsmodel.EmojiCategory) error {
+ return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx)
+ },
+ )
+}
+
+func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) {
+ return e.getEmojiCategory(
+ ctx,
+ func() (*gtsmodel.EmojiCategory, bool) {
+ return e.categoryCache.GetByName(name)
+ },
+ func(emojiCategory *gtsmodel.EmojiCategory) error {
+ return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx)
+ },
+ )
+}
+
func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) {
// Attempt to fetch cached emoji
emoji, cached := cacheGet()
@@ -316,7 +373,7 @@ func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji
}
// Place in the cache
- e.cache.Put(emoji)
+ e.emojiCache.Put(emoji)
}
return emoji, nil
@@ -341,3 +398,43 @@ func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsm
return emojis, nil
}
+
+func (e *emojiDB) getEmojiCategory(ctx context.Context, cacheGet func() (*gtsmodel.EmojiCategory, bool), dbQuery func(*gtsmodel.EmojiCategory) error) (*gtsmodel.EmojiCategory, db.Error) {
+ // Attempt to fetch cached emoji categories
+ emojiCategory, cached := cacheGet()
+
+ if !cached {
+ emojiCategory = &gtsmodel.EmojiCategory{}
+
+ // Not cached! Perform database query
+ err := dbQuery(emojiCategory)
+ if err != nil {
+ return nil, e.conn.ProcessError(err)
+ }
+
+ // Place in the cache
+ e.categoryCache.Put(emojiCategory)
+ }
+
+ return emojiCategory, nil
+}
+
+func (e *emojiDB) emojiCategoriesFromIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) {
+ // Catch case of no emoji categories early
+ if len(emojiCategoryIDs) == 0 {
+ return nil, db.ErrNoEntries
+ }
+
+ emojiCategories := make([]*gtsmodel.EmojiCategory, 0, len(emojiCategoryIDs))
+
+ for _, id := range emojiCategoryIDs {
+ emojiCategory, err := e.GetEmojiCategory(ctx, id)
+ if err != nil {
+ log.Errorf("emojiCategoriesFromIDs: error getting emoji category %q: %v", id, err)
+ }
+
+ emojiCategories = append(emojiCategories, emojiCategory)
+ }
+
+ return emojiCategories, nil
+}
diff --git a/internal/db/bundb/emoji_test.go b/internal/db/bundb/emoji_test.go
index b542f9b67..786d41e5d 100644
--- a/internal/db/bundb/emoji_test.go
+++ b/internal/db/bundb/emoji_test.go
@@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/testrig"
)
type EmojiTestSuite struct {
@@ -54,6 +55,8 @@ func (suite *EmojiTestSuite) TestGetEmojiByStaticURL() {
suite.NoError(err)
suite.NotNil(emoji)
suite.Equal("rainbow", emoji.Shortcode)
+ suite.NotNil(emoji.Category)
+ suite.Equal("reactions", emoji.Category.Name)
}
func (suite *EmojiTestSuite) TestGetAllEmojis() {
@@ -143,6 +146,21 @@ func (suite *EmojiTestSuite) TestGetSpecificEmojisFromDomain2() {
suite.Equal("yell", emojis[0].Shortcode)
}
+func (suite *EmojiTestSuite) TestGetEmojiCategories() {
+ categories, err := suite.db.GetEmojiCategories(context.Background())
+ suite.NoError(err)
+ suite.Len(categories, 2)
+ // check alphabetical order
+ suite.Equal(categories[0].Name, "cute stuff")
+ suite.Equal(categories[1].Name, "reactions")
+}
+
+func (suite *EmojiTestSuite) TestGetEmojiCategory() {
+ category, err := suite.db.GetEmojiCategory(context.Background(), testrig.NewTestEmojiCategories()["reactions"].ID)
+ suite.NoError(err)
+ suite.NotNil(category)
+}
+
func TestEmojiTestSuite(t *testing.T) {
suite.Run(t, new(EmojiTestSuite))
}
diff --git a/internal/db/bundb/migrations/20221031145649_emoji_categories.go b/internal/db/bundb/migrations/20221031145649_emoji_categories.go
new file mode 100644
index 000000000..02e4a1f3a
--- /dev/null
+++ b/internal/db/bundb/migrations/20221031145649_emoji_categories.go
@@ -0,0 +1,46 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ if _, err := db.NewCreateTable().Model(&gtsmodel.EmojiCategory{}).IfNotExists().Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/emoji.go b/internal/db/emoji.go
index d2f66a377..267213b2d 100644
--- a/internal/db/emoji.go
+++ b/internal/db/emoji.go
@@ -50,4 +50,12 @@ type Emoji interface {
GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, Error)
// GetEmojiByStaticURL gets an emoji using the URL of the static version of the emoji image.
GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, Error)
+ // PutEmojiCategory puts one new emoji category in the database.
+ PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) Error
+ // GetEmojiCategories gets a slice of the names of all existing emoji categories.
+ GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, Error)
+ // GetEmojiCategory gets one emoji category by its id.
+ GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, Error)
+ // GetEmojiCategoryByName gets one emoji category by its name.
+ GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, Error)
}