diff options
author | 2021-08-29 15:41:41 +0100 | |
---|---|---|
committer | 2021-08-29 16:41:41 +0200 | |
commit | ed462245730bd7832019bd43e0bc1c9d1c055e8e (patch) | |
tree | 1caad78ea6aabf5ea93c93a8ade97176b4889500 /internal | |
parent | Mention fixup (#167) (diff) | |
download | gotosocial-ed462245730bd7832019bd43e0bc1c9d1c055e8e.tar.xz |
Add SQLite support, fix un-thread-safe DB caches, small performance f… (#172)
* Add SQLite support, fix un-thread-safe DB caches, small performance fixes
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* add SQLite licenses to README
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* appease the linter, and fix my dumbass-ery
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* make requested changes
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* add back comment
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
Diffstat (limited to 'internal')
32 files changed, 604 insertions, 416 deletions
diff --git a/internal/api/client/account/sqlite-test.db b/internal/api/client/account/sqlite-test.db Binary files differnew file mode 100644 index 000000000..eab8315d9 --- /dev/null +++ b/internal/api/client/account/sqlite-test.db diff --git a/internal/api/client/fileserver/sqlite-test.db b/internal/api/client/fileserver/sqlite-test.db Binary files differnew file mode 100644 index 000000000..5689e7edb --- /dev/null +++ b/internal/api/client/fileserver/sqlite-test.db diff --git a/internal/api/client/media/sqlite-test.db b/internal/api/client/media/sqlite-test.db Binary files differnew file mode 100644 index 000000000..1ed985248 --- /dev/null +++ b/internal/api/client/media/sqlite-test.db diff --git a/internal/api/client/status/sqlite-test.db b/internal/api/client/status/sqlite-test.db Binary files differnew file mode 100644 index 000000000..448d10813 --- /dev/null +++ b/internal/api/client/status/sqlite-test.db diff --git a/internal/api/s2s/user/sqlite-test.db b/internal/api/s2s/user/sqlite-test.db Binary files differnew file mode 100644 index 000000000..b67967b30 --- /dev/null +++ b/internal/api/s2s/user/sqlite-test.db diff --git a/internal/cache/status.go b/internal/cache/status.go new file mode 100644 index 000000000..895a5692c --- /dev/null +++ b/internal/cache/status.go @@ -0,0 +1,106 @@ +package cache + +import ( + "sync" + + "github.com/ReneKroon/ttlcache" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// statusCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Status +type StatusCache struct { + cache *ttlcache.Cache // map of IDs -> cached statuses + urls map[string]string // map of status URLs -> IDs + uris map[string]string // map of status URIs -> IDs + mutex sync.Mutex +} + +// newStatusCache returns a new instantiated statusCache object +func NewStatusCache() *StatusCache { + c := StatusCache{ + cache: ttlcache.NewCache(), + urls: make(map[string]string, 100), + uris: make(map[string]string, 100), + mutex: sync.Mutex{}, + } + + // Set callback to purge lookup maps on expiration + c.cache.SetExpirationCallback(func(key string, value interface{}) { + status := value.(*gtsmodel.Status) + + c.mutex.Lock() + delete(c.urls, status.URL) + delete(c.uris, status.URI) + c.mutex.Unlock() + }) + + return &c +} + +// GetByID attempts to fetch a status from the cache by its ID +func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) { + c.mutex.Lock() + status, ok := c.getByID(id) + c.mutex.Unlock() + return status, ok +} + +// GetByURL attempts to fetch a status from the cache by its URL +func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) { + // Perform safe ID lookup + c.mutex.Lock() + id, ok := c.urls[url] + + // Not found, unlock early + if !ok { + c.mutex.Unlock() + return nil, false + } + + // Attempt status lookup + status, ok := c.getByID(id) + c.mutex.Unlock() + return status, ok +} + +// GetByURI attempts to fetch a status from the cache by its URI +func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) { + // Perform safe ID lookup + c.mutex.Lock() + id, ok := c.uris[uri] + + // Not found, unlock early + if !ok { + c.mutex.Unlock() + return nil, false + } + + // Attempt status lookup + status, ok := c.getByID(id) + c.mutex.Unlock() + return status, ok +} + +// getByID performs an unsafe (no mutex locks) lookup of status by ID +func (c *StatusCache) getByID(id string) (*gtsmodel.Status, bool) { + v, ok := c.cache.Get(id) + if !ok { + return nil, false + } + return v.(*gtsmodel.Status), true +} + +// Put places a status in the cache +func (c *StatusCache) Put(status *gtsmodel.Status) { + if status == nil || status.ID == "" || + status.URL == "" || + status.URI == "" { + panic("invalid status") + } + + c.mutex.Lock() + c.cache.Set(status.ID, status) + c.urls[status.URL] = status.ID + c.uris[status.URI] = status.ID + c.mutex.Unlock() +} diff --git a/internal/cache/status_test.go b/internal/cache/status_test.go new file mode 100644 index 000000000..10dee5bca --- /dev/null +++ b/internal/cache/status_test.go @@ -0,0 +1,41 @@ +package cache_test + +import ( + "testing" + + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func TestStatusCache(t *testing.T) { + cache := cache.NewStatusCache() + + // Attempt to place a status + status := gtsmodel.Status{ + ID: "id", + URI: "uri", + URL: "url", + } + cache.Put(&status) + + var ok bool + var check *gtsmodel.Status + + // Check we can retrieve + check, ok = cache.GetByID(status.ID) + if !ok || !statusIs(&status, check) { + t.Fatal("Could not find expected status") + } + check, ok = cache.GetByURI(status.URI) + if !ok || !statusIs(&status, check) { + t.Fatal("Could not find expected status") + } + check, ok = cache.GetByURL(status.URL) + if !ok || !statusIs(&status, check) { + t.Fatal("Could not find expected status") + } +} + +func statusIs(status1, status2 *gtsmodel.Status) bool { + return status1.ID == status2.ID && status1.URI == status2.URI && status1.URL == status2.URL +} diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index aef1f3281..d7d45a739 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,7 +25,6 @@ import ( "strings" "time" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,8 +33,7 @@ import ( type accountDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -52,9 +50,11 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac q := a.newAccountQ(account). Where("account.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { @@ -63,9 +63,11 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. q := a.newAccountQ(account). Where("account.uri = ?", uri) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { @@ -74,9 +76,11 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel. q := a.newAccountQ(account). Where("account.url = ?", uri) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { @@ -92,10 +96,10 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account WherePK() _, err := q.Exec(ctx) - - err = processErrorResponse(err) - - return account, err + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { @@ -113,9 +117,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts WhereGroup(" AND ", whereEmptyOrNull("domain")) } - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) { @@ -129,9 +135,11 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) Where("account_id = ?", accountID). Column("created_at") - err := processErrorResponse(q.Scan(ctx)) - - return status.CreatedAt, err + err := q.Scan(ctx) + if err != nil { + return time.Time{}, a.conn.ProcessError(err) + } + return status.CreatedAt, nil } func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { @@ -153,17 +161,17 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen NewInsert(). Model(mediaAttachment). Exec(ctx); err != nil { - return err + return a.conn.ProcessError(err) } - if _, err := a.conn. NewUpdate(). Model(>smodel.Account{}). Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). Where("id = ?", accountID). Exec(ctx); err != nil { - return err + return a.conn.ProcessError(err) } + return nil } @@ -174,9 +182,11 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri Where("username = ?", username). WhereGroup(" AND ", whereEmptyOrNull("domain")) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { @@ -187,8 +197,9 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g Model(faves). Where("account_id = ?", accountID). Scan(ctx); err != nil { - return nil, err + return nil, a.conn.ProcessError(err) } + return *faves, nil } @@ -201,7 +212,6 @@ func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) } func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { - a.log.Debugf("getting statuses for account %s", accountID) statuses := []*gtsmodel.Status{} q := a.conn. @@ -238,14 +248,13 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if err := q.Scan(ctx); err != nil { - return nil, err + return nil, a.conn.ProcessError(err) } if len(statuses) == 0 { return nil, db.ErrNoEntries } - a.log.Debugf("returning statuses for account %s", accountID) return statuses, nil } @@ -273,7 +282,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI err := fq.Scan(ctx) if err != nil { - return nil, "", "", err + return nil, "", "", a.conn.ProcessError(err) } if len(blocks) == 0 { diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index d43501444..6a51ffeb1 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -29,20 +29,17 @@ import ( "strings" "time" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/util" - "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) type adminDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { @@ -52,7 +49,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo Where("username = ?", username). Where("domain = ?", nil) - return notExists(ctx, q) + return a.conn.NotExists(ctx, q) } func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) { @@ -72,7 +69,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. // fail because we found something return false, fmt.Errorf("email domain %s is blocked", domain) } else if err != sql.ErrNoRows { - return false, processErrorResponse(err) + return false, a.conn.ProcessError(err) } // check if this email is associated with a user already @@ -82,13 +79,13 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. Where("email = ?", email). WhereOr("unconfirmed_email = ?", email) - return notExists(ctx, q) + return a.conn.NotExists(ctx, q) } func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - a.log.Errorf("error creating new rsa key: %s", err) + a.conn.log.Errorf("error creating new rsa key: %s", err) return nil, err } @@ -128,7 +125,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, NewInsert(). Model(acct). Exec(ctx); err != nil { - return nil, err + return nil, a.conn.ProcessError(err) } } @@ -167,7 +164,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, NewInsert(). Model(u). Exec(ctx); err != nil { - return nil, err + return nil, a.conn.ProcessError(err) } return u, nil @@ -184,15 +181,15 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { WhereGroup(" AND ", whereEmptyOrNull("domain")) count, err := existsQ.Count(ctx) if err != nil && count == 1 { - a.log.Infof("instance account %s already exists", username) + a.conn.log.Infof("instance account %s already exists", username) return nil } else if err != sql.ErrNoRows { - return processErrorResponse(err) + return a.conn.ProcessError(err) } key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - a.log.Errorf("error creating new rsa key: %s", err) + a.conn.log.Errorf("error creating new rsa key: %s", err) return err } @@ -224,10 +221,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { Model(acct) if _, err := insertQ.Exec(ctx); err != nil { - return err + return a.conn.ProcessError(err) } - a.log.Infof("instance account %s CREATED with id %s", username, acct.ID) + a.conn.log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } @@ -240,12 +237,12 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { Model(>smodel.Instance{}). Where("domain = ?", domain) - exists, err := exists(ctx, q) + exists, err := a.conn.Exists(ctx, q) if err != nil { return err } if exists { - a.log.Infof("instance entry already exists") + a.conn.log.Infof("instance entry already exists") return nil } @@ -266,10 +263,10 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { Model(i) _, err = insertQ.Exec(ctx) - err = processErrorResponse(err) - - if err == nil { - a.log.Infof("created instance instance %s with id %s", domain, i.ID) + if err != nil { + return a.conn.ProcessError(err) } - return err + + a.conn.log.Infof("created instance instance %s with id %s", domain, i.ID) + return nil } diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 983b6b810..a3a8d0ae9 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -21,9 +21,7 @@ package bundb import ( "context" "errors" - "strings" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/uptrace/bun" @@ -31,16 +29,12 @@ import ( type basicDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { _, err := b.conn.NewInsert().Model(i).Exec(ctx) - if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err + return b.conn.ProcessError(err) } func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error { @@ -49,7 +43,8 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro Model(i). Where("id = ?", id) - return processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + return b.conn.ProcessError(err) } func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { @@ -59,7 +54,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) q := b.conn.NewSelect().Model(i) for _, w := range where { - if w.Value == nil { q = q.Where("? IS NULL", bun.Ident(w.Key)) } else { @@ -71,7 +65,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) } } - return processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + return b.conn.ProcessError(err) } func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { @@ -79,7 +74,8 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { NewSelect(). Model(i) - return processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + return b.conn.ProcessError(err) } func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { @@ -89,8 +85,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E Where("id = ?", id) _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { @@ -107,8 +102,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface } _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error { @@ -118,8 +112,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E WherePK() _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error { @@ -129,8 +122,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu WherePK() _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { @@ -151,8 +143,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, q = q.Set("? = ?", bun.Safe(key), value) _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { @@ -162,7 +153,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) IsHealthy(ctx context.Context) db.Error { @@ -170,10 +161,6 @@ func (b *basicDB) IsHealthy(ctx context.Context) db.Error { } func (b *basicDB) Stop(ctx context.Context) db.Error { - b.log.Info("closing db connection") - if err := b.conn.Close(); err != nil { - // only cancel if there's a problem closing the db - return err - } - return nil + b.conn.log.Info("closing db connection") + return b.conn.Close() } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 49ed09cbd..248232fe3 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -30,15 +30,19 @@ import ( "strings" "time" + "github.com/ReneKroon/ttlcache" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" + "github.com/uptrace/bun/dialect/sqlitedialect" + _ "modernc.org/sqlite" ) const ( @@ -66,15 +70,14 @@ type bunDBService struct { db.Status db.Timeline config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { var sqldb *sql.DB - var conn *bun.DB + var conn *DBConn // depending on the database type we're trying to create, we need to use a different driver... switch strings.ToLower(c.DBConfig.Type) { @@ -85,10 +88,24 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) return nil, fmt.Errorf("could not create bundb postgres options: %s", err) } sqldb = stdlib.OpenDB(*opts) - conn = bun.NewDB(sqldb, pgdialect.New()) + conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log) case dbTypeSqlite: // SQLITE - // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite + var err error + sqldb, err = sql.Open("sqlite", c.DBConfig.Address) + if err != nil { + return nil, fmt.Errorf("could not open sqlite db: %s", err) + } + conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log) + + if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") { + log.Warn("sqlite in-memory database should only be used for debugging") + + // don't close connections on disconnect -- otherwise + // the SQLite database will be deleted when there + // are no active connections + sqldb.SetConnMaxLifetime(0) + } default: return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) } @@ -108,66 +125,56 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) Account: &accountDB{ config: c, conn: conn, - log: log, }, Admin: &adminDB{ config: c, conn: conn, - log: log, }, Basic: &basicDB{ config: c, conn: conn, - log: log, }, Domain: &domainDB{ config: c, conn: conn, - log: log, }, Instance: &instanceDB{ config: c, conn: conn, - log: log, }, Media: &mediaDB{ config: c, conn: conn, - log: log, }, Mention: &mentionDB{ config: c, conn: conn, - log: log, + cache: ttlcache.NewCache(), }, Notification: ¬ificationDB{ config: c, conn: conn, - log: log, + cache: ttlcache.NewCache(), }, Relationship: &relationshipDB{ config: c, conn: conn, - log: log, }, Session: &sessionDB{ config: c, conn: conn, - log: log, }, Status: &statusDB{ config: c, conn: conn, - log: log, + cache: cache.NewStatusCache(), }, Timeline: &timelineDB{ config: c, conn: conn, - log: log, }, config: c, conn: conn, - log: log, } // we can confidently return this useable service now @@ -332,7 +339,7 @@ func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAcco if err != nil { if err == sql.ErrNoRows { // no result found for this username/domain so just don't include it as a mencho and carry on about our business - ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain) + ps.conn.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain) continue } // a serious error has happened so bail @@ -398,7 +405,7 @@ func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []strin if err != nil { if err == sql.ErrNoRows { // no result found for this username/domain so just don't include it as an emoji and carry on about our business - ps.log.Debugf("no emoji found with shortcode %s, skipping it", e) + ps.conn.log.Debugf("no emoji found with shortcode %s, skipping it", e) continue } // a serious error has happened so bail diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go new file mode 100644 index 000000000..698adff3d --- /dev/null +++ b/internal/db/bundb/conn.go @@ -0,0 +1,72 @@ +package bundb + +import ( + "context" + "database/sql" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" +) + +// dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality +type DBConn struct { + errProc func(error) db.Error // errProc is the SQL-type specific error processor + log *logrus.Logger // log is the logger passed with this DBConn + *bun.DB // DB is the underlying bun.DB connection +} + +// WrapDBConn @TODO +func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn { + var errProc func(error) db.Error + switch dbConn.Dialect().Name() { + case dialect.PG: + errProc = processPostgresError + case dialect.SQLite: + errProc = processSQLiteError + default: + panic("unknown dialect name: " + dbConn.Dialect().Name().String()) + } + return &DBConn{ + errProc: errProc, + log: log, + DB: dbConn, + } +} + +// ProcessError processes an error to replace any known values with our own db.Error types, +// making it easier to catch specific situations (e.g. no rows, already exists, etc) +func (conn *DBConn) ProcessError(err error) db.Error { + switch { + case err == nil: + return nil + case err == sql.ErrNoRows: + return db.ErrNoEntries + default: + return conn.errProc(err) + } +} + +// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors +func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { + // Get the select query result + count, err := query.Count(ctx) + + // Process error as our own and check if it exists + switch err := conn.ProcessError(err); err { + case nil: + return (count != 0), nil + case db.ErrNoEntries: + return false, nil + default: + return false, err + } +} + +// NotExists is the functional opposite of conn.Exists() +func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { + // Simply inverse of conn.exists() + exists, err := conn.Exists(ctx, query) + return !exists, err +} diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 6aa2b8ffe..5cb98e87e 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -22,18 +22,15 @@ import ( "context" "net/url" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" - "github.com/uptrace/bun" ) type domainDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { @@ -47,7 +44,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db Where("LOWER(domain) = LOWER(?)", domain). Limit(1) - return exists(ctx, q) + return d.conn.Exists(ctx, q) } func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { diff --git a/internal/db/bundb/errors.go b/internal/db/bundb/errors.go new file mode 100644 index 000000000..7602d5e1d --- /dev/null +++ b/internal/db/bundb/errors.go @@ -0,0 +1,43 @@ +package bundb + +import ( + "github.com/jackc/pgconn" + "github.com/superseriousbusiness/gotosocial/internal/db" + "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" +) + +// processPostgresError processes an error, replacing any postgres specific errors with our own error type +func processPostgresError(err error) db.Error { + // Attempt to cast as postgres + pgErr, ok := err.(*pgconn.PgError) + if !ok { + return err + } + + // Handle supplied error code: + // (https://www.postgresql.org/docs/10/errcodes-appendix.html) + switch pgErr.Code { + case "23505" /* unique_violation */ : + return db.ErrAlreadyExists + default: + return err + } +} + +// processSQLiteError processes an error, replacing any sqlite specific errors with our own error type +func processSQLiteError(err error) db.Error { + // Attempt to cast as sqlite + sqliteErr, ok := err.(*sqlite.Error) + if !ok { + return err + } + + // Handle supplied error code: + switch sqliteErr.Code() { + case sqlite3.SQLITE_CONSTRAINT_UNIQUE: + return db.ErrAlreadyExists + default: + return err + } +} diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 2813e7e1d..4e26fc7c4 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -21,7 +21,6 @@ package bundb import ( "context" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -30,8 +29,7 @@ import ( type instanceDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { @@ -49,8 +47,10 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int } count, err := q.Count(ctx) - - return count, processErrorResponse(err) + if err != nil { + return 0, i.conn.ProcessError(err) + } + return count, nil } func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { @@ -68,8 +68,10 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( } count, err := q.Count(ctx) - - return count, processErrorResponse(err) + if err != nil { + return 0, i.conn.ProcessError(err) + } + return count, nil } func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { @@ -89,12 +91,14 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i } count, err := q.Count(ctx) - - return count, processErrorResponse(err) + if err != nil { + return 0, i.conn.ProcessError(err) + } + return count, nil } func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { - i.log.Debug("GetAccountsForInstance") + i.conn.log.Debug("GetAccountsForInstance") accounts := []*gtsmodel.Account{} @@ -111,7 +115,9 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max q = q.Limit(limit) } - err := processErrorResponse(q.Scan(ctx)) - - return accounts, err + err := q.Scan(ctx) + if err != nil { + return nil, i.conn.ProcessError(err) + } + return accounts, nil } diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 04e55ca62..3c9ee587d 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -21,7 +21,6 @@ package bundb import ( "context" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -30,8 +29,7 @@ import ( type mediaDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery { @@ -47,7 +45,9 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M q := m.newMediaQ(attachment). Where("media_attachment.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) - - return attachment, err + err := q.Scan(ctx) + if err != nil { + return nil, m.conn.ProcessError(err) + } + return attachment, nil } diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index a444f9b5f..3c2c64cfd 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -21,8 +21,7 @@ package bundb import ( "context" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/ReneKroon/ttlcache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -31,38 +30,8 @@ import ( type mentionDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger - cache cache.Cache -} - -func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) { - if m.cache == nil { - m.cache = cache.New() - } - - if err := m.cache.Store(id, mention); err != nil { - m.log.Panicf("mentionDB: error storing in cache: %s", err) - } -} - -func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) { - if m.cache == nil { - m.cache = cache.New() - return nil, false - } - - mI, err := m.cache.Fetch(id) - if err != nil || mI == nil { - return nil, false - } - - mention, ok := mI.(*gtsmodel.Mention) - if !ok { - m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id) - } - - return mention, true + conn *DBConn + cache *ttlcache.Cache } func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { @@ -74,33 +43,57 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { Relation("TargetAccount") } -func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { - if mention, cached := m.mentionCached(id); cached { - return mention, nil +func (m *mentionDB) getMentionCached(id string) (*gtsmodel.Mention, bool) { + v, ok := m.cache.Get(id) + if !ok { + return nil, false } + return v.(*gtsmodel.Mention), true +} + +func (m *mentionDB) putMentionCache(mention *gtsmodel.Mention) { + m.cache.Set(mention.ID, mention) +} +func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { mention := >smodel.Mention{} q := m.newMentionQ(mention). Where("mention.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) - - if err == nil && mention != nil { - m.cacheMention(id, mention) + err := q.Scan(ctx) + if err != nil { + return nil, m.conn.ProcessError(err) } - return mention, err + m.putMentionCache(mention) + return mention, nil +} + +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { + if mention, cached := m.getMentionCached(id); cached { + return mention, nil + } + return m.getMentionDB(ctx, id) } func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { - mentions := []*gtsmodel.Mention{} + mentions := make([]*gtsmodel.Mention, 0, len(ids)) + + for _, id := range ids { + // Attempt fetch from cache + mention, cached := m.getMentionCached(id) + if cached { + mentions = append(mentions, mention) + } - for _, i := range ids { - mention, err := m.GetMention(ctx, i) + // Attempt fetch from DB + mention, err := m.getMentionDB(ctx, id) if err != nil { - return nil, processErrorResponse(err) + return nil, err } + + // Append mention mentions = append(mentions, mention) } diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 1c30837ec..d3be16168 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -21,8 +21,7 @@ package bundb import ( "context" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/ReneKroon/ttlcache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -31,38 +30,8 @@ import ( type notificationDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger - cache cache.Cache -} - -func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) { - if n.cache == nil { - n.cache = cache.New() - } - - if err := n.cache.Store(id, notification); err != nil { - n.log.Panicf("notificationDB: error storing in cache: %s", err) - } -} - -func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) { - if n.cache == nil { - n.cache = cache.New() - return nil, false - } - - nI, err := n.cache.Fetch(id) - if err != nil || nI == nil { - return nil, false - } - - notification, ok := nI.(*gtsmodel.Notification) - if !ok { - n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id) - } - - return notification, true + conn *DBConn + cache *ttlcache.Cache } func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery { @@ -75,30 +44,30 @@ func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery { } func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { - if notification, cached := n.notificationCached(id); cached { + if notification, cached := n.getNotificationCache(id); cached { return notification, nil } - notification := >smodel.Notification{} - - q := n.newNotificationQ(notification). - Where("notification.id = ?", id) - - err := processErrorResponse(q.Scan(ctx)) - - if err == nil && notification != nil { - n.cacheNotification(id, notification) + notif := >smodel.Notification{} + err := n.getNotificationDB(ctx, id, notif) + if err != nil { + return nil, err } - - return notification, err + return notif, nil } func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { - // begin by selecting just the IDs - notifIDs := []*gtsmodel.Notification{} + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make a guess for slice size + notifications := make([]*gtsmodel.Notification, 0, limit) + q := n.conn. NewSelect(). - Model(¬ifIDs). + Model(¬ifications). Column("id"). Where("target_account_id = ?", accountID). Order("id DESC") @@ -115,22 +84,52 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, q = q.Limit(limit) } - err := processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) if err != nil { - return nil, err + return nil, n.conn.ProcessError(err) } // now we have the IDs, select the notifs one by one // reason for this is that for each notif, we can instead get it from our cache if it's cached - notifications := []*gtsmodel.Notification{} - for _, notifID := range notifIDs { - notif, err := n.GetNotification(ctx, notifID.ID) - errP := processErrorResponse(err) - if errP != nil { - return nil, errP + for i, notif := range notifications { + // Check cache for notification + nn, cached := n.getNotificationCache(notif.ID) + if cached { + notifications[i] = nn + continue + } + + // Check DB for notification + err := n.getNotificationDB(ctx, notif.ID, notif) + if err != nil { + return nil, err } - notifications = append(notifications, notif) } return notifications, nil } + +func (n *notificationDB) getNotificationCache(id string) (*gtsmodel.Notification, bool) { + v, ok := n.cache.Get(id) + if !ok { + return nil, false + } + return v.(*gtsmodel.Notification), true +} + +func (n *notificationDB) putNotificationCache(notif *gtsmodel.Notification) { + n.cache.Set(notif.ID, notif) +} + +func (n *notificationDB) getNotificationDB(ctx context.Context, id string, dst *gtsmodel.Notification) error { + q := n.newNotificationQ(dst). + Where("notification.id = ?", id) + + err := q.Scan(ctx) + if err != nil { + return n.conn.ProcessError(err) + } + + n.putNotificationCache(dst) + return nil +} diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 95426f122..56b752593 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -23,7 +23,6 @@ import ( "database/sql" "fmt" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -32,8 +31,7 @@ import ( type relationshipDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery { @@ -66,7 +64,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account Where("account_id = ?", account2) } - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { @@ -76,9 +74,11 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 Where("block.account_id = ?", account1). Where("block.target_account_id = ?", account2) - err := processErrorResponse(q.Scan(ctx)) - - return block, err + err := q.Scan(ctx) + if err != nil { + return nil, r.conn.ProcessError(err) + } + return block, nil } func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { @@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode Where("target_account_id = ?", targetAccount.ID). Limit(1) - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { @@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g Where("account_id = ?", sourceAccount.ID). Where("target_account_id = ?", targetAccount.ID) - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { @@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod // make sure account 1 follows account 2 f1, err := r.IsFollowing(ctx, account1, account2) if err != nil { - return false, processErrorResponse(err) + return false, err } // make sure account 2 follows account 1 f2, err := r.IsFollowing(ctx, account2, account1) if err != nil { - return false, processErrorResponse(err) + return false, err } return f1 && f2, nil @@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Where("account_id = ?", originAccountID). Where("target_account_id = ?", targetAccountID). Scan(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } // create a new follow to 'replace' the request with @@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Model(follow). On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). Exec(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } // now remove the follow request @@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Where("account_id = ?", originAccountID). Where("target_account_id = ?", targetAccountID). Exec(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } return follow, nil @@ -261,9 +261,11 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID q := r.newFollowQ(&followRequests). Where("target_account_id = ?", accountID) - err := processErrorResponse(q.Scan(ctx)) - - return followRequests, err + err := q.Scan(ctx) + if err != nil { + return nil, r.conn.ProcessError(err) + } + return followRequests, nil } func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) { @@ -272,9 +274,11 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string q := r.newFollowQ(&follows). Where("account_id = ?", accountID) - err := processErrorResponse(q.Scan(ctx)) - - return follows, err + err := q.Scan(ctx) + if err != nil { + return nil, r.conn.ProcessError(err) + } + return follows, nil } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { @@ -286,7 +290,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri } func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { - follows := []*gtsmodel.Follow{} q := r.conn. @@ -302,11 +305,9 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str q = q.Where("target_account_id = ?", accountID) } - if err := q.Scan(ctx); err != nil { - if err == sql.ErrNoRows { - return follows, nil - } - return nil, processErrorResponse(err) + err := q.Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, r.conn.ProcessError(err) } return follows, nil } diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go index 55efd21d4..c8b09ec86 100644 --- a/internal/db/bundb/session.go +++ b/internal/db/bundb/session.go @@ -23,22 +23,19 @@ import ( "crypto/rand" "errors" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/uptrace/bun" ) type sessionDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { - rss := []*gtsmodel.RouterSession{} + rss := make([]*gtsmodel.RouterSession, 0, 1) _, err := s.conn. NewSelect(). @@ -47,7 +44,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db Order("id DESC"). Exec(ctx) if err != nil { - return nil, processErrorResponse(err) + return nil, s.conn.ProcessError(err) } if len(rss) <= 0 { @@ -92,8 +89,8 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, Model(rs) _, err = q.Exec(ctx) - - err = processErrorResponse(err) - - return rs, err + if err != nil { + return nil, s.conn.ProcessError(err) + } + return rs, nil } diff --git a/internal/db/bundb/sqlite-test.db b/internal/db/bundb/sqlite-test.db Binary files differnew file mode 100644 index 000000000..ed3b25ee3 --- /dev/null +++ b/internal/db/bundb/sqlite-test.db diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 2019322ac..1d5acf0fc 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -24,7 +24,6 @@ import ( "errors" "time" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -34,38 +33,8 @@ import ( type statusDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger - cache cache.Cache -} - -func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { - if s.cache == nil { - s.cache = cache.New() - } - - if err := s.cache.Store(id, status); err != nil { - s.log.Panicf("statusDB: error storing in cache: %s", err) - } -} - -func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { - if s.cache == nil { - s.cache = cache.New() - return nil, false - } - - sI, err := s.cache.Fetch(id) - if err != nil || sI == nil { - return nil, false - } - - status, ok := sI.(*gtsmodel.Status) - if !ok { - s.log.Panicf("statusDB: cached interface with key %s was not a status", id) - } - - return status, true + conn *DBConn + cache *cache.StatusCache } func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { @@ -84,7 +53,9 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { if status.InReplyToID != "" && status.InReplyTo == nil { - if inReplyTo, cached := s.statusCached(status.InReplyToID); cached { + // TODO: do we want to keep this possibly recursive strategy? + + if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached { status.InReplyTo = inReplyTo } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { status.InReplyTo = inReplyTo @@ -92,7 +63,9 @@ func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Sta } if status.BoostOfID != "" && status.BoostOf == nil { - if boostOf, cached := s.statusCached(status.BoostOfID); cached { + // TODO: do we want to keep this possibly recursive strategy? + + if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached { status.BoostOf = boostOf } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { status.BoostOf = boostOf @@ -112,29 +85,26 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { } func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(id); cached { + if status, cached := s.cache.GetByID(id); cached { return status, nil } - status := new(gtsmodel.Status) + status := >smodel.Status{} q := s.newStatusQ(status). Where("status.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) if err != nil { - return nil, err - } - - if status != nil { - s.cacheStatus(id, status) + return nil, s.conn.ProcessError(err) } - return s.getAttachedStatuses(ctx, status), err + s.cache.Put(status) + return s.getAttachedStatuses(ctx, status), nil } func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { + if status, cached := s.cache.GetByURI(uri); cached { return status, nil } @@ -143,38 +113,32 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St q := s.newStatusQ(status). Where("LOWER(status.uri) = LOWER(?)", uri) - err := processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) if err != nil { - return nil, err + return nil, s.conn.ProcessError(err) } - if status != nil { - s.cacheStatus(uri, status) - } - - return s.getAttachedStatuses(ctx, status), err + s.cache.Put(status) + return s.getAttachedStatuses(ctx, status), nil } -func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { +func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { + if status, cached := s.cache.GetByURL(url); cached { return status, nil } status := >smodel.Status{} q := s.newStatusQ(status). - Where("LOWER(status.url) = LOWER(?)", uri) + Where("LOWER(status.url) = LOWER(?)", url) - err := processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) if err != nil { - return nil, err - } - - if status != nil { - s.cacheStatus(uri, status) + return nil, s.conn.ProcessError(err) } - return s.getAttachedStatuses(ctx, status), err + s.cache.Put(status) + return s.getAttachedStatuses(ctx, status), nil } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { @@ -213,14 +177,12 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er _, err := tx.NewInsert().Model(status).Exec(ctx) return err } - - return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction)) + return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction)) } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { parents := []*gtsmodel.Status{} s.statusParent(ctx, status, &parents, onlyDirect) - return parents, nil } @@ -318,7 +280,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -328,7 +290,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta Where("boost_of_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -338,7 +300,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -348,7 +310,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { @@ -357,8 +319,11 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) q := s.newFaveQ(&faves). Where("status_id = ?", status.ID) - err := processErrorResponse(q.Scan(ctx)) - return faves, err + err := q.Scan(ctx) + if err != nil { + return nil, s.conn.ProcessError(err) + } + return faves, nil } func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { @@ -367,6 +332,9 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status q := s.newStatusQ(&reblogs). Where("boost_of_id = ?", status.ID) - err := processErrorResponse(q.Scan(ctx)) - return reblogs, err + err := q.Scan(ctx) + if err != nil { + return nil, s.conn.ProcessError(err) + } + return reblogs, nil } diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index 513000577..4f846441b 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -59,7 +59,6 @@ func (suite *StatusTestSuite) TearDownTest() { func (suite *StatusTestSuite) TestGetStatusByID() { status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID) if err != nil { - fmt.Println(err.Error()) suite.FailNow(err.Error()) } suite.NotNil(status) diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index cd202f436..0c4619ae8 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -23,7 +23,6 @@ import ( "database/sql" "sort" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -32,12 +31,18 @@ import ( type timelineDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { - statuses := []*gtsmodel.Status{} + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + statuses := make([]*gtsmodel.Status, 0, limit) + q := t.conn. NewSelect(). Model(&statuses) @@ -86,11 +91,21 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI q = q.WhereGroup(" AND ", whereGroup) - return statuses, processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + if err != nil { + return nil, t.conn.ProcessError(err) + } + return statuses, nil } func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { - statuses := []*gtsmodel.Status{} + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + statuses := make([]*gtsmodel.Status, 0, limit) q := t.conn. NewSelect(). @@ -121,14 +136,23 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma q = q.Limit(limit) } - return statuses, processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + if err != nil { + return nil, t.conn.ProcessError(err) + } + return statuses, nil } // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } - faves := []*gtsmodel.StatusFave{} + // Make educated guess for slice size + faves := make([]*gtsmodel.StatusFave, 0, limit) fq := t.conn. NewSelect(). @@ -160,26 +184,23 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max return nil, "", "", db.ErrNoEntries } - // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID - statusesFavesMap := map[string]string{} - - in := []string{} + // map[statusID]faveID -- we need this to sort statuses by fave ID rather than status ID + statusesFavesMap := make(map[string]string, len(faves)) + statusIDs := make([]string, 0, len(faves)) for _, f := range faves { statusesFavesMap[f.StatusID] = f.ID - in = append(in, f.StatusID) + statusIDs = append(statusIDs, f.StatusID) } - statuses := []*gtsmodel.Status{} + statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) + err = t.conn. NewSelect(). Model(&statuses). - Where("id IN (?)", bun.In(in)). + Where("id IN (?)", bun.In(statusIDs)). Scan(ctx) if err != nil { - if err == sql.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } - return nil, "", "", err + return nil, "", "", t.conn.ProcessError(err) } if len(statuses) == 0 { diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index faa80221f..9e1afb87e 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -19,64 +19,9 @@ package bundb import ( - "context" - "strings" - - "database/sql" - - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/uptrace/bun" ) -// processErrorResponse parses the given error and returns an appropriate DBError. -func processErrorResponse(err error) db.Error { - switch err { - case nil: - return nil - case sql.ErrNoRows: - return db.ErrNoEntries - default: - if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err - } -} - -func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { - count, err := q.Count(ctx) - - exists := count != 0 - - err = processErrorResponse(err) - - if err != nil { - if err == db.ErrNoEntries { - return false, nil - } - return false, err - } - - return exists, nil -} - -func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { - count, err := q.Count(ctx) - - notExists := count == 0 - - err = processErrorResponse(err) - - if err != nil { - if err == db.ErrNoEntries { - return true, nil - } - return false, err - } - - return notExists, nil -} - // whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies // that the given column should be EITHER an empty string OR null. // diff --git a/internal/federation/dereferencing/sqlite-test.db b/internal/federation/dereferencing/sqlite-test.db Binary files differnew file mode 100644 index 000000000..bef45b3af --- /dev/null +++ b/internal/federation/dereferencing/sqlite-test.db diff --git a/internal/federation/sqlite-test.db b/internal/federation/sqlite-test.db Binary files differnew file mode 100644 index 000000000..d34adbfe9 --- /dev/null +++ b/internal/federation/sqlite-test.db diff --git a/internal/oauth/sqlite-test.db b/internal/oauth/sqlite-test.db Binary files differnew file mode 100644 index 000000000..429e3d860 --- /dev/null +++ b/internal/oauth/sqlite-test.db diff --git a/internal/processing/status/sqlite-test.db b/internal/processing/status/sqlite-test.db Binary files differnew file mode 100644 index 000000000..d266d6b1d --- /dev/null +++ b/internal/processing/status/sqlite-test.db diff --git a/internal/text/sqlite-test.db b/internal/text/sqlite-test.db Binary files differnew file mode 100644 index 000000000..08b0a8909 --- /dev/null +++ b/internal/text/sqlite-test.db diff --git a/internal/timeline/sqlite-test.db b/internal/timeline/sqlite-test.db Binary files differnew file mode 100644 index 000000000..224027d43 --- /dev/null +++ b/internal/timeline/sqlite-test.db diff --git a/internal/typeutils/sqlite-test.db b/internal/typeutils/sqlite-test.db Binary files differnew file mode 100644 index 000000000..2775172f1 --- /dev/null +++ b/internal/typeutils/sqlite-test.db |