summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/gotosocial/action/testrig/testrig.go4
-rw-r--r--internal/api/auth/token.go4
-rw-r--r--internal/cache/cache.go4
-rw-r--r--internal/cache/db.go70
-rw-r--r--internal/cache/invalidate.go10
-rw-r--r--internal/cache/size.go38
-rw-r--r--internal/config/config.go2
-rw-r--r--internal/config/defaults.go2
-rw-r--r--internal/config/helpers.gen.go50
-rw-r--r--internal/db/application.go36
-rw-r--r--internal/db/bundb/admin.go2
-rw-r--r--internal/db/bundb/application.go179
-rw-r--r--internal/oauth/clientstore.go25
-rw-r--r--internal/oauth/errors.go6
-rw-r--r--internal/oauth/server.go2
-rw-r--r--internal/oauth/tokenstore.go57
-rw-r--r--internal/processing/app.go2
-rwxr-xr-xtest/envparsing.sh2
18 files changed, 428 insertions, 67 deletions
diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go
index 0769b8878..79361375b 100644
--- a/cmd/gotosocial/action/testrig/testrig.go
+++ b/cmd/gotosocial/action/testrig/testrig.go
@@ -98,8 +98,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
testrig.StandardStorageSetup(state.Storage, "./testrig/media")
// Initialize workers.
- state.Workers.Start()
- defer state.Workers.Stop()
+ testrig.StartNoopWorkers(&state)
+ defer testrig.StopWorkers(&state)
// build backend handlers
transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
diff --git a/internal/api/auth/token.go b/internal/api/auth/token.go
index cab9352fa..d9f0d8154 100644
--- a/internal/api/auth/token.go
+++ b/internal/api/auth/token.go
@@ -49,7 +49,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) {
form := &tokenRequestForm{}
if err := c.ShouldBind(form); err != nil {
- apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error()))
+ apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, err.Error()))
return
}
@@ -98,7 +98,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) {
}
if len(help) != 0 {
- apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...))
+ apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, help...))
return
}
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index 3aa21cdd0..d35162172 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -59,6 +59,7 @@ func (c *Caches) Init() {
c.initBlock()
c.initBlockIDs()
c.initBoostOfIDs()
+ c.initClient()
c.initDomainAllow()
c.initDomainBlock()
c.initEmoji()
@@ -85,9 +86,10 @@ func (c *Caches) Init() {
c.initReport()
c.initStatus()
c.initStatusFave()
+ c.initStatusFaveIDs()
c.initTag()
c.initThreadMute()
- c.initStatusFaveIDs()
+ c.initToken()
c.initTombstone()
c.initUser()
c.initWebfinger()
diff --git a/internal/cache/db.go b/internal/cache/db.go
index c383ed6c7..cb0ed6712 100644
--- a/internal/cache/db.go
+++ b/internal/cache/db.go
@@ -58,6 +58,9 @@ type GTSCaches struct {
// BoostOfIDs provides access to the boost of IDs list database cache.
BoostOfIDs SliceCache[string]
+ // Client provides access to the gtsmodel Client database cache.
+ Client StructCache[*gtsmodel.Client]
+
// DomainAllow provides access to the domain allow database cache.
DomainAllow *domain.Cache
@@ -150,6 +153,9 @@ type GTSCaches struct {
// Tag provides access to the gtsmodel Tag database cache.
Tag StructCache[*gtsmodel.Tag]
+ // Token provides access to the gtsmodel Token database cache.
+ Token StructCache[*gtsmodel.Token]
+
// Tombstone provides access to the gtsmodel Tombstone database cache.
Tombstone StructCache[*gtsmodel.Tombstone]
@@ -309,9 +315,10 @@ func (c *Caches) initApplication() {
{Fields: "ID"},
{Fields: "ClientID"},
},
- MaxSize: cap,
- IgnoreErr: ignoreErrors,
- Copy: copyF,
+ MaxSize: cap,
+ IgnoreErr: ignoreErrors,
+ Copy: copyF,
+ Invalidate: c.OnInvalidateApplication,
})
}
@@ -374,6 +381,32 @@ func (c *Caches) initBoostOfIDs() {
c.GTS.BoostOfIDs.Init(0, cap)
}
+func (c *Caches) initClient() {
+ // Calculate maximum cache size.
+ cap := calculateResultCacheMax(
+ sizeofClient(), // model in-mem size.
+ config.GetCacheClientMemRatio(),
+ )
+
+ log.Infof(nil, "cache size = %d", cap)
+
+ copyF := func(c1 *gtsmodel.Client) *gtsmodel.Client {
+ c2 := new(gtsmodel.Client)
+ *c2 = *c1
+ return c2
+ }
+
+ c.GTS.Client.Init(structr.CacheConfig[*gtsmodel.Client]{
+ Indices: []structr.IndexConfig{
+ {Fields: "ID"},
+ },
+ MaxSize: cap,
+ IgnoreErr: ignoreErrors,
+ Copy: copyF,
+ Invalidate: c.OnInvalidateClient,
+ })
+}
+
func (c *Caches) initDomainAllow() {
c.GTS.DomainAllow = new(domain.Cache)
}
@@ -1135,7 +1168,7 @@ func (c *Caches) initTag() {
func (c *Caches) initThreadMute() {
cap := calculateResultCacheMax(
- sizeOfThreadMute(), // model in-mem size.
+ sizeofThreadMute(), // model in-mem size.
config.GetCacheThreadMuteMemRatio(),
)
@@ -1160,6 +1193,35 @@ func (c *Caches) initThreadMute() {
})
}
+func (c *Caches) initToken() {
+ // Calculate maximum cache size.
+ cap := calculateResultCacheMax(
+ sizeofToken(), // model in-mem size.
+ config.GetCacheTokenMemRatio(),
+ )
+
+ log.Infof(nil, "cache size = %d", cap)
+
+ copyF := func(t1 *gtsmodel.Token) *gtsmodel.Token {
+ t2 := new(gtsmodel.Token)
+ *t2 = *t1
+ return t2
+ }
+
+ c.GTS.Token.Init(structr.CacheConfig[*gtsmodel.Token]{
+ Indices: []structr.IndexConfig{
+ {Fields: "ID"},
+ {Fields: "Code"},
+ {Fields: "Access"},
+ {Fields: "Refresh"},
+ {Fields: "ClientID", Multiple: true},
+ },
+ MaxSize: cap,
+ IgnoreErr: ignoreErrors,
+ Copy: copyF,
+ })
+}
+
func (c *Caches) initTombstone() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go
index 746d8c7e7..547015eac 100644
--- a/internal/cache/invalidate.go
+++ b/internal/cache/invalidate.go
@@ -60,6 +60,11 @@ func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) {
c.GTS.Move.Invalidate("TargetURI", account.URI)
}
+func (c *Caches) OnInvalidateApplication(app *gtsmodel.Application) {
+ // Invalidate cached client of this application.
+ c.GTS.Client.Invalidate("ID", app.ClientID)
+}
+
func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {
// Invalidate block origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.AccountID)
@@ -73,6 +78,11 @@ func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {
c.GTS.BlockIDs.Invalidate(block.AccountID)
}
+func (c *Caches) OnInvalidateClient(client *gtsmodel.Client) {
+ // Invalidate any tokens under this client.
+ c.GTS.Token.Invalidate("ClientID", client.ID)
+}
+
func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) {
// Invalidate any emoji in this category.
c.GTS.Emoji.Invalidate("CategoryID", category.ID)
diff --git a/internal/cache/size.go b/internal/cache/size.go
index 83b0da046..9c1a82abc 100644
--- a/internal/cache/size.go
+++ b/internal/cache/size.go
@@ -176,6 +176,7 @@ func totalOfRatios() float64 {
config.GetCacheBlockMemRatio() +
config.GetCacheBlockIDsMemRatio() +
config.GetCacheBoostOfIDsMemRatio() +
+ config.GetCacheClientMemRatio() +
config.GetCacheEmojiMemRatio() +
config.GetCacheEmojiCategoryMemRatio() +
config.GetCacheFollowMemRatio() +
@@ -198,6 +199,7 @@ func totalOfRatios() float64 {
config.GetCacheStatusFaveIDsMemRatio() +
config.GetCacheTagMemRatio() +
config.GetCacheThreadMuteMemRatio() +
+ config.GetCacheTokenMemRatio() +
config.GetCacheTombstoneMemRatio() +
config.GetCacheUserMemRatio() +
config.GetCacheWebfingerMemRatio() +
@@ -287,6 +289,17 @@ func sizeofBlock() uintptr {
}))
}
+func sizeofClient() uintptr {
+ return uintptr(size.Of(&gtsmodel.Client{
+ ID: exampleID,
+ CreatedAt: exampleTime,
+ UpdatedAt: exampleTime,
+ Secret: exampleID,
+ Domain: exampleURI,
+ UserID: exampleID,
+ }))
+}
+
func sizeofEmoji() uintptr {
return uintptr(size.Of(&gtsmodel.Emoji{
ID: exampleID,
@@ -591,7 +604,7 @@ func sizeofTag() uintptr {
}))
}
-func sizeOfThreadMute() uintptr {
+func sizeofThreadMute() uintptr {
return uintptr(size.Of(&gtsmodel.ThreadMute{
ID: exampleID,
CreatedAt: exampleTime,
@@ -601,6 +614,29 @@ func sizeOfThreadMute() uintptr {
}))
}
+func sizeofToken() uintptr {
+ return uintptr(size.Of(&gtsmodel.Token{
+ ID: exampleID,
+ CreatedAt: exampleTime,
+ UpdatedAt: exampleTime,
+ ClientID: exampleID,
+ UserID: exampleID,
+ RedirectURI: exampleURI,
+ Scope: "r:w",
+ Code: "", // TODO
+ CodeChallenge: "", // TODO
+ CodeChallengeMethod: "", // TODO
+ CodeCreateAt: exampleTime,
+ CodeExpiresAt: exampleTime,
+ Access: exampleID + exampleID,
+ AccessCreateAt: exampleTime,
+ AccessExpiresAt: exampleTime,
+ Refresh: "", // TODO: clients don't really support this very well yet
+ RefreshCreateAt: exampleTime,
+ RefreshExpiresAt: exampleTime,
+ }))
+}
+
func sizeofTombstone() uintptr {
return uintptr(size.Of(&gtsmodel.Tombstone{
ID: exampleID,
diff --git a/internal/config/config.go b/internal/config/config.go
index dee9e99de..3cd67525f 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -199,6 +199,7 @@ type CacheConfiguration struct {
BlockMemRatio float64 `name:"block-mem-ratio"`
BlockIDsMemRatio float64 `name:"block-mem-ratio"`
BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"`
+ ClientMemRatio float64 `name:"client-mem-ratio"`
EmojiMemRatio float64 `name:"emoji-mem-ratio"`
EmojiCategoryMemRatio float64 `name:"emoji-category-mem-ratio"`
FilterMemRatio float64 `name:"filter-mem-ratio"`
@@ -226,6 +227,7 @@ type CacheConfiguration struct {
StatusFaveIDsMemRatio float64 `name:"status-fave-ids-mem-ratio"`
TagMemRatio float64 `name:"tag-mem-ratio"`
ThreadMuteMemRatio float64 `name:"thread-mute-mem-ratio"`
+ TokenMemRatio float64 `name:"token-mem-ratio"`
TombstoneMemRatio float64 `name:"tombstone-mem-ratio"`
UserMemRatio float64 `name:"user-mem-ratio"`
WebfingerMemRatio float64 `name:"webfinger-mem-ratio"`
diff --git a/internal/config/defaults.go b/internal/config/defaults.go
index 64fff366a..f5f8fb6ac 100644
--- a/internal/config/defaults.go
+++ b/internal/config/defaults.go
@@ -163,6 +163,7 @@ var Defaults = Configuration{
BlockMemRatio: 2,
BlockIDsMemRatio: 3,
BoostOfIDsMemRatio: 3,
+ ClientMemRatio: 0.1,
EmojiMemRatio: 3,
EmojiCategoryMemRatio: 0.1,
FilterMemRatio: 0.5,
@@ -190,6 +191,7 @@ var Defaults = Configuration{
StatusFaveIDsMemRatio: 3,
TagMemRatio: 2,
ThreadMuteMemRatio: 0.2,
+ TokenMemRatio: 0.75,
TombstoneMemRatio: 0.5,
UserMemRatio: 0.25,
WebfingerMemRatio: 0.1,
diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go
index 39d26d13e..a8c919834 100644
--- a/internal/config/helpers.gen.go
+++ b/internal/config/helpers.gen.go
@@ -2925,6 +2925,31 @@ func GetCacheBoostOfIDsMemRatio() float64 { return global.GetCacheBoostOfIDsMemR
// SetCacheBoostOfIDsMemRatio safely sets the value for global configuration 'Cache.BoostOfIDsMemRatio' field
func SetCacheBoostOfIDsMemRatio(v float64) { global.SetCacheBoostOfIDsMemRatio(v) }
+// GetCacheClientMemRatio safely fetches the Configuration value for state's 'Cache.ClientMemRatio' field
+func (st *ConfigState) GetCacheClientMemRatio() (v float64) {
+ st.mutex.RLock()
+ v = st.config.Cache.ClientMemRatio
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheClientMemRatio safely sets the Configuration value for state's 'Cache.ClientMemRatio' field
+func (st *ConfigState) SetCacheClientMemRatio(v float64) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.ClientMemRatio = v
+ st.reloadToViper()
+}
+
+// CacheClientMemRatioFlag returns the flag name for the 'Cache.ClientMemRatio' field
+func CacheClientMemRatioFlag() string { return "cache-client-mem-ratio" }
+
+// GetCacheClientMemRatio safely fetches the value for global configuration 'Cache.ClientMemRatio' field
+func GetCacheClientMemRatio() float64 { return global.GetCacheClientMemRatio() }
+
+// SetCacheClientMemRatio safely sets the value for global configuration 'Cache.ClientMemRatio' field
+func SetCacheClientMemRatio(v float64) { global.SetCacheClientMemRatio(v) }
+
// GetCacheEmojiMemRatio safely fetches the Configuration value for state's 'Cache.EmojiMemRatio' field
func (st *ConfigState) GetCacheEmojiMemRatio() (v float64) {
st.mutex.RLock()
@@ -3600,6 +3625,31 @@ func GetCacheThreadMuteMemRatio() float64 { return global.GetCacheThreadMuteMemR
// SetCacheThreadMuteMemRatio safely sets the value for global configuration 'Cache.ThreadMuteMemRatio' field
func SetCacheThreadMuteMemRatio(v float64) { global.SetCacheThreadMuteMemRatio(v) }
+// GetCacheTokenMemRatio safely fetches the Configuration value for state's 'Cache.TokenMemRatio' field
+func (st *ConfigState) GetCacheTokenMemRatio() (v float64) {
+ st.mutex.RLock()
+ v = st.config.Cache.TokenMemRatio
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheTokenMemRatio safely sets the Configuration value for state's 'Cache.TokenMemRatio' field
+func (st *ConfigState) SetCacheTokenMemRatio(v float64) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.TokenMemRatio = v
+ st.reloadToViper()
+}
+
+// CacheTokenMemRatioFlag returns the flag name for the 'Cache.TokenMemRatio' field
+func CacheTokenMemRatioFlag() string { return "cache-token-mem-ratio" }
+
+// GetCacheTokenMemRatio safely fetches the value for global configuration 'Cache.TokenMemRatio' field
+func GetCacheTokenMemRatio() float64 { return global.GetCacheTokenMemRatio() }
+
+// SetCacheTokenMemRatio safely sets the value for global configuration 'Cache.TokenMemRatio' field
+func SetCacheTokenMemRatio(v float64) { global.SetCacheTokenMemRatio(v) }
+
// GetCacheTombstoneMemRatio safely fetches the Configuration value for state's 'Cache.TombstoneMemRatio' field
func (st *ConfigState) GetCacheTombstoneMemRatio() (v float64) {
st.mutex.RLock()
diff --git a/internal/db/application.go b/internal/db/application.go
index 34a857d3f..b71e593c2 100644
--- a/internal/db/application.go
+++ b/internal/db/application.go
@@ -35,4 +35,40 @@ type Application interface {
// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database.
DeleteApplicationByClientID(ctx context.Context, clientID string) error
+
+ // GetClientByID ...
+ GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error)
+
+ // PutClient ...
+ PutClient(ctx context.Context, client *gtsmodel.Client) error
+
+ // DeleteClientByID ...
+ DeleteClientByID(ctx context.Context, id string) error
+
+ // GetAllTokens ...
+ GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error)
+
+ // GetTokenByCode ...
+ GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error)
+
+ // GetTokenByAccess ...
+ GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error)
+
+ // GetTokenByRefresh ...
+ GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error)
+
+ // PutToken ...
+ PutToken(ctx context.Context, token *gtsmodel.Token) error
+
+ // DeleteTokenByID ...
+ DeleteTokenByID(ctx context.Context, id string) error
+
+ // DeleteTokenByCode ...
+ DeleteTokenByCode(ctx context.Context, code string) error
+
+ // DeleteTokenByAccess ...
+ DeleteTokenByAccess(ctx context.Context, access string) error
+
+ // DeleteTokenByRefresh ...
+ DeleteTokenByRefresh(ctx context.Context, refresh string) error
}
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index e52467b9b..e9191b7c7 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -397,7 +397,7 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error {
}
// Store it.
- return a.state.DB.Put(ctx, oc)
+ return a.state.DB.PutClient(ctx, oc)
}
func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) {
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
index f02632793..c2a957c93 100644
--- a/internal/db/bundb/application.go
+++ b/internal/db/bundb/application.go
@@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -95,3 +96,181 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI
return nil
}
+
+func (a *applicationDB) GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) {
+ return a.state.Caches.GTS.Client.LoadOne("ID", func() (*gtsmodel.Client, error) {
+ var client gtsmodel.Client
+
+ if err := a.db.NewSelect().
+ Model(&client).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return &client, nil
+ }, id)
+}
+
+func (a *applicationDB) PutClient(ctx context.Context, client *gtsmodel.Client) error {
+ return a.state.Caches.GTS.Client.Store(client, func() error {
+ _, err := a.db.NewInsert().Model(client).Exec(ctx)
+ return err
+ })
+}
+
+func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error {
+ _, err := a.db.NewDelete().
+ Table("clients").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Client.Invalidate("ID", id)
+ return nil
+}
+
+func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) {
+ var tokenIDs []string
+
+ // Select ALL token IDs.
+ if err := a.db.NewSelect().
+ Table("tokens").
+ Column("id").
+ Scan(ctx, &tokenIDs); err != nil {
+ return nil, err
+ }
+
+ // Load all input token IDs via cache loader callback.
+ tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID",
+ tokenIDs,
+ func(uncached []string) ([]*gtsmodel.Token, error) {
+ // Preallocate expected length of uncached tokens.
+ tokens := make([]*gtsmodel.Token, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) token IDs.
+ if err := a.db.NewSelect().
+ Model(tokens).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return tokens, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reoroder the tokens by their
+ // IDs to ensure in correct order.
+ getID := func(t *gtsmodel.Token) string { return t.ID }
+ util.OrderBy(tokens, tokenIDs, getID)
+
+ return tokens, nil
+}
+
+func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) {
+ return a.getTokenBy(
+ "Code",
+ func(t *gtsmodel.Token) error {
+ return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("code"), code).Scan(ctx)
+ },
+ code,
+ )
+}
+
+func (a *applicationDB) GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) {
+ return a.getTokenBy(
+ "Access",
+ func(t *gtsmodel.Token) error {
+ return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("access"), access).Scan(ctx)
+ },
+ access,
+ )
+}
+
+func (a *applicationDB) GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) {
+ return a.getTokenBy(
+ "Refresh",
+ func(t *gtsmodel.Token) error {
+ return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("refresh"), refresh).Scan(ctx)
+ },
+ refresh,
+ )
+}
+
+func (a *applicationDB) getTokenBy(lookup string, dbQuery func(*gtsmodel.Token) error, keyParts ...any) (*gtsmodel.Token, error) {
+ return a.state.Caches.GTS.Token.LoadOne(lookup, func() (*gtsmodel.Token, error) {
+ var token gtsmodel.Token
+
+ if err := dbQuery(&token); err != nil {
+ return nil, err
+ }
+
+ return &token, nil
+ }, keyParts...)
+}
+
+func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) error {
+ return a.state.Caches.GTS.Token.Store(token, func() error {
+ _, err := a.db.NewInsert().Model(token).Exec(ctx)
+ return err
+ })
+}
+
+func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("ID", id)
+ return nil
+}
+
+func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("code"), code).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("Code", code)
+ return nil
+}
+
+func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("access"), access).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("Access", access)
+ return nil
+}
+
+func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error {
+ _, err := a.db.NewDelete().
+ Table("tokens").
+ Where("? = ?", bun.Ident("refresh"), refresh).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ a.state.Caches.GTS.Token.Invalidate("Refresh", refresh)
+ return nil
+}
diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go
index 5bb600e70..bddb30b1b 100644
--- a/internal/oauth/clientstore.go
+++ b/internal/oauth/clientstore.go
@@ -27,11 +27,11 @@ import (
)
type clientStore struct {
- db db.Basic
+ db db.DB
}
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
-func NewClientStore(db db.Basic) oauth2.ClientStore {
+func NewClientStore(db db.DB) oauth2.ClientStore {
pts := &clientStore{
db: db,
}
@@ -39,26 +39,27 @@ func NewClientStore(db db.Basic) oauth2.ClientStore {
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
- poc := &gtsmodel.Client{}
- if err := cs.db.GetByID(ctx, clientID, poc); err != nil {
+ client, err := cs.db.GetClientByID(ctx, clientID)
+ if err != nil {
return nil, err
}
- return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
+ return models.New(
+ client.ID,
+ client.Secret,
+ client.Domain,
+ client.UserID,
+ ), nil
}
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
- poc := &gtsmodel.Client{
+ return cs.db.PutClient(ctx, &gtsmodel.Client{
ID: cli.GetID(),
Secret: cli.GetSecret(),
Domain: cli.GetDomain(),
UserID: cli.GetUserID(),
- }
- return cs.db.Put(ctx, poc)
+ })
}
func (cs *clientStore) Delete(ctx context.Context, id string) error {
- poc := &gtsmodel.Client{
- ID: id,
- }
- return cs.db.DeleteByID(ctx, id, poc)
+ return cs.db.DeleteClientByID(ctx, id)
}
diff --git a/internal/oauth/errors.go b/internal/oauth/errors.go
index dd61be28c..b16143e5c 100644
--- a/internal/oauth/errors.go
+++ b/internal/oauth/errors.go
@@ -19,7 +19,5 @@ package oauth
import "github.com/superseriousbusiness/oauth2/v4/errors"
-// InvalidRequest returns an oauth spec compliant 'invalid_request' error.
-func InvalidRequest() error {
- return errors.New("invalid_request")
-}
+// ErrInvalidRequest is an oauth spec compliant 'invalid_request' error.
+var ErrInvalidRequest = errors.New("invalid_request")
diff --git a/internal/oauth/server.go b/internal/oauth/server.go
index 3e4519479..4f2ed509b 100644
--- a/internal/oauth/server.go
+++ b/internal/oauth/server.go
@@ -75,7 +75,7 @@ type s struct {
}
// New returns a new oauth server that implements the Server interface
-func New(ctx context.Context, database db.Basic) Server {
+func New(ctx context.Context, database db.DB) Server {
ts := newTokenStore(ctx, database)
cs := NewClientStore(database)
diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go
index 3658f0aa9..14b91fa06 100644
--- a/internal/oauth/tokenstore.go
+++ b/internal/oauth/tokenstore.go
@@ -20,7 +20,6 @@ package oauth
import (
"context"
"errors"
- "fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -34,14 +33,14 @@ import (
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct {
oauth2.TokenStore
- db db.Basic
+ db db.DB
}
// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
//
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
-func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore {
+func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
ts := &tokenStore{
db: db,
}
@@ -69,19 +68,19 @@ func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore {
func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
- tokens := new([]*gtsmodel.Token)
- if err := ts.db.GetAll(ctx, tokens); err != nil {
+ tokens, err := ts.db.GetAllTokens(ctx)
+ if err != nil {
return err
}
// iterate through and remove expired tokens
now := time.Now()
- for _, dbt := range *tokens {
+ for _, dbt := range tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
- if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
+ if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {
return err
}
}
@@ -107,67 +106,49 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
dbt.ID = dbtID
}
- if err := ts.db.Put(ctx, dbt); err != nil {
- return fmt.Errorf("error in tokenstore create: %s", err)
- }
- return nil
+ return ts.db.PutToken(ctx, dbt)
}
// RemoveByCode deletes a token from the DB based on the Code field
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
- return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &gtsmodel.Token{})
+ return ts.db.DeleteTokenByCode(ctx, code)
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
- return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &gtsmodel.Token{})
+ return ts.db.DeleteTokenByAccess(ctx, access)
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
- return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &gtsmodel.Token{})
+ return ts.db.DeleteTokenByRefresh(ctx, refresh)
}
// GetByCode selects a token from the DB based on the Code field
func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
- if code == "" {
- return nil, nil
- }
- dbt := &gtsmodel.Token{
- Code: code,
- }
- if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil {
+ token, err := ts.db.GetTokenByCode(ctx, code)
+ if err != nil {
return nil, err
}
- return DBTokenToToken(dbt), nil
+ return DBTokenToToken(token), nil
}
// GetByAccess selects a token from the DB based on the Access field
func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
- if access == "" {
- return nil, nil
- }
- dbt := &gtsmodel.Token{
- Access: access,
- }
- if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil {
+ token, err := ts.db.GetTokenByAccess(ctx, access)
+ if err != nil {
return nil, err
}
- return DBTokenToToken(dbt), nil
+ return DBTokenToToken(token), nil
}
// GetByRefresh selects a token from the DB based on the Refresh field
func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
- if refresh == "" {
- return nil, nil
- }
- dbt := &gtsmodel.Token{
- Refresh: refresh,
- }
- if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil {
+ token, err := ts.db.GetTokenByRefresh(ctx, refresh)
+ if err != nil {
return nil, err
}
- return DBTokenToToken(dbt), nil
+ return DBTokenToToken(token), nil
}
/*
diff --git a/internal/processing/app.go b/internal/processing/app.go
index eef4fae0d..d492b3bc4 100644
--- a/internal/processing/app.go
+++ b/internal/processing/app.go
@@ -75,7 +75,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
}
// chuck it in the db
- if err := p.state.DB.Put(ctx, oc); err != nil {
+ if err := p.state.DB.PutClient(ctx, oc); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/test/envparsing.sh b/test/envparsing.sh
index 19b86a818..a379750c0 100755
--- a/test/envparsing.sh
+++ b/test/envparsing.sh
@@ -29,6 +29,7 @@ EXPECT=$(cat << "EOF"
"application-mem-ratio": 0.1,
"block-mem-ratio": 3,
"boost-of-ids-mem-ratio": 3,
+ "client-mem-ratio": 0.1,
"emoji-category-mem-ratio": 0.1,
"emoji-mem-ratio": 3,
"filter-keyword-mem-ratio": 0.5,
@@ -57,6 +58,7 @@ EXPECT=$(cat << "EOF"
"status-mem-ratio": 5,
"tag-mem-ratio": 2,
"thread-mute-mem-ratio": 0.2,
+ "token-mem-ratio": 0.75,
"tombstone-mem-ratio": 0.5,
"user-mem-ratio": 0.25,
"visibility-mem-ratio": 2,