diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/auth/token.go | 4 | ||||
| -rw-r--r-- | internal/cache/cache.go | 4 | ||||
| -rw-r--r-- | internal/cache/db.go | 70 | ||||
| -rw-r--r-- | internal/cache/invalidate.go | 10 | ||||
| -rw-r--r-- | internal/cache/size.go | 38 | ||||
| -rw-r--r-- | internal/config/config.go | 2 | ||||
| -rw-r--r-- | internal/config/defaults.go | 2 | ||||
| -rw-r--r-- | internal/config/helpers.gen.go | 50 | ||||
| -rw-r--r-- | internal/db/application.go | 36 | ||||
| -rw-r--r-- | internal/db/bundb/admin.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/application.go | 179 | ||||
| -rw-r--r-- | internal/oauth/clientstore.go | 25 | ||||
| -rw-r--r-- | internal/oauth/errors.go | 6 | ||||
| -rw-r--r-- | internal/oauth/server.go | 2 | ||||
| -rw-r--r-- | internal/oauth/tokenstore.go | 57 | ||||
| -rw-r--r-- | internal/processing/app.go | 2 | 
16 files changed, 424 insertions, 65 deletions
diff --git a/internal/api/auth/token.go b/internal/api/auth/token.go index cab9352fa..d9f0d8154 100644 --- a/internal/api/auth/token.go +++ b/internal/api/auth/token.go @@ -49,7 +49,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) {  	form := &tokenRequestForm{}  	if err := c.ShouldBind(form); err != nil { -		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error())) +		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, err.Error()))  		return  	} @@ -98,7 +98,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) {  	}  	if len(help) != 0 { -		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...)) +		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, help...))  		return  	} diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 3aa21cdd0..d35162172 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -59,6 +59,7 @@ func (c *Caches) Init() {  	c.initBlock()  	c.initBlockIDs()  	c.initBoostOfIDs() +	c.initClient()  	c.initDomainAllow()  	c.initDomainBlock()  	c.initEmoji() @@ -85,9 +86,10 @@ func (c *Caches) Init() {  	c.initReport()  	c.initStatus()  	c.initStatusFave() +	c.initStatusFaveIDs()  	c.initTag()  	c.initThreadMute() -	c.initStatusFaveIDs() +	c.initToken()  	c.initTombstone()  	c.initUser()  	c.initWebfinger() diff --git a/internal/cache/db.go b/internal/cache/db.go index c383ed6c7..cb0ed6712 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -58,6 +58,9 @@ type GTSCaches struct {  	// BoostOfIDs provides access to the boost of IDs list database cache.  	BoostOfIDs SliceCache[string] +	// Client provides access to the gtsmodel Client database cache. +	Client StructCache[*gtsmodel.Client] +  	// DomainAllow provides access to the domain allow database cache.  	DomainAllow *domain.Cache @@ -150,6 +153,9 @@ type GTSCaches struct {  	// Tag provides access to the gtsmodel Tag database cache.  	Tag StructCache[*gtsmodel.Tag] +	// Token provides access to the gtsmodel Token database cache. +	Token StructCache[*gtsmodel.Token] +  	// Tombstone provides access to the gtsmodel Tombstone database cache.  	Tombstone StructCache[*gtsmodel.Tombstone] @@ -309,9 +315,10 @@ func (c *Caches) initApplication() {  			{Fields: "ID"},  			{Fields: "ClientID"},  		}, -		MaxSize:   cap, -		IgnoreErr: ignoreErrors, -		Copy:      copyF, +		MaxSize:    cap, +		IgnoreErr:  ignoreErrors, +		Copy:       copyF, +		Invalidate: c.OnInvalidateApplication,  	})  } @@ -374,6 +381,32 @@ func (c *Caches) initBoostOfIDs() {  	c.GTS.BoostOfIDs.Init(0, cap)  } +func (c *Caches) initClient() { +	// Calculate maximum cache size. +	cap := calculateResultCacheMax( +		sizeofClient(), // model in-mem size. +		config.GetCacheClientMemRatio(), +	) + +	log.Infof(nil, "cache size = %d", cap) + +	copyF := func(c1 *gtsmodel.Client) *gtsmodel.Client { +		c2 := new(gtsmodel.Client) +		*c2 = *c1 +		return c2 +	} + +	c.GTS.Client.Init(structr.CacheConfig[*gtsmodel.Client]{ +		Indices: []structr.IndexConfig{ +			{Fields: "ID"}, +		}, +		MaxSize:    cap, +		IgnoreErr:  ignoreErrors, +		Copy:       copyF, +		Invalidate: c.OnInvalidateClient, +	}) +} +  func (c *Caches) initDomainAllow() {  	c.GTS.DomainAllow = new(domain.Cache)  } @@ -1135,7 +1168,7 @@ func (c *Caches) initTag() {  func (c *Caches) initThreadMute() {  	cap := calculateResultCacheMax( -		sizeOfThreadMute(), // model in-mem size. +		sizeofThreadMute(), // model in-mem size.  		config.GetCacheThreadMuteMemRatio(),  	) @@ -1160,6 +1193,35 @@ func (c *Caches) initThreadMute() {  	})  } +func (c *Caches) initToken() { +	// Calculate maximum cache size. +	cap := calculateResultCacheMax( +		sizeofToken(), // model in-mem size. +		config.GetCacheTokenMemRatio(), +	) + +	log.Infof(nil, "cache size = %d", cap) + +	copyF := func(t1 *gtsmodel.Token) *gtsmodel.Token { +		t2 := new(gtsmodel.Token) +		*t2 = *t1 +		return t2 +	} + +	c.GTS.Token.Init(structr.CacheConfig[*gtsmodel.Token]{ +		Indices: []structr.IndexConfig{ +			{Fields: "ID"}, +			{Fields: "Code"}, +			{Fields: "Access"}, +			{Fields: "Refresh"}, +			{Fields: "ClientID", Multiple: true}, +		}, +		MaxSize:   cap, +		IgnoreErr: ignoreErrors, +		Copy:      copyF, +	}) +} +  func (c *Caches) initTombstone() {  	// Calculate maximum cache size.  	cap := calculateResultCacheMax( diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go index 746d8c7e7..547015eac 100644 --- a/internal/cache/invalidate.go +++ b/internal/cache/invalidate.go @@ -60,6 +60,11 @@ func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) {  	c.GTS.Move.Invalidate("TargetURI", account.URI)  } +func (c *Caches) OnInvalidateApplication(app *gtsmodel.Application) { +	// Invalidate cached client of this application. +	c.GTS.Client.Invalidate("ID", app.ClientID) +} +  func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {  	// Invalidate block origin account ID cached visibility.  	c.Visibility.Invalidate("ItemID", block.AccountID) @@ -73,6 +78,11 @@ func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {  	c.GTS.BlockIDs.Invalidate(block.AccountID)  } +func (c *Caches) OnInvalidateClient(client *gtsmodel.Client) { +	// Invalidate any tokens under this client. +	c.GTS.Token.Invalidate("ClientID", client.ID) +} +  func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) {  	// Invalidate any emoji in this category.  	c.GTS.Emoji.Invalidate("CategoryID", category.ID) diff --git a/internal/cache/size.go b/internal/cache/size.go index 83b0da046..9c1a82abc 100644 --- a/internal/cache/size.go +++ b/internal/cache/size.go @@ -176,6 +176,7 @@ func totalOfRatios() float64 {  		config.GetCacheBlockMemRatio() +  		config.GetCacheBlockIDsMemRatio() +  		config.GetCacheBoostOfIDsMemRatio() + +		config.GetCacheClientMemRatio() +  		config.GetCacheEmojiMemRatio() +  		config.GetCacheEmojiCategoryMemRatio() +  		config.GetCacheFollowMemRatio() + @@ -198,6 +199,7 @@ func totalOfRatios() float64 {  		config.GetCacheStatusFaveIDsMemRatio() +  		config.GetCacheTagMemRatio() +  		config.GetCacheThreadMuteMemRatio() + +		config.GetCacheTokenMemRatio() +  		config.GetCacheTombstoneMemRatio() +  		config.GetCacheUserMemRatio() +  		config.GetCacheWebfingerMemRatio() + @@ -287,6 +289,17 @@ func sizeofBlock() uintptr {  	}))  } +func sizeofClient() uintptr { +	return uintptr(size.Of(>smodel.Client{ +		ID:        exampleID, +		CreatedAt: exampleTime, +		UpdatedAt: exampleTime, +		Secret:    exampleID, +		Domain:    exampleURI, +		UserID:    exampleID, +	})) +} +  func sizeofEmoji() uintptr {  	return uintptr(size.Of(>smodel.Emoji{  		ID:                     exampleID, @@ -591,7 +604,7 @@ func sizeofTag() uintptr {  	}))  } -func sizeOfThreadMute() uintptr { +func sizeofThreadMute() uintptr {  	return uintptr(size.Of(>smodel.ThreadMute{  		ID:        exampleID,  		CreatedAt: exampleTime, @@ -601,6 +614,29 @@ func sizeOfThreadMute() uintptr {  	}))  } +func sizeofToken() uintptr { +	return uintptr(size.Of(>smodel.Token{ +		ID:                  exampleID, +		CreatedAt:           exampleTime, +		UpdatedAt:           exampleTime, +		ClientID:            exampleID, +		UserID:              exampleID, +		RedirectURI:         exampleURI, +		Scope:               "r:w", +		Code:                "", // TODO +		CodeChallenge:       "", // TODO +		CodeChallengeMethod: "", // TODO +		CodeCreateAt:        exampleTime, +		CodeExpiresAt:       exampleTime, +		Access:              exampleID + exampleID, +		AccessCreateAt:      exampleTime, +		AccessExpiresAt:     exampleTime, +		Refresh:             "", // TODO: clients don't really support this very well yet +		RefreshCreateAt:     exampleTime, +		RefreshExpiresAt:    exampleTime, +	})) +} +  func sizeofTombstone() uintptr {  	return uintptr(size.Of(>smodel.Tombstone{  		ID:        exampleID, diff --git a/internal/config/config.go b/internal/config/config.go index dee9e99de..3cd67525f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -199,6 +199,7 @@ type CacheConfiguration struct {  	BlockMemRatio            float64       `name:"block-mem-ratio"`  	BlockIDsMemRatio         float64       `name:"block-mem-ratio"`  	BoostOfIDsMemRatio       float64       `name:"boost-of-ids-mem-ratio"` +	ClientMemRatio           float64       `name:"client-mem-ratio"`  	EmojiMemRatio            float64       `name:"emoji-mem-ratio"`  	EmojiCategoryMemRatio    float64       `name:"emoji-category-mem-ratio"`  	FilterMemRatio           float64       `name:"filter-mem-ratio"` @@ -226,6 +227,7 @@ type CacheConfiguration struct {  	StatusFaveIDsMemRatio    float64       `name:"status-fave-ids-mem-ratio"`  	TagMemRatio              float64       `name:"tag-mem-ratio"`  	ThreadMuteMemRatio       float64       `name:"thread-mute-mem-ratio"` +	TokenMemRatio            float64       `name:"token-mem-ratio"`  	TombstoneMemRatio        float64       `name:"tombstone-mem-ratio"`  	UserMemRatio             float64       `name:"user-mem-ratio"`  	WebfingerMemRatio        float64       `name:"webfinger-mem-ratio"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 64fff366a..f5f8fb6ac 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -163,6 +163,7 @@ var Defaults = Configuration{  		BlockMemRatio:            2,  		BlockIDsMemRatio:         3,  		BoostOfIDsMemRatio:       3, +		ClientMemRatio:           0.1,  		EmojiMemRatio:            3,  		EmojiCategoryMemRatio:    0.1,  		FilterMemRatio:           0.5, @@ -190,6 +191,7 @@ var Defaults = Configuration{  		StatusFaveIDsMemRatio:    3,  		TagMemRatio:              2,  		ThreadMuteMemRatio:       0.2, +		TokenMemRatio:            0.75,  		TombstoneMemRatio:        0.5,  		UserMemRatio:             0.25,  		WebfingerMemRatio:        0.1, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 39d26d13e..a8c919834 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2925,6 +2925,31 @@ func GetCacheBoostOfIDsMemRatio() float64 { return global.GetCacheBoostOfIDsMemR  // SetCacheBoostOfIDsMemRatio safely sets the value for global configuration 'Cache.BoostOfIDsMemRatio' field  func SetCacheBoostOfIDsMemRatio(v float64) { global.SetCacheBoostOfIDsMemRatio(v) } +// GetCacheClientMemRatio safely fetches the Configuration value for state's 'Cache.ClientMemRatio' field +func (st *ConfigState) GetCacheClientMemRatio() (v float64) { +	st.mutex.RLock() +	v = st.config.Cache.ClientMemRatio +	st.mutex.RUnlock() +	return +} + +// SetCacheClientMemRatio safely sets the Configuration value for state's 'Cache.ClientMemRatio' field +func (st *ConfigState) SetCacheClientMemRatio(v float64) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.ClientMemRatio = v +	st.reloadToViper() +} + +// CacheClientMemRatioFlag returns the flag name for the 'Cache.ClientMemRatio' field +func CacheClientMemRatioFlag() string { return "cache-client-mem-ratio" } + +// GetCacheClientMemRatio safely fetches the value for global configuration 'Cache.ClientMemRatio' field +func GetCacheClientMemRatio() float64 { return global.GetCacheClientMemRatio() } + +// SetCacheClientMemRatio safely sets the value for global configuration 'Cache.ClientMemRatio' field +func SetCacheClientMemRatio(v float64) { global.SetCacheClientMemRatio(v) } +  // GetCacheEmojiMemRatio safely fetches the Configuration value for state's 'Cache.EmojiMemRatio' field  func (st *ConfigState) GetCacheEmojiMemRatio() (v float64) {  	st.mutex.RLock() @@ -3600,6 +3625,31 @@ func GetCacheThreadMuteMemRatio() float64 { return global.GetCacheThreadMuteMemR  // SetCacheThreadMuteMemRatio safely sets the value for global configuration 'Cache.ThreadMuteMemRatio' field  func SetCacheThreadMuteMemRatio(v float64) { global.SetCacheThreadMuteMemRatio(v) } +// GetCacheTokenMemRatio safely fetches the Configuration value for state's 'Cache.TokenMemRatio' field +func (st *ConfigState) GetCacheTokenMemRatio() (v float64) { +	st.mutex.RLock() +	v = st.config.Cache.TokenMemRatio +	st.mutex.RUnlock() +	return +} + +// SetCacheTokenMemRatio safely sets the Configuration value for state's 'Cache.TokenMemRatio' field +func (st *ConfigState) SetCacheTokenMemRatio(v float64) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.TokenMemRatio = v +	st.reloadToViper() +} + +// CacheTokenMemRatioFlag returns the flag name for the 'Cache.TokenMemRatio' field +func CacheTokenMemRatioFlag() string { return "cache-token-mem-ratio" } + +// GetCacheTokenMemRatio safely fetches the value for global configuration 'Cache.TokenMemRatio' field +func GetCacheTokenMemRatio() float64 { return global.GetCacheTokenMemRatio() } + +// SetCacheTokenMemRatio safely sets the value for global configuration 'Cache.TokenMemRatio' field +func SetCacheTokenMemRatio(v float64) { global.SetCacheTokenMemRatio(v) } +  // GetCacheTombstoneMemRatio safely fetches the Configuration value for state's 'Cache.TombstoneMemRatio' field  func (st *ConfigState) GetCacheTombstoneMemRatio() (v float64) {  	st.mutex.RLock() diff --git a/internal/db/application.go b/internal/db/application.go index 34a857d3f..b71e593c2 100644 --- a/internal/db/application.go +++ b/internal/db/application.go @@ -35,4 +35,40 @@ type Application interface {  	// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database.  	DeleteApplicationByClientID(ctx context.Context, clientID string) error + +	// GetClientByID ... +	GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) + +	// PutClient ... +	PutClient(ctx context.Context, client *gtsmodel.Client) error + +	// DeleteClientByID ... +	DeleteClientByID(ctx context.Context, id string) error + +	// GetAllTokens ... +	GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) + +	// GetTokenByCode ... +	GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) + +	// GetTokenByAccess ... +	GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) + +	// GetTokenByRefresh ... +	GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) + +	// PutToken ... +	PutToken(ctx context.Context, token *gtsmodel.Token) error + +	// DeleteTokenByID ... +	DeleteTokenByID(ctx context.Context, id string) error + +	// DeleteTokenByCode ... +	DeleteTokenByCode(ctx context.Context, code string) error + +	// DeleteTokenByAccess ... +	DeleteTokenByAccess(ctx context.Context, access string) error + +	// DeleteTokenByRefresh ... +	DeleteTokenByRefresh(ctx context.Context, refresh string) error  } diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index e52467b9b..e9191b7c7 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -397,7 +397,7 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error {  	}  	// Store it. -	return a.state.DB.Put(ctx, oc) +	return a.state.DB.PutClient(ctx, oc)  }  func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) { diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index f02632793..c2a957c93 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -22,6 +22,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -95,3 +96,181 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI  	return nil  } + +func (a *applicationDB) GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) { +	return a.state.Caches.GTS.Client.LoadOne("ID", func() (*gtsmodel.Client, error) { +		var client gtsmodel.Client + +		if err := a.db.NewSelect(). +			Model(&client). +			Where("? = ?", bun.Ident("id"), id). +			Scan(ctx); err != nil { +			return nil, err +		} + +		return &client, nil +	}, id) +} + +func (a *applicationDB) PutClient(ctx context.Context, client *gtsmodel.Client) error { +	return a.state.Caches.GTS.Client.Store(client, func() error { +		_, err := a.db.NewInsert().Model(client).Exec(ctx) +		return err +	}) +} + +func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error { +	_, err := a.db.NewDelete(). +		Table("clients"). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx) +	if err != nil { +		return err +	} + +	a.state.Caches.GTS.Client.Invalidate("ID", id) +	return nil +} + +func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) { +	var tokenIDs []string + +	// Select ALL token IDs. +	if err := a.db.NewSelect(). +		Table("tokens"). +		Column("id"). +		Scan(ctx, &tokenIDs); err != nil { +		return nil, err +	} + +	// Load all input token IDs via cache loader callback. +	tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID", +		tokenIDs, +		func(uncached []string) ([]*gtsmodel.Token, error) { +			// Preallocate expected length of uncached tokens. +			tokens := make([]*gtsmodel.Token, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) token IDs. +			if err := a.db.NewSelect(). +				Model(tokens). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return tokens, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reoroder the tokens by their +	// IDs to ensure in correct order. +	getID := func(t *gtsmodel.Token) string { return t.ID } +	util.OrderBy(tokens, tokenIDs, getID) + +	return tokens, nil +} + +func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) { +	return a.getTokenBy( +		"Code", +		func(t *gtsmodel.Token) error { +			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("code"), code).Scan(ctx) +		}, +		code, +	) +} + +func (a *applicationDB) GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) { +	return a.getTokenBy( +		"Access", +		func(t *gtsmodel.Token) error { +			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("access"), access).Scan(ctx) +		}, +		access, +	) +} + +func (a *applicationDB) GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) { +	return a.getTokenBy( +		"Refresh", +		func(t *gtsmodel.Token) error { +			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("refresh"), refresh).Scan(ctx) +		}, +		refresh, +	) +} + +func (a *applicationDB) getTokenBy(lookup string, dbQuery func(*gtsmodel.Token) error, keyParts ...any) (*gtsmodel.Token, error) { +	return a.state.Caches.GTS.Token.LoadOne(lookup, func() (*gtsmodel.Token, error) { +		var token gtsmodel.Token + +		if err := dbQuery(&token); err != nil { +			return nil, err +		} + +		return &token, nil +	}, keyParts...) +} + +func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) error { +	return a.state.Caches.GTS.Token.Store(token, func() error { +		_, err := a.db.NewInsert().Model(token).Exec(ctx) +		return err +	}) +} + +func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { +	_, err := a.db.NewDelete(). +		Table("tokens"). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx) +	if err != nil { +		return err +	} + +	a.state.Caches.GTS.Token.Invalidate("ID", id) +	return nil +} + +func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { +	_, err := a.db.NewDelete(). +		Table("tokens"). +		Where("? = ?", bun.Ident("code"), code). +		Exec(ctx) +	if err != nil { +		return err +	} + +	a.state.Caches.GTS.Token.Invalidate("Code", code) +	return nil +} + +func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { +	_, err := a.db.NewDelete(). +		Table("tokens"). +		Where("? = ?", bun.Ident("access"), access). +		Exec(ctx) +	if err != nil { +		return err +	} + +	a.state.Caches.GTS.Token.Invalidate("Access", access) +	return nil +} + +func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { +	_, err := a.db.NewDelete(). +		Table("tokens"). +		Where("? = ?", bun.Ident("refresh"), refresh). +		Exec(ctx) +	if err != nil { +		return err +	} + +	a.state.Caches.GTS.Token.Invalidate("Refresh", refresh) +	return nil +} diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go index 5bb600e70..bddb30b1b 100644 --- a/internal/oauth/clientstore.go +++ b/internal/oauth/clientstore.go @@ -27,11 +27,11 @@ import (  )  type clientStore struct { -	db db.Basic +	db db.DB  }  // NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. -func NewClientStore(db db.Basic) oauth2.ClientStore { +func NewClientStore(db db.DB) oauth2.ClientStore {  	pts := &clientStore{  		db: db,  	} @@ -39,26 +39,27 @@ func NewClientStore(db db.Basic) oauth2.ClientStore {  }  func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { -	poc := >smodel.Client{} -	if err := cs.db.GetByID(ctx, clientID, poc); err != nil { +	client, err := cs.db.GetClientByID(ctx, clientID) +	if err != nil {  		return nil, err  	} -	return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil +	return models.New( +		client.ID, +		client.Secret, +		client.Domain, +		client.UserID, +	), nil  }  func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { -	poc := >smodel.Client{ +	return cs.db.PutClient(ctx, >smodel.Client{  		ID:     cli.GetID(),  		Secret: cli.GetSecret(),  		Domain: cli.GetDomain(),  		UserID: cli.GetUserID(), -	} -	return cs.db.Put(ctx, poc) +	})  }  func (cs *clientStore) Delete(ctx context.Context, id string) error { -	poc := >smodel.Client{ -		ID: id, -	} -	return cs.db.DeleteByID(ctx, id, poc) +	return cs.db.DeleteClientByID(ctx, id)  } diff --git a/internal/oauth/errors.go b/internal/oauth/errors.go index dd61be28c..b16143e5c 100644 --- a/internal/oauth/errors.go +++ b/internal/oauth/errors.go @@ -19,7 +19,5 @@ package oauth  import "github.com/superseriousbusiness/oauth2/v4/errors" -// InvalidRequest returns an oauth spec compliant 'invalid_request' error. -func InvalidRequest() error { -	return errors.New("invalid_request") -} +// ErrInvalidRequest is an oauth spec compliant 'invalid_request' error. +var ErrInvalidRequest = errors.New("invalid_request") diff --git a/internal/oauth/server.go b/internal/oauth/server.go index 3e4519479..4f2ed509b 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -75,7 +75,7 @@ type s struct {  }  // New returns a new oauth server that implements the Server interface -func New(ctx context.Context, database db.Basic) Server { +func New(ctx context.Context, database db.DB) Server {  	ts := newTokenStore(ctx, database)  	cs := NewClientStore(database) diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index 3658f0aa9..14b91fa06 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -20,7 +20,6 @@ package oauth  import (  	"context"  	"errors" -	"fmt"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -34,14 +33,14 @@ import (  // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.  type tokenStore struct {  	oauth2.TokenStore -	db db.Basic +	db db.DB  }  // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.  //  // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through  // the tokens in the DB once per minute and deletes any that have expired. -func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { +func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {  	ts := &tokenStore{  		db: db,  	} @@ -69,19 +68,19 @@ func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore {  func (ts *tokenStore) sweep(ctx context.Context) error {  	// select *all* tokens from the db  	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. -	tokens := new([]*gtsmodel.Token) -	if err := ts.db.GetAll(ctx, tokens); err != nil { +	tokens, err := ts.db.GetAllTokens(ctx) +	if err != nil {  		return err  	}  	// iterate through and remove expired tokens  	now := time.Now() -	for _, dbt := range *tokens { +	for _, dbt := range tokens {  		// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:  		// we only want to check if a token expired before now if the expiry time is *not zero*;  		// ie., if it's been explicity set.  		if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { -			if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { +			if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {  				return err  			}  		} @@ -107,67 +106,49 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {  		dbt.ID = dbtID  	} -	if err := ts.db.Put(ctx, dbt); err != nil { -		return fmt.Errorf("error in tokenstore create: %s", err) -	} -	return nil +	return ts.db.PutToken(ctx, dbt)  }  // RemoveByCode deletes a token from the DB based on the Code field  func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { -	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, >smodel.Token{}) +	return ts.db.DeleteTokenByCode(ctx, code)  }  // RemoveByAccess deletes a token from the DB based on the Access field  func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { -	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, >smodel.Token{}) +	return ts.db.DeleteTokenByAccess(ctx, access)  }  // RemoveByRefresh deletes a token from the DB based on the Refresh field  func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { -	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, >smodel.Token{}) +	return ts.db.DeleteTokenByRefresh(ctx, refresh)  }  // GetByCode selects a token from the DB based on the Code field  func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { -	if code == "" { -		return nil, nil -	} -	dbt := >smodel.Token{ -		Code: code, -	} -	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { +	token, err := ts.db.GetTokenByCode(ctx, code) +	if err != nil {  		return nil, err  	} -	return DBTokenToToken(dbt), nil +	return DBTokenToToken(token), nil  }  // GetByAccess selects a token from the DB based on the Access field  func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { -	if access == "" { -		return nil, nil -	} -	dbt := >smodel.Token{ -		Access: access, -	} -	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { +	token, err := ts.db.GetTokenByAccess(ctx, access) +	if err != nil {  		return nil, err  	} -	return DBTokenToToken(dbt), nil +	return DBTokenToToken(token), nil  }  // GetByRefresh selects a token from the DB based on the Refresh field  func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { -	if refresh == "" { -		return nil, nil -	} -	dbt := >smodel.Token{ -		Refresh: refresh, -	} -	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { +	token, err := ts.db.GetTokenByRefresh(ctx, refresh) +	if err != nil {  		return nil, err  	} -	return DBTokenToToken(dbt), nil +	return DBTokenToToken(token), nil  }  /* diff --git a/internal/processing/app.go b/internal/processing/app.go index eef4fae0d..d492b3bc4 100644 --- a/internal/processing/app.go +++ b/internal/processing/app.go @@ -75,7 +75,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api  	}  	// chuck it in the db -	if err := p.state.DB.Put(ctx, oc); err != nil { +	if err := p.state.DB.PutClient(ctx, oc); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	}  | 
