diff options
29 files changed, 1277 insertions, 329 deletions
@@ -5,7 +5,7 @@ go 1.20  require (  	codeberg.org/gruf/go-bytesize v1.0.2  	codeberg.org/gruf/go-byteutil v1.1.2 -	codeberg.org/gruf/go-cache/v3 v3.4.3 +	codeberg.org/gruf/go-cache/v3 v3.4.4  	codeberg.org/gruf/go-debug v1.3.0  	codeberg.org/gruf/go-errors/v2 v2.2.0  	codeberg.org/gruf/go-fastcopy v1.1.2 @@ -48,8 +48,8 @@ codeberg.org/gruf/go-bytesize v1.0.2/go.mod h1:n/GU8HzL9f3UNp/mUKyr1qVmTlj7+xacp  codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=  codeberg.org/gruf/go-byteutil v1.1.2 h1:TQLZtTxTNca9xEfDIndmo7nBYxeS94nrv/9DS3Nk5Tw=  codeberg.org/gruf/go-byteutil v1.1.2/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU= -codeberg.org/gruf/go-cache/v3 v3.4.3 h1:GTNq01M17jUJ3B3ehrVTbElpvCqOKgz1x+VB9GEIxXA= -codeberg.org/gruf/go-cache/v3 v3.4.3/go.mod h1:pTeVPEb9DshXUkd8Dg76UcsLpU6EC/tXQ2qb+JrmxEc= +codeberg.org/gruf/go-cache/v3 v3.4.4 h1:V0A3EzjhzhULOydD16pwa2DRDwF67OuuP4ORnm//7p8= +codeberg.org/gruf/go-cache/v3 v3.4.4/go.mod h1:pTeVPEb9DshXUkd8Dg76UcsLpU6EC/tXQ2qb+JrmxEc=  codeberg.org/gruf/go-debug v1.3.0 h1:PIRxQiWUFKtGOGZFdZ3Y0pqyfI0Xr87j224IYe2snZs=  codeberg.org/gruf/go-debug v1.3.0/go.mod h1:N+vSy9uJBQgpQcJUqjctvqFz7tBHJf+S/PIjLILzpLg=  codeberg.org/gruf/go-errors/v2 v2.0.0/go.mod h1:ZRhbdhvgoUA3Yw6e56kd9Ox984RrvbEFC2pOXyHDJP4= diff --git a/internal/api/client/blocks/blocks.go b/internal/api/client/blocks/blocks.go index bff9a068e..0eeee2bf1 100644 --- a/internal/api/client/blocks/blocks.go +++ b/internal/api/client/blocks/blocks.go @@ -30,8 +30,10 @@ const (  	// MaxIDKey is the url query for setting a max ID to return  	MaxIDKey = "max_id" +  	// SinceIDKey is the url query for returning results newer than the given ID  	SinceIDKey = "since_id" +  	// LimitKey is for specifying maximum number of results to return.  	LimitKey = "limit"  ) diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go index 7aec8b334..505c33db8 100644 --- a/internal/api/client/blocks/blocksget.go +++ b/internal/api/client/blocks/blocksget.go @@ -18,14 +18,13 @@  package blocks  import ( -	"fmt"  	"net/http" -	"strconv"  	"github.com/gin-gonic/gin"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/oauth" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  )  // BlocksGETHandler swagger:operation GET /api/v1/blocks blocksGet @@ -104,31 +103,21 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {  		return  	} -	maxID := "" -	maxIDString := c.Query(MaxIDKey) -	if maxIDString != "" { -		maxID = maxIDString -	} - -	sinceID := "" -	sinceIDString := c.Query(SinceIDKey) -	if sinceIDString != "" { -		sinceID = sinceIDString -	} - -	limit := 20 -	limitString := c.Query(LimitKey) -	if limitString != "" { -		i, err := strconv.ParseInt(limitString, 10, 32) -		if err != nil { -			err := fmt.Errorf("error parsing %s: %s", LimitKey, err) -			apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) -			return -		} -		limit = int(i) +	limit, errWithCode := apiutil.ParseLimit(c.Query(LimitKey), 20, 100, 2) +	if err != nil { +		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) +		return  	} -	resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit) +	resp, errWithCode := m.processor.BlocksGet( +		c.Request.Context(), +		authed.Account, +		paging.Pager{ +			SinceID: c.Query(SinceIDKey), +			MaxID:   c.Query(MaxIDKey), +			Limit:   limit, +		}, +	)  	if errWithCode != nil {  		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)  		return @@ -137,5 +126,6 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {  	if resp.LinkHeader != "" {  		c.Header("Link", resp.LinkHeader)  	} -	c.JSON(http.StatusOK, resp.Accounts) + +	c.JSON(http.StatusOK, resp.Items)  } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 63564935e..e97dce6f9 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -80,6 +80,27 @@ func (c *Caches) setuphooks() {  		// Invalidate account ID cached visibility.  		c.Visibility.Invalidate("ItemID", account.ID)  		c.Visibility.Invalidate("RequesterID", account.ID) + +		// Invalidate this account's +		// following / follower lists. +		// (see FollowIDs() comment for details). +		c.GTS.FollowIDs().InvalidateAll( +			">"+account.ID, +			"l>"+account.ID, +			"<"+account.ID, +			"l<"+account.ID, +		) + +		// Invalidate this account's +		// follow requesting / request lists. +		// (see FollowRequestIDs() comment for details). +		c.GTS.FollowRequestIDs().InvalidateAll( +			">"+account.ID, +			"<"+account.ID, +		) + +		// Invalidate this account's block lists. +		c.GTS.BlockIDs().Invalidate(account.ID)  	})  	c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) { @@ -90,6 +111,9 @@ func (c *Caches) setuphooks() {  		// Invalidate block target account ID cached visibility.  		c.Visibility.Invalidate("ItemID", block.TargetAccountID)  		c.Visibility.Invalidate("RequesterID", block.TargetAccountID) + +		// Invalidate source account's block lists. +		c.GTS.BlockIDs().Invalidate(block.AccountID)  	})  	c.GTS.EmojiCategory().SetInvalidateCallback(func(category *gtsmodel.EmojiCategory) { @@ -98,6 +122,9 @@ func (c *Caches) setuphooks() {  	})  	c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) { +		// Invalidate follow request with this same ID. +		c.GTS.FollowRequest().Invalidate("ID", follow.ID) +  		// Invalidate any related list entries.  		c.GTS.ListEntry().Invalidate("FollowID", follow.ID) @@ -108,19 +135,35 @@ func (c *Caches) setuphooks() {  		// Invalidate follow target account ID cached visibility.  		c.Visibility.Invalidate("ItemID", follow.TargetAccountID)  		c.Visibility.Invalidate("RequesterID", follow.TargetAccountID) + +		// Invalidate source account's following +		// lists, and destination's follwer lists. +		// (see FollowIDs() comment for details). +		c.GTS.FollowIDs().InvalidateAll( +			">"+follow.AccountID, +			"l>"+follow.AccountID, +			"<"+follow.AccountID, +			"l<"+follow.AccountID, +			"<"+follow.TargetAccountID, +			"l<"+follow.TargetAccountID, +			">"+follow.TargetAccountID, +			"l>"+follow.TargetAccountID, +		)  	})  	c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) { -		// Invalidate follow request origin account ID cached visibility. -		c.Visibility.Invalidate("ItemID", followReq.AccountID) -		c.Visibility.Invalidate("RequesterID", followReq.AccountID) - -		// Invalidate follow request target account ID cached visibility. -		c.Visibility.Invalidate("ItemID", followReq.TargetAccountID) -		c.Visibility.Invalidate("RequesterID", followReq.TargetAccountID) - -		// Invalidate any cached follow with same ID. +		// Invalidate follow with this same ID.  		c.GTS.Follow().Invalidate("ID", followReq.ID) + +		// Invalidate source account's followreq +		// lists, and destinations follow req lists. +		// (see FollowRequestIDs() comment for details). +		c.GTS.FollowRequestIDs().InvalidateAll( +			">"+followReq.AccountID, +			"<"+followReq.AccountID, +			">"+followReq.TargetAccountID, +			"<"+followReq.TargetAccountID, +		)  	})  	c.GTS.List().SetInvalidateCallback(func(list *gtsmodel.List) { @@ -128,12 +171,29 @@ func (c *Caches) setuphooks() {  		c.GTS.ListEntry().Invalidate("ListID", list.ID)  	}) +	c.GTS.Media().SetInvalidateCallback(func(media *gtsmodel.MediaAttachment) { +		if *media.Avatar || *media.Header { +			// Invalidate cache of attaching account. +			c.GTS.Account().Invalidate("ID", media.AccountID) +		} + +		if media.StatusID != "" { +			// Invalidate cache of attaching status. +			c.GTS.Status().Invalidate("ID", media.StatusID) +		} +	}) +  	c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) {  		// Invalidate status ID cached visibility.  		c.Visibility.Invalidate("ItemID", status.ID)  		for _, id := range status.AttachmentIDs { -			// Invalidate cache for attached media IDs, +			// Invalidate each media by the IDs we're aware of. +			// This must be done as the status table is aware of +			// the media IDs in use before the media table is +			// aware of the status ID they are linked to. +			// +			// c.GTS.Media().Invalidate("StatusID") will not work.  			c.GTS.Media().Invalidate("ID", id)  		}  	}) diff --git a/internal/cache/gts.go b/internal/cache/gts.go index dd43154ef..fefd02fff 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -26,29 +26,31 @@ import (  )  type GTSCaches struct { -	account     *result.Cache[*gtsmodel.Account] -	accountNote *result.Cache[*gtsmodel.AccountNote] -	block       *result.Cache[*gtsmodel.Block] -	// TODO: maybe should be moved out of here since it's -	// not actually doing anything with gtsmodel.DomainBlock. -	domainBlock   *domain.BlockCache -	emoji         *result.Cache[*gtsmodel.Emoji] -	emojiCategory *result.Cache[*gtsmodel.EmojiCategory] -	follow        *result.Cache[*gtsmodel.Follow] -	followRequest *result.Cache[*gtsmodel.FollowRequest] -	instance      *result.Cache[*gtsmodel.Instance] -	list          *result.Cache[*gtsmodel.List] -	listEntry     *result.Cache[*gtsmodel.ListEntry] -	marker        *result.Cache[*gtsmodel.Marker] -	media         *result.Cache[*gtsmodel.MediaAttachment] -	mention       *result.Cache[*gtsmodel.Mention] -	notification  *result.Cache[*gtsmodel.Notification] -	report        *result.Cache[*gtsmodel.Report] -	status        *result.Cache[*gtsmodel.Status] -	statusFave    *result.Cache[*gtsmodel.StatusFave] -	tombstone     *result.Cache[*gtsmodel.Tombstone] -	user          *result.Cache[*gtsmodel.User] -	// TODO: move out of GTS caches since not using database models. +	account          *result.Cache[*gtsmodel.Account] +	accountNote      *result.Cache[*gtsmodel.AccountNote] +	block            *result.Cache[*gtsmodel.Block] +	blockIDs         *SliceCache[string] +	domainBlock      *domain.BlockCache +	emoji            *result.Cache[*gtsmodel.Emoji] +	emojiCategory    *result.Cache[*gtsmodel.EmojiCategory] +	follow           *result.Cache[*gtsmodel.Follow] +	followIDs        *SliceCache[string] +	followRequest    *result.Cache[*gtsmodel.FollowRequest] +	followRequestIDs *SliceCache[string] +	instance         *result.Cache[*gtsmodel.Instance] +	list             *result.Cache[*gtsmodel.List] +	listEntry        *result.Cache[*gtsmodel.ListEntry] +	marker           *result.Cache[*gtsmodel.Marker] +	media            *result.Cache[*gtsmodel.MediaAttachment] +	mention          *result.Cache[*gtsmodel.Mention] +	notification     *result.Cache[*gtsmodel.Notification] +	report           *result.Cache[*gtsmodel.Report] +	status           *result.Cache[*gtsmodel.Status] +	statusFave       *result.Cache[*gtsmodel.StatusFave] +	tombstone        *result.Cache[*gtsmodel.Tombstone] +	user             *result.Cache[*gtsmodel.User] + +	// TODO: move out of GTS caches since unrelated to DB.  	webfinger *ttl.Cache[string, string]  } @@ -58,11 +60,14 @@ func (c *GTSCaches) Init() {  	c.initAccount()  	c.initAccountNote()  	c.initBlock() +	c.initBlockIDs()  	c.initDomainBlock()  	c.initEmoji()  	c.initEmojiCategory()  	c.initFollow() +	c.initFollowIDs()  	c.initFollowRequest() +	c.initFollowRequestIDs()  	c.initInstance()  	c.initList()  	c.initListEntry() @@ -83,10 +88,28 @@ func (c *GTSCaches) Start() {  	tryStart(c.account, config.GetCacheGTSAccountSweepFreq())  	tryStart(c.accountNote, config.GetCacheGTSAccountNoteSweepFreq())  	tryStart(c.block, config.GetCacheGTSBlockSweepFreq()) +	tryUntil("starting block IDs cache", 5, func() bool { +		if sweep := config.GetCacheGTSBlockIDsSweepFreq(); sweep > 0 { +			return c.blockIDs.Start(sweep) +		} +		return true +	})  	tryStart(c.emoji, config.GetCacheGTSEmojiSweepFreq())  	tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())  	tryStart(c.follow, config.GetCacheGTSFollowSweepFreq()) +	tryUntil("starting follow IDs cache", 5, func() bool { +		if sweep := config.GetCacheGTSFollowIDsSweepFreq(); sweep > 0 { +			return c.followIDs.Start(sweep) +		} +		return true +	})  	tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) +	tryUntil("starting follow request IDs cache", 5, func() bool { +		if sweep := config.GetCacheGTSFollowRequestIDsSweepFreq(); sweep > 0 { +			return c.followRequestIDs.Start(sweep) +		} +		return true +	})  	tryStart(c.instance, config.GetCacheGTSInstanceSweepFreq())  	tryStart(c.list, config.GetCacheGTSListSweepFreq())  	tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) @@ -112,10 +135,28 @@ func (c *GTSCaches) Stop() {  	tryStop(c.account, config.GetCacheGTSAccountSweepFreq())  	tryStop(c.accountNote, config.GetCacheGTSAccountNoteSweepFreq())  	tryStop(c.block, config.GetCacheGTSBlockSweepFreq()) +	tryUntil("stopping block IDs cache", 5, func() bool { +		if config.GetCacheGTSBlockIDsSweepFreq() > 0 { +			return c.blockIDs.Stop() +		} +		return true +	})  	tryStop(c.emoji, config.GetCacheGTSEmojiSweepFreq())  	tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())  	tryStop(c.follow, config.GetCacheGTSFollowSweepFreq()) +	tryUntil("stopping follow IDs cache", 5, func() bool { +		if config.GetCacheGTSFollowIDsSweepFreq() > 0 { +			return c.followIDs.Stop() +		} +		return true +	})  	tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) +	tryUntil("stopping follow request IDs cache", 5, func() bool { +		if config.GetCacheGTSFollowRequestIDsSweepFreq() > 0 { +			return c.followRequestIDs.Stop() +		} +		return true +	})  	tryStop(c.instance, config.GetCacheGTSInstanceSweepFreq())  	tryStop(c.list, config.GetCacheGTSListSweepFreq())  	tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) @@ -128,7 +169,12 @@ func (c *GTSCaches) Stop() {  	tryStop(c.statusFave, config.GetCacheGTSStatusFaveSweepFreq())  	tryStop(c.tombstone, config.GetCacheGTSTombstoneSweepFreq())  	tryStop(c.user, config.GetCacheGTSUserSweepFreq()) -	tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.webfinger.Stop) +	tryUntil("stopping *gtsmodel.Webfinger cache", 5, func() bool { +		if config.GetCacheGTSWebfingerSweepFreq() > 0 { +			return c.webfinger.Stop() +		} +		return true +	})  }  // Account provides access to the gtsmodel Account database cache. @@ -146,6 +192,11 @@ func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] {  	return c.block  } +// FollowIDs provides access to the block IDs database cache. +func (c *GTSCaches) BlockIDs() *SliceCache[string] { +	return c.blockIDs +} +  // DomainBlock provides access to the domain block database cache.  func (c *GTSCaches) DomainBlock() *domain.BlockCache {  	return c.domainBlock @@ -166,11 +217,29 @@ func (c *GTSCaches) Follow() *result.Cache[*gtsmodel.Follow] {  	return c.follow  } +// FollowIDs provides access to the follower / following IDs database cache. +// THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: +// - '>'  for following IDs +// - 'l>' for local following IDs +// - '<'  for follower IDs +// - 'l<' for local follower IDs +func (c *GTSCaches) FollowIDs() *SliceCache[string] { +	return c.followIDs +} +  // FollowRequest provides access to the gtsmodel FollowRequest database cache.  func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {  	return c.followRequest  } +// FollowRequestIDs provides access to the follow requester / requesting IDs database +// cache. THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: +// - '>'  for following IDs +// - '<'  for follower IDs +func (c *GTSCaches) FollowRequestIDs() *SliceCache[string] { +	return c.followRequestIDs +} +  // Instance provides access to the gtsmodel Instance database cache.  func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] {  	return c.instance @@ -274,6 +343,8 @@ func (c *GTSCaches) initBlock() {  		{Name: "ID"},  		{Name: "URI"},  		{Name: "AccountID.TargetAccountID"}, +		{Name: "AccountID", Multi: true}, +		{Name: "TargetAccountID", Multi: true},  	}, func(b1 *gtsmodel.Block) *gtsmodel.Block {  		b2 := new(gtsmodel.Block)  		*b2 = *b1 @@ -283,6 +354,14 @@ func (c *GTSCaches) initBlock() {  	c.block.IgnoreErrors(ignoreErrors)  } +func (c *GTSCaches) initBlockIDs() { +	c.blockIDs = &SliceCache[string]{Cache: ttl.New[string, []string]( +		0, +		config.GetCacheGTSBlockIDsMaxSize(), +		config.GetCacheGTSBlockIDsTTL(), +	)} +} +  func (c *GTSCaches) initDomainBlock() {  	c.domainBlock = new(domain.BlockCache)  } @@ -321,6 +400,8 @@ func (c *GTSCaches) initFollow() {  		{Name: "ID"},  		{Name: "URI"},  		{Name: "AccountID.TargetAccountID"}, +		{Name: "AccountID", Multi: true}, +		{Name: "TargetAccountID", Multi: true},  	}, func(f1 *gtsmodel.Follow) *gtsmodel.Follow {  		f2 := new(gtsmodel.Follow)  		*f2 = *f1 @@ -329,11 +410,21 @@ func (c *GTSCaches) initFollow() {  	c.follow.SetTTL(config.GetCacheGTSFollowTTL(), true)  } +func (c *GTSCaches) initFollowIDs() { +	c.followIDs = &SliceCache[string]{Cache: ttl.New[string, []string]( +		0, +		config.GetCacheGTSFollowIDsMaxSize(), +		config.GetCacheGTSFollowIDsTTL(), +	)} +} +  func (c *GTSCaches) initFollowRequest() {  	c.followRequest = result.New([]result.Lookup{  		{Name: "ID"},  		{Name: "URI"},  		{Name: "AccountID.TargetAccountID"}, +		{Name: "AccountID", Multi: true}, +		{Name: "TargetAccountID", Multi: true},  	}, func(f1 *gtsmodel.FollowRequest) *gtsmodel.FollowRequest {  		f2 := new(gtsmodel.FollowRequest)  		*f2 = *f1 @@ -342,6 +433,14 @@ func (c *GTSCaches) initFollowRequest() {  	c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)  } +func (c *GTSCaches) initFollowRequestIDs() { +	c.followRequestIDs = &SliceCache[string]{Cache: ttl.New[string, []string]( +		0, +		config.GetCacheGTSFollowRequestIDsMaxSize(), +		config.GetCacheGTSFollowRequestIDsTTL(), +	)} +} +  func (c *GTSCaches) initInstance() {  	c.instance = result.New([]result.Lookup{  		{Name: "ID"}, @@ -502,5 +601,6 @@ func (c *GTSCaches) initWebfinger() {  	c.webfinger = ttl.New[string, string](  		0,  		config.GetCacheGTSWebfingerMaxSize(), -		config.GetCacheGTSWebfingerTTL()) +		config.GetCacheGTSWebfingerTTL(), +	)  } diff --git a/internal/cache/slice.go b/internal/cache/slice.go new file mode 100644 index 000000000..194f20d4b --- /dev/null +++ b/internal/cache/slice.go @@ -0,0 +1,76 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package cache + +import ( +	"codeberg.org/gruf/go-cache/v3/ttl" +	"golang.org/x/exp/slices" +) + +// SliceCache wraps a ttl.Cache to provide simple loader-callback +// functions for fetching + caching slices of objects (e.g. IDs). +type SliceCache[T any] struct { +	*ttl.Cache[string, []T] +} + +// Load will attempt to load an existing slice from the cache for the given key, else calling the provided load function and caching the result. +func (c *SliceCache[T]) Load(key string, load func() ([]T, error)) ([]T, error) { +	// Look for follow IDs list in cache under this key. +	data, ok := c.Get(key) + +	if !ok { +		var err error + +		// Not cached, load! +		data, err = load() +		if err != nil { +			return nil, err +		} + +		// Store the data. +		c.Set(key, data) +	} + +	// Return data clone for safety. +	return slices.Clone(data), nil +} + +// LoadRange is functionally the same as .Load(), but will pass the result through provided reslice function before returning a cloned result. +func (c *SliceCache[T]) LoadRange(key string, load func() ([]T, error), reslice func([]T) []T) ([]T, error) { +	// Look for follow IDs list in cache under this key. +	data, ok := c.Get(key) + +	if !ok { +		var err error + +		// Not cached, load! +		data, err = load() +		if err != nil { +			return nil, err +		} + +		// Store the data. +		c.Set(key, data) +	} + +	// Reslice to range. +	slice := reslice(data) + +	// Return range clone for safety. +	return slices.Clone(slice), nil +} diff --git a/internal/cache/util.go b/internal/cache/util.go index a0adfd366..f2357c904 100644 --- a/internal/cache/util.go +++ b/internal/cache/util.go @@ -18,28 +18,33 @@  package cache  import ( -	"context" +	"database/sql"  	"errors"  	"fmt"  	"time"  	"codeberg.org/gruf/go-cache/v3/result"  	errorsv2 "codeberg.org/gruf/go-errors/v2" +	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/log"  ) -// SentinelError is returned to indicate a non-permanent error return, -// i.e. a situation in which we do not want a cache a negative result. +// SentinelError is an error that can be returned and checked against to indicate a non-permanent +// error return from a cache loader callback, e.g. a temporary situation that will soon be fixed.  var SentinelError = errors.New("BUG: error should not be returned") //nolint:revive -// ignoreErrors is an error ignoring function capable of being passed to -// caches, which specifically catches and ignores our sentinel error type. +// ignoreErrors is an error matching function used to signal which errors +// the result caches should NOT hold onto. these amount to anything non-permanent.  func ignoreErrors(err error) bool { -	return errorsv2.Comparable( +	return !errorsv2.Comparable(  		err, -		SentinelError, -		context.DeadlineExceeded, -		context.Canceled, + +		// the only cacheable errs, +		// i.e anything permanent +		// (until invalidation). +		db.ErrNoEntries, +		db.ErrAlreadyExists, +		sql.ErrNoRows,  	)  } diff --git a/internal/config/config.go b/internal/config/config.go index bd9fc468c..99b07358e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -194,6 +194,10 @@ type GTSCacheConfiguration struct {  	BlockTTL       time.Duration `name:"block-ttl"`  	BlockSweepFreq time.Duration `name:"block-sweep-freq"` +	BlockIDsMaxSize   int           `name:"block-ids-max-size"` +	BlockIDsTTL       time.Duration `name:"block-ids-ttl"` +	BlockIDsSweepFreq time.Duration `name:"block-ids-sweep-freq"` +  	DomainBlockMaxSize   int           `name:"domain-block-max-size"`  	DomainBlockTTL       time.Duration `name:"domain-block-ttl"`  	DomainBlockSweepFreq time.Duration `name:"domain-block-sweep-freq"` @@ -210,10 +214,18 @@ type GTSCacheConfiguration struct {  	FollowTTL       time.Duration `name:"follow-ttl"`  	FollowSweepFreq time.Duration `name:"follow-sweep-freq"` +	FollowIDsMaxSize   int           `name:"follow-ids-max-size"` +	FollowIDsTTL       time.Duration `name:"follow-ids-ttl"` +	FollowIDsSweepFreq time.Duration `name:"follow-ids-sweep-freq"` +  	FollowRequestMaxSize   int           `name:"follow-request-max-size"`  	FollowRequestTTL       time.Duration `name:"follow-request-ttl"`  	FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"` +	FollowRequestIDsMaxSize   int           `name:"follow-request-ids-max-size"` +	FollowRequestIDsTTL       time.Duration `name:"follow-request-ids-ttl"` +	FollowRequestIDsSweepFreq time.Duration `name:"follow-request-ids-sweep-freq"` +  	InstanceMaxSize   int           `name:"instance-max-size"`  	InstanceTTL       time.Duration `name:"instance-ttl"`  	InstanceSweepFreq time.Duration `name:"instance-sweep-freq"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index ee20fb6a7..cb37838c1 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -139,6 +139,10 @@ var Defaults = Configuration{  			BlockTTL:       time.Minute * 30,  			BlockSweepFreq: time.Minute, +			BlockIDsMaxSize:   500, +			BlockIDsTTL:       time.Minute * 30, +			BlockIDsSweepFreq: time.Minute, +  			DomainBlockMaxSize:   2000,  			DomainBlockTTL:       time.Hour * 24,  			DomainBlockSweepFreq: time.Minute, @@ -155,10 +159,18 @@ var Defaults = Configuration{  			FollowTTL:       time.Minute * 30,  			FollowSweepFreq: time.Minute, +			FollowIDsMaxSize:   500, +			FollowIDsTTL:       time.Minute * 30, +			FollowIDsSweepFreq: time.Minute, +  			FollowRequestMaxSize:   2000,  			FollowRequestTTL:       time.Minute * 30,  			FollowRequestSweepFreq: time.Minute, +			FollowRequestIDsMaxSize:   500, +			FollowRequestIDsTTL:       time.Minute * 30, +			FollowRequestIDsSweepFreq: time.Minute, +  			InstanceMaxSize:   2000,  			InstanceTTL:       time.Minute * 30,  			InstanceSweepFreq: time.Minute, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 5eed1b468..1bf8ec2bc 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2624,6 +2624,81 @@ func GetCacheGTSBlockSweepFreq() time.Duration { return global.GetCacheGTSBlockS  // SetCacheGTSBlockSweepFreq safely sets the value for global configuration 'Cache.GTS.BlockSweepFreq' field  func SetCacheGTSBlockSweepFreq(v time.Duration) { global.SetCacheGTSBlockSweepFreq(v) } +// GetCacheGTSBlockIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsMaxSize' field +func (st *ConfigState) GetCacheGTSBlockIDsMaxSize() (v int) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.BlockIDsMaxSize +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSBlockIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.BlockIDsMaxSize' field +func (st *ConfigState) SetCacheGTSBlockIDsMaxSize(v int) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.BlockIDsMaxSize = v +	st.reloadToViper() +} + +// CacheGTSBlockIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.BlockIDsMaxSize' field +func CacheGTSBlockIDsMaxSizeFlag() string { return "cache-gts-block-ids-max-size" } + +// GetCacheGTSBlockIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.BlockIDsMaxSize' field +func GetCacheGTSBlockIDsMaxSize() int { return global.GetCacheGTSBlockIDsMaxSize() } + +// SetCacheGTSBlockIDsMaxSize safely sets the value for global configuration 'Cache.GTS.BlockIDsMaxSize' field +func SetCacheGTSBlockIDsMaxSize(v int) { global.SetCacheGTSBlockIDsMaxSize(v) } + +// GetCacheGTSBlockIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsTTL' field +func (st *ConfigState) GetCacheGTSBlockIDsTTL() (v time.Duration) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.BlockIDsTTL +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSBlockIDsTTL safely sets the Configuration value for state's 'Cache.GTS.BlockIDsTTL' field +func (st *ConfigState) SetCacheGTSBlockIDsTTL(v time.Duration) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.BlockIDsTTL = v +	st.reloadToViper() +} + +// CacheGTSBlockIDsTTLFlag returns the flag name for the 'Cache.GTS.BlockIDsTTL' field +func CacheGTSBlockIDsTTLFlag() string { return "cache-gts-block-ids-ttl" } + +// GetCacheGTSBlockIDsTTL safely fetches the value for global configuration 'Cache.GTS.BlockIDsTTL' field +func GetCacheGTSBlockIDsTTL() time.Duration { return global.GetCacheGTSBlockIDsTTL() } + +// SetCacheGTSBlockIDsTTL safely sets the value for global configuration 'Cache.GTS.BlockIDsTTL' field +func SetCacheGTSBlockIDsTTL(v time.Duration) { global.SetCacheGTSBlockIDsTTL(v) } + +// GetCacheGTSBlockIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.BlockIDsSweepFreq' field +func (st *ConfigState) GetCacheGTSBlockIDsSweepFreq() (v time.Duration) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.BlockIDsSweepFreq +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSBlockIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.BlockIDsSweepFreq' field +func (st *ConfigState) SetCacheGTSBlockIDsSweepFreq(v time.Duration) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.BlockIDsSweepFreq = v +	st.reloadToViper() +} + +// CacheGTSBlockIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.BlockIDsSweepFreq' field +func CacheGTSBlockIDsSweepFreqFlag() string { return "cache-gts-block-ids-sweep-freq" } + +// GetCacheGTSBlockIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.BlockIDsSweepFreq' field +func GetCacheGTSBlockIDsSweepFreq() time.Duration { return global.GetCacheGTSBlockIDsSweepFreq() } + +// SetCacheGTSBlockIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.BlockIDsSweepFreq' field +func SetCacheGTSBlockIDsSweepFreq(v time.Duration) { global.SetCacheGTSBlockIDsSweepFreq(v) } +  // GetCacheGTSDomainBlockMaxSize safely fetches the Configuration value for state's 'Cache.GTS.DomainBlockMaxSize' field  func (st *ConfigState) GetCacheGTSDomainBlockMaxSize() (v int) {  	st.mutex.RLock() @@ -2926,6 +3001,81 @@ func GetCacheGTSFollowSweepFreq() time.Duration { return global.GetCacheGTSFollo  // SetCacheGTSFollowSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowSweepFreq' field  func SetCacheGTSFollowSweepFreq(v time.Duration) { global.SetCacheGTSFollowSweepFreq(v) } +// GetCacheGTSFollowIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsMaxSize' field +func (st *ConfigState) GetCacheGTSFollowIDsMaxSize() (v int) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.FollowIDsMaxSize +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSFollowIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowIDsMaxSize' field +func (st *ConfigState) SetCacheGTSFollowIDsMaxSize(v int) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.FollowIDsMaxSize = v +	st.reloadToViper() +} + +// CacheGTSFollowIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowIDsMaxSize' field +func CacheGTSFollowIDsMaxSizeFlag() string { return "cache-gts-follow-ids-max-size" } + +// GetCacheGTSFollowIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowIDsMaxSize' field +func GetCacheGTSFollowIDsMaxSize() int { return global.GetCacheGTSFollowIDsMaxSize() } + +// SetCacheGTSFollowIDsMaxSize safely sets the value for global configuration 'Cache.GTS.FollowIDsMaxSize' field +func SetCacheGTSFollowIDsMaxSize(v int) { global.SetCacheGTSFollowIDsMaxSize(v) } + +// GetCacheGTSFollowIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsTTL' field +func (st *ConfigState) GetCacheGTSFollowIDsTTL() (v time.Duration) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.FollowIDsTTL +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSFollowIDsTTL safely sets the Configuration value for state's 'Cache.GTS.FollowIDsTTL' field +func (st *ConfigState) SetCacheGTSFollowIDsTTL(v time.Duration) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.FollowIDsTTL = v +	st.reloadToViper() +} + +// CacheGTSFollowIDsTTLFlag returns the flag name for the 'Cache.GTS.FollowIDsTTL' field +func CacheGTSFollowIDsTTLFlag() string { return "cache-gts-follow-ids-ttl" } + +// GetCacheGTSFollowIDsTTL safely fetches the value for global configuration 'Cache.GTS.FollowIDsTTL' field +func GetCacheGTSFollowIDsTTL() time.Duration { return global.GetCacheGTSFollowIDsTTL() } + +// SetCacheGTSFollowIDsTTL safely sets the value for global configuration 'Cache.GTS.FollowIDsTTL' field +func SetCacheGTSFollowIDsTTL(v time.Duration) { global.SetCacheGTSFollowIDsTTL(v) } + +// GetCacheGTSFollowIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowIDsSweepFreq' field +func (st *ConfigState) GetCacheGTSFollowIDsSweepFreq() (v time.Duration) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.FollowIDsSweepFreq +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSFollowIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowIDsSweepFreq' field +func (st *ConfigState) SetCacheGTSFollowIDsSweepFreq(v time.Duration) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.FollowIDsSweepFreq = v +	st.reloadToViper() +} + +// CacheGTSFollowIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowIDsSweepFreq' field +func CacheGTSFollowIDsSweepFreqFlag() string { return "cache-gts-follow-ids-sweep-freq" } + +// GetCacheGTSFollowIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowIDsSweepFreq' field +func GetCacheGTSFollowIDsSweepFreq() time.Duration { return global.GetCacheGTSFollowIDsSweepFreq() } + +// SetCacheGTSFollowIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowIDsSweepFreq' field +func SetCacheGTSFollowIDsSweepFreq(v time.Duration) { global.SetCacheGTSFollowIDsSweepFreq(v) } +  // GetCacheGTSFollowRequestMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestMaxSize' field  func (st *ConfigState) GetCacheGTSFollowRequestMaxSize() (v int) {  	st.mutex.RLock() @@ -3003,6 +3153,85 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration {  // SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field  func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) } +// GetCacheGTSFollowRequestIDsMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsMaxSize' field +func (st *ConfigState) GetCacheGTSFollowRequestIDsMaxSize() (v int) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.FollowRequestIDsMaxSize +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSFollowRequestIDsMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsMaxSize' field +func (st *ConfigState) SetCacheGTSFollowRequestIDsMaxSize(v int) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.FollowRequestIDsMaxSize = v +	st.reloadToViper() +} + +// CacheGTSFollowRequestIDsMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsMaxSize' field +func CacheGTSFollowRequestIDsMaxSizeFlag() string { return "cache-gts-follow-request-ids-max-size" } + +// GetCacheGTSFollowRequestIDsMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsMaxSize' field +func GetCacheGTSFollowRequestIDsMaxSize() int { return global.GetCacheGTSFollowRequestIDsMaxSize() } + +// SetCacheGTSFollowRequestIDsMaxSize safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsMaxSize' field +func SetCacheGTSFollowRequestIDsMaxSize(v int) { global.SetCacheGTSFollowRequestIDsMaxSize(v) } + +// GetCacheGTSFollowRequestIDsTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsTTL' field +func (st *ConfigState) GetCacheGTSFollowRequestIDsTTL() (v time.Duration) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.FollowRequestIDsTTL +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSFollowRequestIDsTTL safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsTTL' field +func (st *ConfigState) SetCacheGTSFollowRequestIDsTTL(v time.Duration) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.FollowRequestIDsTTL = v +	st.reloadToViper() +} + +// CacheGTSFollowRequestIDsTTLFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsTTL' field +func CacheGTSFollowRequestIDsTTLFlag() string { return "cache-gts-follow-request-ids-ttl" } + +// GetCacheGTSFollowRequestIDsTTL safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsTTL' field +func GetCacheGTSFollowRequestIDsTTL() time.Duration { return global.GetCacheGTSFollowRequestIDsTTL() } + +// SetCacheGTSFollowRequestIDsTTL safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsTTL' field +func SetCacheGTSFollowRequestIDsTTL(v time.Duration) { global.SetCacheGTSFollowRequestIDsTTL(v) } + +// GetCacheGTSFollowRequestIDsSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestIDsSweepFreq' field +func (st *ConfigState) GetCacheGTSFollowRequestIDsSweepFreq() (v time.Duration) { +	st.mutex.RLock() +	v = st.config.Cache.GTS.FollowRequestIDsSweepFreq +	st.mutex.RUnlock() +	return +} + +// SetCacheGTSFollowRequestIDsSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowRequestIDsSweepFreq' field +func (st *ConfigState) SetCacheGTSFollowRequestIDsSweepFreq(v time.Duration) { +	st.mutex.Lock() +	defer st.mutex.Unlock() +	st.config.Cache.GTS.FollowRequestIDsSweepFreq = v +	st.reloadToViper() +} + +// CacheGTSFollowRequestIDsSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowRequestIDsSweepFreq' field +func CacheGTSFollowRequestIDsSweepFreqFlag() string { return "cache-gts-follow-request-ids-sweep-freq" } + +// GetCacheGTSFollowRequestIDsSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowRequestIDsSweepFreq' field +func GetCacheGTSFollowRequestIDsSweepFreq() time.Duration { +	return global.GetCacheGTSFollowRequestIDsSweepFreq() +} + +// SetCacheGTSFollowRequestIDsSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestIDsSweepFreq' field +func SetCacheGTSFollowRequestIDsSweepFreq(v time.Duration) { +	global.SetCacheGTSFollowRequestIDsSweepFreq(v) +} +  // GetCacheGTSInstanceMaxSize safely fetches the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field  func (st *ConfigState) GetCacheGTSInstanceMaxSize() (v int) {  	st.mutex.RLock() diff --git a/internal/db/account.go b/internal/db/account.go index 21b8d6a1f..505ca4004 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -104,8 +104,6 @@ type Account interface {  	// In the case of no statuses, this function will return db.ErrNoEntries.  	GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) -	GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) -  	// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.  	//  	// If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned. diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2ef1618db..e57c01a82 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -694,46 +694,6 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,  	return a.statusesFromIDs(ctx, statusIDs)  } -func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) { -	blocks := []*gtsmodel.Block{} - -	fq := a.db. -		NewSelect(). -		Model(&blocks). -		Where("? = ?", bun.Ident("block.account_id"), accountID). -		Relation("TargetAccount"). -		Order("block.id DESC") - -	if maxID != "" { -		fq = fq.Where("? < ?", bun.Ident("block.id"), maxID) -	} - -	if sinceID != "" { -		fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID) -	} - -	if limit > 0 { -		fq = fq.Limit(limit) -	} - -	if err := fq.Scan(ctx); err != nil { -		return nil, "", "", a.db.ProcessError(err) -	} - -	if len(blocks) == 0 { -		return nil, "", "", db.ErrNoEntries -	} - -	accounts := []*gtsmodel.Account{} -	for _, b := range blocks { -		accounts = append(accounts, b.TargetAccount) -	} - -	nextMaxID := blocks[len(blocks)-1].ID -	prevMinID := blocks[0].ID -	return accounts, nextMaxID, prevMinID, nil -} -  func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {  	// Catch case of no statuses early  	if len(statusIDs) == 0 { diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 90bcd134d..04f22b6e9 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -126,16 +126,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {  			return err  		} -		// Prepare SELECT accounts query. -		aq := tx.NewSelect(). -			Table("accounts"). -			Column("id") - -		// Append a WHERE LIKE clause to the query +		// Prepare a SELECT query with a WHERE LIKE  		// that checks the `emoji` column for any  		// text containing this specific emoji ID.  		//  		// (see GetStatusesUsingEmoji() for details.) +		aq := tx.NewSelect().Table("accounts").Column("id")  		aq = whereLike(aq, "emojis", id)  		// Select all accounts using this emoji into accountIDss. @@ -170,16 +166,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {  			}  		} -		// Prepare SELECT statuses query. -		sq := tx.NewSelect(). -			Table("statuses"). -			Column("id") - -		// Append a WHERE LIKE clause to the query +		// Prepare a SELECT query with a WHERE LIKE  		// that checks the `emoji` column for any  		// text containing this specific emoji ID.  		//  		// (see GetStatusesUsingEmoji() for details.) +		sq := tx.NewSelect().Table("statuses").Column("id")  		sq = whereLike(sq, "emojis", id)  		// Select all statuses using this emoji into statusIDs. diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 25bb3a65d..70faf837a 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -189,11 +189,10 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {  		gtscontext.SetBarebones(ctx),  		id,  	) -	if err != nil { -		if errors.Is(err, db.ErrNoEntries) { -			// Already gone. -			return nil -		} +	if err != nil && !errors.Is(err, db.ErrNoEntries) { +		// NOTE: even if db.ErrNoEntries is returned, we +		// still run the below transaction to ensure related +		// objects are appropriately deleted.  		return err  	} diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 3b885af61..b8120b87a 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -106,8 +106,6 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt  }  func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { -	defer m.state.Caches.GTS.Media().Invalidate("ID", id) -  	// Load media into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -120,10 +118,8 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  		return err  	} -	var ( -		invalidateAccount bool -		invalidateStatus  bool -	) +	// On return, ensure that media with ID is invalidated. +	defer m.state.Caches.GTS.Media().Invalidate("ID", id)  	// Delete media attachment in new transaction.  	err = m.db.RunInTx(ctx, func(tx bun.Tx) error { @@ -161,9 +157,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  				if _, err := set(q).Exec(ctx); err != nil {  					return gtserror.Newf("error updating account: %w", err)  				} - -				// Mark as needing invalidate. -				invalidateAccount = true  			}  		} @@ -178,33 +171,18 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  				return gtserror.Newf("error selecting status: %w", err)  			} -			// Get length of attachments beforehand. -			before := len(status.AttachmentIDs) - -			for i := 0; i < len(status.AttachmentIDs); { -				if status.AttachmentIDs[i] == id { -					// Remove this reference to deleted attachment ID. -					copy(status.AttachmentIDs[i:], status.AttachmentIDs[i+1:]) -					status.AttachmentIDs = status.AttachmentIDs[:len(status.AttachmentIDs)-1] -					continue -				} -				i++ -			} - -			if before != len(status.AttachmentIDs) { -				// Note: this accounts for status not found. +			if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse +			len(updatedIDs) != len(status.AttachmentIDs) { +				// Note: this handles not found.  				//  				// Attachments changed, update the status.  				if _, err := tx.NewUpdate().  					Table("statuses").  					Where("? = ?", bun.Ident("id"), status.ID). -					Set("? = ?", bun.Ident("attachment_ids"), status.AttachmentIDs). +					Set("? = ?", bun.Ident("attachment_ids"), updatedIDs).  					Exec(ctx); err != nil {  					return gtserror.Newf("error updating status: %w", err)  				} - -				// Mark as needing invalidate. -				invalidateStatus = true  			}  		} @@ -219,16 +197,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  		return nil  	}) -	if invalidateAccount { -		// The account for given ID will have been updated in transaction. -		m.state.Caches.GTS.Account().Invalidate("ID", media.AccountID) -	} - -	if invalidateStatus { -		// The status for given ID will have been updated in transaction. -		m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID) -	} -  	return m.db.ProcessError(err)  } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index eddd73b49..e7b563f2e 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -20,11 +20,12 @@ package bundb  import (  	"context"  	"errors" -	"fmt"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/uptrace/bun"  ) @@ -45,7 +46,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		targetAccount,  	)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err) +		return nil, gtserror.Newf("error fetching follow: %w", err)  	}  	if follow != nil { @@ -61,7 +62,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		requestingAccount,  	)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err) +		return nil, gtserror.Newf("error checking followedBy: %w", err)  	}  	// check if requesting has follow requested target @@ -70,19 +71,19 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		targetAccount,  	)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err) +		return nil, gtserror.Newf("error checking requested: %w", err)  	}  	// check if the requesting account is blocking the target account  	rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err) +		return nil, gtserror.Newf("error checking blocking: %w", err)  	}  	// check if the requesting account is blocked by the target account  	rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err) +		return nil, gtserror.Newf("error checking blockedBy: %w", err)  	}  	// retrieve a note by the requesting account on the target account, if there is one @@ -92,7 +93,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		targetAccount,  	)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return nil, fmt.Errorf("GetRelationship: error fetching note: %w", err) +		return nil, gtserror.Newf("error fetching note: %w", err)  	}  	if note != nil {  		rel.Note = note.Comment @@ -102,87 +103,186 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  }  func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectFollows(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followIDs, err := r.getAccountFollowIDs(ctx, accountID) +	if err != nil { +		return nil, err  	}  	return r.GetFollowsByIDs(ctx, followIDs)  }  func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectLocalFollows(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) +	if err != nil { +		return nil, err  	}  	return r.GetFollowsByIDs(ctx, followIDs)  }  func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectFollowers(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) +	if err != nil { +		return nil, err  	} -	return r.GetFollowsByIDs(ctx, followIDs) +	return r.GetFollowsByIDs(ctx, followerIDs)  }  func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectLocalFollowers(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) +	if err != nil { +		return nil, err  	} -	return r.GetFollowsByIDs(ctx, followIDs) +	return r.GetFollowsByIDs(ctx, followerIDs) +} + +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { +	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) +	if err != nil { +		return nil, err +	} +	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +} + +func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { +	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) +	if err != nil { +		return nil, err +	} +	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +} + +func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Pager) ([]*gtsmodel.Block, error) { +	// Load block IDs from cache with database loader callback. +	blockIDs, err := r.state.Caches.GTS.BlockIDs().LoadRange(accountID, func() ([]string, error) { +		var blockIDs []string + +		// Block IDs not in cache, perform DB query! +		q := newSelectBlocks(r.db, accountID) +		if _, err := q.Exec(ctx, &blockIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return blockIDs, nil +	}, page.PageDesc) +	if err != nil { +		return nil, err +	} + +	// Convert these IDs to full block objects. +	return r.GetBlocksByIDs(ctx, blockIDs)  }  func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollows(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followIDs, err := r.getAccountFollowIDs(ctx, accountID) +	return len(followIDs), err  }  func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectLocalFollows(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) +	return len(followIDs), err  }  func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollowers(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) +	return len(followerIDs), err  }  func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) +	return len(followerIDs), err  } -func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { -	var followReqIDs []string -	if err := newSelectFollowRequests(r.db, accountID). -		Scan(ctx, &followReqIDs); err != nil { -		return nil, r.db.ProcessError(err) -	} -	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { +	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) +	return len(followReqIDs), err  } -func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { -	var followReqIDs []string -	if err := newSelectFollowRequesting(r.db, accountID). -		Scan(ctx, &followReqIDs); err != nil { -		return nil, r.db.ProcessError(err) -	} -	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { +	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) +	return len(followReqIDs), err  } -func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollowRequests(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectFollows(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	})  } -func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectLocalFollows(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	}) +} + +func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectFollowers(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	}) +} + +func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectLocalFollowers(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	}) +} + +func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { +		var followReqIDs []string + +		// Follow request IDs not in cache, perform DB query! +		q := newSelectFollowRequests(r.db, accountID) +		if _, err := q.Exec(ctx, &followReqIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followReqIDs, nil +	}) +} + +func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { +		var followReqIDs []string + +		// Follow request IDs not in cache, perform DB query! +		q := newSelectFollowRequesting(r.db, accountID) +		if _, err := q.Exec(ctx, &followReqIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followReqIDs, nil +	})  }  // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. @@ -256,3 +356,12 @@ func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery {  		).  		OrderExpr("? DESC", bun.Ident("updated_at"))  } + +// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. +func newSelectBlocks(db *WrappedDB, accountID string) *bun.SelectQuery { +	return db.NewSelect(). +		TableExpr("?", bun.Ident("blocks")). +		ColumnExpr("?", bun.Ident("?")). +		Where("? = ?", bun.Ident("account_id"), accountID). +		OrderExpr("? DESC", bun.Ident("updated_at")) +} diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index 948e82fcb..2a042bed4 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -25,6 +25,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/uptrace/bun"  ) @@ -97,6 +98,25 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t  	)  } +func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) { +	// Preallocate slice of expected length. +	blocks := make([]*gtsmodel.Block, 0, len(ids)) + +	for _, id := range ids { +		// Fetch block model for this ID. +		block, err := r.GetBlockByID(ctx, id) +		if err != nil { +			log.Errorf(ctx, "error getting block %q: %v", id, err) +			continue +		} + +		// Append to return slice. +		blocks = append(blocks, block) +	} + +	return blocks, nil +} +  func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {  	// Fetch block from cache with loader callback  	block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { @@ -148,8 +168,6 @@ func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) er  }  func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { -	defer r.state.Caches.GTS.Block().Invalidate("ID", id) -  	// Load block into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -162,6 +180,9 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {  		return err  	} +	// Drop this now-cached block on return after delete. +	defer r.state.Caches.GTS.Block().Invalidate("ID", id) +  	// Finally delete block from DB.  	_, err = r.db.NewDelete().  		Table("blocks"). @@ -171,8 +192,6 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {  }  func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { -	defer r.state.Caches.GTS.Block().Invalidate("URI", uri) -  	// Load block into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -185,6 +204,9 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error  		return err  	} +	// Drop this now-cached block on return after delete. +	defer r.state.Caches.GTS.Block().Invalidate("URI", uri) +  	// Finally delete block from DB.  	_, err = r.db.NewDelete().  		Table("blocks"). @@ -211,10 +233,9 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri  	}  	defer func() { -		// Invalidate all IDs on return. -		for _, id := range blockIDs { -			r.state.Caches.GTS.Block().Invalidate("ID", id) -		} +		// Invalidate all account's incoming / outoing blocks on return. +		r.state.Caches.GTS.Block().Invalidate("AccountID", accountID) +		r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID)  	}()  	// Load all blocks into cache, this *really* isn't great diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 84501b0be..3b0597612 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -233,8 +233,6 @@ func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {  }  func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID string, targetAccountID string) error { -	defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) -  	// Load follow into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -251,13 +249,14 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin  		return err  	} +	// Drop this now-cached follow on return after delete. +	defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) +  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID)  }  func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error { -	defer r.state.Caches.GTS.Follow().Invalidate("ID", id) -  	// Load follow into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -270,13 +269,14 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error  		return err  	} +	// Drop this now-cached follow on return after delete. +	defer r.state.Caches.GTS.Follow().Invalidate("ID", id) +  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID)  }  func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error { -	defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) -  	// Load follow into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -289,6 +289,9 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro  		return err  	} +	// Drop this now-cached follow on return after delete. +	defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) +  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID)  } @@ -312,10 +315,9 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str  	}  	defer func() { -		// Invalidate all IDs on return. -		for _, id := range followIDs { -			r.state.Caches.GTS.Follow().Invalidate("ID", id) -		} +		// Invalidate all account's incoming / outoing follows on return. +		r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID) +		r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID)  	}()  	// Load all follows into cache, this *really* isn't great diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index a6e913953..dc5e760e6 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -208,9 +208,6 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI  		return nil, err  	} -	// Invalidate follow request from cache lookups on return. -	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) -  	// Delete original follow request.  	if _, err := r.db.  		NewDelete(). @@ -243,8 +240,6 @@ func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountI  }  func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error { -	defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) -  	// Load followreq into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -261,6 +256,9 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI  		return err  	} +	// Drop this now-cached follow request on return after delete. +	defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) +  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete().  		Table("follow_requests"). @@ -270,8 +268,6 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI  }  func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { -	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) -  	// Load followreq into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -284,6 +280,9 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)  		return err  	} +	// Drop this now-cached follow request on return after delete. +	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) +  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete().  		Table("follow_requests"). @@ -293,8 +292,6 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)  }  func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { -	defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) -  	// Load followreq into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -307,6 +304,9 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin  		return err  	} +	// Drop this now-cached follow request on return after delete. +	defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) +  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete().  		Table("follow_requests"). @@ -334,10 +334,9 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun  	}  	defer func() { -		// Invalidate all IDs on return. -		for _, id := range followReqIDs { -			r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) -		} +		// Invalidate all account's incoming / outoing follow requests on return. +		r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID) +		r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID)  	}()  	// Load all followreqs into cache, this *really* isn't diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index a019216d0..4dc7d8468 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -381,8 +381,6 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co  }  func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { -	defer s.state.Caches.GTS.Status().Invalidate("ID", id) -  	// Load status into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -397,6 +395,9 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {  		return err  	} +	// On return ensure status invalidated from cache. +	defer s.state.Caches.GTS.Status().Invalidate("ID", id) +  	return s.db.RunInTx(ctx, func(tx bun.Tx) error {  		// delete links between this status and any emojis it uses  		if _, err := tx. diff --git a/internal/db/relationship.go b/internal/db/relationship.go index e19aee646..6ba9fdf8c 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -21,6 +21,7 @@ import (  	"context"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  )  // Relationship contains functions for getting or modifying the relationship between two accounts. @@ -166,6 +167,9 @@ type Relationship interface {  	// CountAccountFollowerRequests returns number of follow requests originating from the given account.  	CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) +	// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. +	GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Pager) ([]*gtsmodel.Block, error) +  	// GetNote gets a private note from a source account on a target account, if it exists.  	GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) diff --git a/internal/paging/paging.go b/internal/paging/paging.go new file mode 100644 index 000000000..0323f40bc --- /dev/null +++ b/internal/paging/paging.go @@ -0,0 +1,227 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package paging + +import "golang.org/x/exp/slices" + +// Pager provides a means of paging serialized IDs, +// using the terminology of our API endpoint queries. +type Pager struct { +	// SinceID will limit the returned +	// page of IDs to contain newer than +	// since ID (excluding it). Result +	// will be returned DESCENDING. +	SinceID string + +	// MinID will limit the returned +	// page of IDs to contain newer than +	// min ID (excluding it). Result +	// will be returned ASCENDING. +	MinID string + +	// MaxID will limit the returned +	// page of IDs to contain older +	// than (excluding) this max ID. +	MaxID string + +	// Limit will limit the returned +	// page of IDs to at most 'limit'. +	Limit int +} + +// Page will page the given slice of GoToSocial IDs according +// to the receiving Pager's SinceID, MinID, MaxID and Limits. +// NOTE THE INPUT SLICE MUST BE SORTED IN ASCENDING ORDER +// (I.E. OLDEST ITEMS AT LOWEST INDICES, NEWER AT HIGHER). +func (p *Pager) PageAsc(ids []string) []string { +	if p == nil { +		// no paging. +		return ids +	} + +	var asc bool + +	if p.SinceID != "" { +		// If a sinceID is given, we +		// page down i.e. descending. +		asc = false + +		for i := 0; i < len(ids); i++ { +			if ids[i] == p.SinceID { +				// Hit the boundary. +				// Reslice to be: +				// "from here" +				ids = ids[i+1:] +				break +			} +		} +	} else if p.MinID != "" { +		// We only support minID if +		// no sinceID is provided. +		// +		// If a minID is given, we +		// page up, i.e. ascending. +		asc = true + +		for i := 0; i < len(ids); i++ { +			if ids[i] == p.MinID { +				// Hit the boundary. +				// Reslice to be: +				// "from here" +				ids = ids[i+1:] +				break +			} +		} +	} + +	if p.MaxID != "" { +		for i := 0; i < len(ids); i++ { +			if ids[i] == p.MaxID { +				// Hit the boundary. +				// Reslice to be: +				// "up to here" +				ids = ids[:i] +				break +			} +		} +	} + +	if !asc && len(ids) > 1 { +		var ( +			// Start at front. +			i = 0 + +			// Start at back. +			j = len(ids) - 1 +		) + +		// Clone input IDs before +		// we perform modifications. +		ids = slices.Clone(ids) + +		for i < j { +			// Swap i,j index values in slice. +			ids[i], ids[j] = ids[j], ids[i] + +			// incr + decr, +			// looping until +			// they meet in +			// the middle. +			i++ +			j-- +		} +	} + +	if p.Limit > 0 && p.Limit < len(ids) { +		// Reslice IDs to given limit. +		ids = ids[:p.Limit] +	} + +	return ids +} + +// Page will page the given slice of GoToSocial IDs according +// to the receiving Pager's SinceID, MinID, MaxID and Limits. +// NOTE THE INPUT SLICE MUST BE SORTED IN ASCENDING ORDER. +// (I.E. NEWEST ITEMS AT LOWEST INDICES, OLDER AT HIGHER). +func (p *Pager) PageDesc(ids []string) []string { +	if p == nil { +		// no paging. +		return ids +	} + +	var asc bool + +	if p.MaxID != "" { +		for i := 0; i < len(ids); i++ { +			if ids[i] == p.MaxID { +				// Hit the boundary. +				// Reslice to be: +				// "from here" +				ids = ids[i+1:] +				break +			} +		} +	} + +	if p.SinceID != "" { +		// If a sinceID is given, we +		// page down i.e. descending. +		asc = false + +		for i := 0; i < len(ids); i++ { +			if ids[i] == p.SinceID { +				// Hit the boundary. +				// Reslice to be: +				// "up to here" +				ids = ids[:i] +				break +			} +		} +	} else if p.MinID != "" { +		// We only support minID if +		// no sinceID is provided. +		// +		// If a minID is given, we +		// page up, i.e. ascending. +		asc = true + +		for i := 0; i < len(ids); i++ { +			if ids[i] == p.MinID { +				// Hit the boundary. +				// Reslice to be: +				// "up to here" +				ids = ids[:i] +				break +			} +		} +	} + +	if asc && len(ids) > 1 { +		var ( +			// Start at front. +			i = 0 + +			// Start at back. +			j = len(ids) - 1 +		) + +		// Clone input IDs before +		// we perform modifications. +		ids = slices.Clone(ids) + +		for i < j { +			// Swap i,j index values in slice. +			ids[i], ids[j] = ids[j], ids[i] + +			// incr + decr, +			// looping until +			// they meet in +			// the middle. +			i++ +			j-- +		} +	} + +	if p.Limit > 0 && p.Limit < len(ids) { +		// Reslice IDs to given limit. +		ids = ids[:p.Limit] +	} + +	return ids +} diff --git a/internal/paging/paging_test.go b/internal/paging/paging_test.go new file mode 100644 index 000000000..71c3be0c9 --- /dev/null +++ b/internal/paging/paging_test.go @@ -0,0 +1,171 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package paging_test + +import ( +	"testing" + +	"github.com/superseriousbusiness/gotosocial/internal/paging" +	"golang.org/x/exp/slices" +) + +type Case struct { +	// Name is the test case name. +	Name string + +	// Input contains test case input ID slice. +	Input []string + +	// Expect contains expected test case output. +	Expect []string + +	// Page contains the paging function to use. +	Page func([]string) []string +} + +var cases = []Case{ +	{ +		Name: "min_id and max_id set", +		Input: []string{ +			"064Q5D7VG6TPPQ46T09MHJ96FW", +			"064Q5D7VGPTC4NK5T070VYSSF8", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VJADJTPA3GW8WAX10TW", +			"064Q5D7VJMWXZD3S1KT7RD51N8", +			"064Q5D7VJYFBYSAH86KDBKZ6AC", +			"064Q5D7VK8H7WMJS399SHEPCB0", +			"064Q5D7VKG5EQ43TYP71B4K6K0", +		}, +		Expect: []string{ +			"064Q5D7VGPTC4NK5T070VYSSF8", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VJADJTPA3GW8WAX10TW", +			"064Q5D7VJMWXZD3S1KT7RD51N8", +			"064Q5D7VJYFBYSAH86KDBKZ6AC", +			"064Q5D7VK8H7WMJS399SHEPCB0", +		}, +		Page: (&paging.Pager{ +			MinID: "064Q5D7VG6TPPQ46T09MHJ96FW", +			MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0", +		}).PageAsc, +	}, +	{ +		Name: "min_id, max_id and limit set", +		Input: []string{ +			"064Q5D7VG6TPPQ46T09MHJ96FW", +			"064Q5D7VGPTC4NK5T070VYSSF8", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VJADJTPA3GW8WAX10TW", +			"064Q5D7VJMWXZD3S1KT7RD51N8", +			"064Q5D7VJYFBYSAH86KDBKZ6AC", +			"064Q5D7VK8H7WMJS399SHEPCB0", +			"064Q5D7VKG5EQ43TYP71B4K6K0", +		}, +		Expect: []string{ +			"064Q5D7VGPTC4NK5T070VYSSF8", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VJADJTPA3GW8WAX10TW", +		}, +		Page: (&paging.Pager{ +			MinID: "064Q5D7VG6TPPQ46T09MHJ96FW", +			MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0", +			Limit: 5, +		}).PageAsc, +	}, +	{ +		Name: "min_id, max_id and too-large limit set", +		Input: []string{ +			"064Q5D7VG6TPPQ46T09MHJ96FW", +			"064Q5D7VGPTC4NK5T070VYSSF8", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VJADJTPA3GW8WAX10TW", +			"064Q5D7VJMWXZD3S1KT7RD51N8", +			"064Q5D7VJYFBYSAH86KDBKZ6AC", +			"064Q5D7VK8H7WMJS399SHEPCB0", +			"064Q5D7VKG5EQ43TYP71B4K6K0", +		}, +		Expect: []string{ +			"064Q5D7VGPTC4NK5T070VYSSF8", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VJADJTPA3GW8WAX10TW", +			"064Q5D7VJMWXZD3S1KT7RD51N8", +			"064Q5D7VJYFBYSAH86KDBKZ6AC", +			"064Q5D7VK8H7WMJS399SHEPCB0", +		}, +		Page: (&paging.Pager{ +			MinID: "064Q5D7VG6TPPQ46T09MHJ96FW", +			MaxID: "064Q5D7VKG5EQ43TYP71B4K6K0", +			Limit: 100, +		}).PageAsc, +	}, +	{ +		Name: "since_id and max_id set", +		Input: []string{ +			"064Q5D7VG6TPPQ46T09MHJ96FW", +			"064Q5D7VGPTC4NK5T070VYSSF8", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VJADJTPA3GW8WAX10TW", +			"064Q5D7VJMWXZD3S1KT7RD51N8", +			"064Q5D7VJYFBYSAH86KDBKZ6AC", +			"064Q5D7VK8H7WMJS399SHEPCB0", +			"064Q5D7VKG5EQ43TYP71B4K6K0", +		}, +		Expect: []string{ +			"064Q5D7VK8H7WMJS399SHEPCB0", +			"064Q5D7VJYFBYSAH86KDBKZ6AC", +			"064Q5D7VJMWXZD3S1KT7RD51N8", +			"064Q5D7VJADJTPA3GW8WAX10TW", +			"064Q5D7VJ073XG9ZTWHA2KHN10", +			"064Q5D7VHMSW9DF3GCS088VAZC", +			"064Q5D7VH5F0JXG6W5NCQ3JCWW", +			"064Q5D7VGPTC4NK5T070VYSSF8", +		}, +		Page: (&paging.Pager{ +			SinceID: "064Q5D7VG6TPPQ46T09MHJ96FW", +			MaxID:   "064Q5D7VKG5EQ43TYP71B4K6K0", +		}).PageAsc, +	}, +} + +func TestPage(t *testing.T) { +	for _, c := range cases { +		t.Run(c.Name, func(t *testing.T) { +			// Page the input slice. +			out := c.Page(c.Input) + +			// Check paged output is as expected. +			if !slices.Equal(out, c.Expect) { +				t.Errorf("\nreceived=%v\nexpect%v\n", out, c.Expect) +			} +		}) +	} +} diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 2a20ec96e..a613ba485 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -20,7 +20,6 @@ package account  import (  	"context"  	"errors" -	"fmt"  	"net"  	"time" @@ -114,38 +113,38 @@ func (p *Processor) DeleteSelf(ctx context.Context, account *gtsmodel.Account) g  func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *gtsmodel.Account) error {  	user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)  	if err != nil { -		return fmt.Errorf("deleteUserAndTokensForAccount: db error getting user: %w", err) +		return gtserror.Newf("db error getting user: %w", err)  	}  	tokens := []*gtsmodel.Token{}  	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err != nil { -		return fmt.Errorf("deleteUserAndTokensForAccount: db error getting tokens: %w", err) +		return gtserror.Newf("db error getting tokens: %w", err)  	}  	for _, t := range tokens {  		// Delete any OAuth clients associated with this token.  		if err := p.state.DB.DeleteByID(ctx, t.ClientID, &[]*gtsmodel.Client{}); err != nil { -			return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting client: %w", err) +			return gtserror.Newf("db error deleting client: %w", err)  		}  		// Delete any OAuth applications associated with this token.  		if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &[]*gtsmodel.Application{}); err != nil { -			return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting application: %w", err) +			return gtserror.Newf("db error deleting application: %w", err)  		}  		// Delete the token itself.  		if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil { -			return fmt.Errorf("deleteUserAndTokensForAccount: db error deleting token: %w", err) +			return gtserror.Newf("db error deleting token: %w", err)  		}  	}  	columns, err := stubbifyUser(user)  	if err != nil { -		return fmt.Errorf("deleteUserAndTokensForAccount: error stubbifying user: %w", err) +		return gtserror.Newf("error stubbifying user: %w", err)  	}  	if err := p.state.DB.UpdateUser(ctx, user, columns...); err != nil { -		return fmt.Errorf("deleteUserAndTokensForAccount: db error updating user: %w", err) +		return gtserror.Newf("db error updating user: %w", err)  	}  	return nil @@ -160,24 +159,24 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.  	// Delete follows targeting this account.  	followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return fmt.Errorf("deleteAccountFollows: db error getting follows targeting account %s: %w", account.ID, err) +		return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err)  	}  	for _, follow := range followedBy {  		if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil { -			return fmt.Errorf("deleteAccountFollows: db error unfollowing account followedBy: %w", err) +			return gtserror.Newf("db error unfollowing account followedBy: %w", err)  		}  	}  	// Delete follow requests targeting this account.  	followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return fmt.Errorf("deleteAccountFollows: db error getting follow requests targeting account %s: %w", account.ID, err) +		return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err)  	}  	for _, followRequest := range followRequestedBy {  		if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil { -			return fmt.Errorf("deleteAccountFollows: db error unfollowing account followRequestedBy: %w", err) +			return gtserror.Newf("db error unfollowing account followRequestedBy: %w", err)  		}  	} @@ -193,14 +192,14 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.  	// Delete follows originating from this account.  	following, err := p.state.DB.GetAccountFollows(ctx, account.ID)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return fmt.Errorf("deleteAccountFollows: db error getting follows owned by account %s: %w", account.ID, err) +		return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err)  	}  	// For each follow owned by this account, unfollow  	// and process side effects (noop if remote account).  	for _, follow := range following {  		if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil { -			return fmt.Errorf("deleteAccountFollows: db error unfollowing account: %w", err) +			return gtserror.Newf("db error unfollowing account: %w", err)  		}  		if msg := unfollowSideEffects(ctx, account, follow); msg != nil {  			// There was a side effect to process. @@ -211,14 +210,14 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.  	// Delete follow requests originating from this account.  	followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return fmt.Errorf("deleteAccountFollows: db error getting follow requests owned by account %s: %w", account.ID, err) +		return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err)  	}  	// For each follow owned by this account, unfollow  	// and process side effects (noop if remote account).  	for _, followRequest := range followRequesting {  		if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil { -			return fmt.Errorf("deleteAccountFollows: db error unfollowingRequesting account: %w", err) +			return gtserror.Newf("db error unfollowingRequesting account: %w", err)  		}  		// Dummy out a follow so our side effects func @@ -279,7 +278,7 @@ func (p *Processor) unfollowSideEffectsFunc(deletedAccount *gtsmodel.Account) fu  func (p *Processor) deleteAccountBlocks(ctx context.Context, account *gtsmodel.Account) error {  	if err := p.state.DB.DeleteAccountBlocks(ctx, account.ID); err != nil { -		return fmt.Errorf("deleteAccountBlocks: db error deleting account blocks for %s: %w", account.ID, err) +		return gtserror.Newf("db error deleting account blocks for %s: %w", account.ID, err)  	}  	return nil  } @@ -333,7 +332,7 @@ statusLoop:  			// Look for any boosts of this status in DB.  			boosts, err := p.state.DB.GetStatusReblogs(ctx, status)  			if err != nil && !errors.Is(err, db.ErrNoEntries) { -				return fmt.Errorf("deleteAccountStatuses: error fetching status reblogs for %s: %w", status.ID, err) +				return gtserror.Newf("error fetching status reblogs for %s: %w", status.ID, err)  			}  			for _, boost := range boosts { @@ -347,7 +346,7 @@ statusLoop:  							log.WithContext(ctx).WithField("boost", boost).Warnf("no account found with id %s for boost %s", boost.AccountID, boost.ID)  							continue  						} -						return fmt.Errorf("deleteAccountStatuses: error fetching boosted status account for %s: %w", boost.AccountID, err) +						return gtserror.Newf("error fetching boosted status account for %s: %w", boost.AccountID, err)  					}  					// Set account model @@ -505,7 +504,7 @@ func stubbifyUser(user *gtsmodel.User) ([]string, error) {  		return nil, err  	} -	var never = time.Time{} +	never := time.Time{}  	user.EncryptedPassword = string(dummyPassword)  	user.SignUpIP = net.IPv4zero diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go index 644f28ca9..8996dff92 100644 --- a/internal/processing/blocks.go +++ b/internal/processing/blocks.go @@ -19,69 +19,71 @@ package processing  import (  	"context" -	"fmt" -	"net/url" +	"errors"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" -	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror" -	"github.com/superseriousbusiness/gotosocial/internal/oauth" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/paging" +	"github.com/superseriousbusiness/gotosocial/internal/util"  ) -func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { -	accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) -	if err != nil { -		if err == db.ErrNoEntries { -			// there are just no entries -			return &apimodel.BlocksResponse{ -				Accounts: []*apimodel.Account{}, -			}, nil -		} -		// there's an actual error +// BlocksGet ... +func (p *Processor) BlocksGet( +	ctx context.Context, +	requestingAccount *gtsmodel.Account, +	page paging.Pager, +) (*apimodel.PageableResponse, gtserror.WithCode) { +	blocks, err := p.state.DB.GetAccountBlocks(ctx, +		requestingAccount.ID, +		&page, +	) +	if err != nil && !errors.Is(err, db.ErrNoEntries) {  		return nil, gtserror.NewErrorInternalError(err)  	} -	apiAccounts := []*apimodel.Account{} -	for _, a := range accounts { -		apiAccount, err := p.tc.AccountToAPIAccountBlocked(ctx, a) -		if err != nil { -			continue -		} -		apiAccounts = append(apiAccounts, apiAccount) +	// Check for zero length. +	count := len(blocks) +	if len(blocks) == 0 { +		return util.EmptyPageableResponse(), nil  	} -	return p.packageBlocksResponse(apiAccounts, "/api/v1/blocks", nextMaxID, prevMinID, limit) -} +	var ( +		items = make([]interface{}, 0, count) -func (p *Processor) packageBlocksResponse(accounts []*apimodel.Account, path string, nextMaxID string, prevMinID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { -	resp := &apimodel.BlocksResponse{ -		Accounts: []*apimodel.Account{}, -	} -	resp.Accounts = accounts +		// Set next + prev values before API converting +		// so the caller can still page even on error. +		nextMaxIDValue = blocks[count-1].ID +		prevMinIDValue = blocks[0].ID +	) -	// prepare the next and previous links -	if len(accounts) != 0 { -		protocol := config.GetProtocol() -		host := config.GetHost() - -		nextLink := &url.URL{ -			Scheme:   protocol, -			Host:     host, -			Path:     path, -			RawQuery: fmt.Sprintf("limit=%d&max_id=%s", limit, nextMaxID), +	for _, block := range blocks { +		if block.TargetAccount == nil { +			// All models should be populated at this point. +			log.Warnf(ctx, "block target account was nil: %v", err) +			continue  		} -		next := fmt.Sprintf("<%s>; rel=\"next\"", nextLink.String()) -		prevLink := &url.URL{ -			Scheme:   protocol, -			Host:     host, -			Path:     path, -			RawQuery: fmt.Sprintf("limit=%d&min_id=%s", limit, prevMinID), +		// Convert target account to frontend API model. +		account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount) +		if err != nil { +			log.Errorf(ctx, "error converting account to public api account: %v", err) +			continue  		} -		prev := fmt.Sprintf("<%s>; rel=\"prev\"", prevLink.String()) -		resp.LinkHeader = fmt.Sprintf("%s, %s", next, prev) + +		// Append target to return items. +		items = append(items, account)  	} -	return resp, nil +	return util.PackagePageableResponse(util.PageableResponseParams{ +		Items:          items, +		Path:           "/api/v1/blocks", +		NextMaxIDKey:   "max_id", +		PrevMinIDKey:   "since_id", +		NextMaxIDValue: nextMaxIDValue, +		PrevMinIDValue: prevMinIDValue, +		Limit:          page.Limit, +	})  } diff --git a/test/envparsing.sh b/test/envparsing.sh index 8f4372906..b9017d0be 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -25,6 +25,9 @@ EXPECT=$(cat <<"EOF"              "account-note-ttl": 1800000000000,              "account-sweep-freq": 1000000000,              "account-ttl": 10800000000000, +            "block-ids-max-size": 500, +            "block-ids-sweep-freq": 60000000000, +            "block-ids-ttl": 1800000000000,              "block-max-size": 1000,              "block-sweep-freq": 60000000000,              "block-ttl": 1800000000000, @@ -37,7 +40,13 @@ EXPECT=$(cat <<"EOF"              "emoji-max-size": 2000,              "emoji-sweep-freq": 60000000000,              "emoji-ttl": 1800000000000, +            "follow-ids-max-size": 500, +            "follow-ids-sweep-freq": 60000000000, +            "follow-ids-ttl": 1800000000000,              "follow-max-size": 2000, +            "follow-request-ids-max-size": 500, +            "follow-request-ids-sweep-freq": 60000000000, +            "follow-request-ids-ttl": 1800000000000,              "follow-request-max-size": 2000,              "follow-request-sweep-freq": 60000000000,              "follow-request-ttl": 1800000000000, diff --git a/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go b/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go index 623a19910..af108e336 100644 --- a/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go +++ b/vendor/codeberg.org/gruf/go-cache/v3/ttl/ttl.go @@ -479,23 +479,23 @@ func (c *Cache[K, V]) InvalidateAll(keys ...K) (ok bool) {  	kvs = make([]kv[K, V], 0, len(keys))  	c.locked(func() { -		for _, key := range keys { +		for x := range keys {  			var item *Entry[K, V]  			// Check for item in cache -			item, ok = c.Cache.Get(key) +			item, ok = c.Cache.Get(keys[x])  			if !ok { -				return +				continue  			}  			// Append this old value to slice  			kvs = append(kvs, kv[K, V]{ -				K: key, +				K: keys[x],  				V: item.Value,  			})  			// Remove from cache map -			_ = c.Cache.Delete(key) +			_ = c.Cache.Delete(keys[x])  			// Free entry  			c.free(item) @@ -553,6 +553,7 @@ func (c *Cache[K, V]) Cap() (l int) {  	return  } +// locked performs given function within mutex lock (NOTE: UNLOCK IS NOT DEFERRED).  func (c *Cache[K, V]) locked(fn func()) {  	c.Lock()  	fn() diff --git a/vendor/modules.txt b/vendor/modules.txt index 64a310838..006cc3e5d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -13,7 +13,7 @@ codeberg.org/gruf/go-bytesize  # codeberg.org/gruf/go-byteutil v1.1.2  ## explicit; go 1.16  codeberg.org/gruf/go-byteutil -# codeberg.org/gruf/go-cache/v3 v3.4.3 +# codeberg.org/gruf/go-cache/v3 v3.4.4  ## explicit; go 1.19  codeberg.org/gruf/go-cache/v3  codeberg.org/gruf/go-cache/v3/result  | 
