summaryrefslogtreecommitdiff
path: root/internal/db/bundb/application.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/application.go')
-rw-r--r--internal/db/bundb/application.go179
1 files changed, 179 insertions, 0 deletions
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
index f02632793..c2a957c93 100644
--- a/internal/db/bundb/application.go
+++ b/internal/db/bundb/application.go
@@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -95,3 +96,181 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI
return nil
}
+
+func (a *applicationDB) GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) {
+ return a.state.Caches.GTS.Client.LoadOne("ID", func() (*gtsmodel.Client, error) {
+ var client gtsmodel.Client
+
+ if err := a.db.NewSelect().
+ Model(&client).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return &client, nil
+ }, id)
+}
+
+func (a *applicationDB) PutClient(ctx context.Context, client *gtsmodel.Client) error {
+ return a.state.Caches.GTS.Client.Store(client, func() error {
+ _, err := a.db.NewInsert().Model(client).Exec(ctx)
+ return err
+ })
+}
+
+func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error {
+ _, err := a.db.NewDelete().
+ Table("clients").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Client.Invalidate("ID", id)
+ return nil
+}
+
+func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) {
+ var tokenIDs []string
+
+ // Select ALL token IDs.
+ if err := a.db.NewSelect().
+ Table("tokens").
+ Column("id").
+ Scan(ctx, &tokenIDs); err != nil {
+ return nil, err
+ }
+
+ // Load all input token IDs via cache loader callback.
+ tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID",
+ tokenIDs,
+ func(uncached []string) ([]*gtsmodel.Token, error) {
+ // Preallocate expected length of uncached tokens.
+ tokens := make([]*gtsmodel.Token, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) token IDs.
+ if err := a.db.NewSelect().
+ Model(tokens).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return tokens, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reoroder the tokens by their
+ // IDs to ensure in correct order.
+ getID := func(t *gtsmodel.Token) string { return t.ID }
+ util.OrderBy(tokens, tokenIDs, getID)
+
+ return tokens, nil
+}
+
+func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) {
+ return a.getTokenBy(
+ "Code",
+ func(t *gtsmodel.Token) error {
+ return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("code"), code).Scan(ctx)
+ },
+ code,
+ )
+}
+
+func (a *applicationDB) GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) {
+ return a.getTokenBy(
+ "Access",
+ func(t *gtsmodel.Token) error {
+ return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("access"), access).Scan(ctx)
+ },
+ access,
+ )
+}
+
+func (a *applicationDB) GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) {
+ return a.getTokenBy(
+ "Refresh",
+ func(t *gtsmodel.Token) error {
+ return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("refresh"), refresh).Scan(ctx)
+ },
+ refresh,
+ )
+}
+
+func (a *applicationDB) getTokenBy(lookup string, dbQuery func(*gtsmodel.Token) error, keyParts ...any) (*gtsmodel.Token, error) {
+ return a.state.Caches.GTS.Token.LoadOne(lookup, func() (*gtsmodel.Token, error) {
+ var token gtsmodel.Token
+
+ if err := dbQuery(&token); err != nil {
+ return nil, err
+ }
+
+ return &token, nil
+ }, keyParts...)
+}
+
+func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) error {
+ return a.state.Caches.GTS.Token.Store(token, func() error {
+ _, err := a.db.NewInsert().Model(token).Exec(ctx)
+ return err
+ })
+}
+
+func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("ID", id)
+ return nil
+}
+
+func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("code"), code).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("Code", code)
+ return nil
+}
+
+func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("access"), access).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("Access", access)
+ return nil
+}
+
+func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("refresh"), refresh).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("Refresh", refresh)
+ return nil
+}