diff options
Diffstat (limited to 'internal/processing/list')
| -rw-r--r-- | internal/processing/list/get.go | 138 | ||||
| -rw-r--r-- | internal/processing/list/updateentries.go | 177 | ||||
| -rw-r--r-- | internal/processing/list/util.go | 46 | 
3 files changed, 148 insertions, 213 deletions
diff --git a/internal/processing/list/get.go b/internal/processing/list/get.go index cdd3c6e0c..b98678eef 100644 --- a/internal/processing/list/get.go +++ b/internal/processing/list/get.go @@ -20,7 +20,6 @@ package list  import (  	"context"  	"errors" -	"fmt"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,7 +27,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" -	"github.com/superseriousbusiness/gotosocial/internal/util" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  )  // Get returns the api model of one list with the given ID. @@ -49,16 +48,14 @@ func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id strin  // GetAll returns multiple lists created by the given account, sorted by list ID DESC (newest first).  func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.List, gtserror.WithCode) { -	lists, err := p.state.DB.GetListsForAccountID( +	lists, err := p.state.DB.GetListsByAccountID( +  		// Use barebones ctx; no embedded  		// structs necessary for simple GET.  		gtscontext.SetBarebones(ctx),  		account.ID,  	) -	if err != nil { -		if errors.Is(err, db.ErrNoEntries) { -			return nil, nil -		} +	if err != nil && !errors.Is(err, db.ErrNoEntries) {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -68,66 +65,23 @@ func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*a  		if errWithCode != nil {  			return nil, errWithCode  		} -  		apiLists = append(apiLists, apiList)  	}  	return apiLists, nil  } -// GetAllListAccounts returns all accounts that are in the given list, -// owned by the given account. There's no pagination for this endpoint. -// -// See https://docs.joinmastodon.org/methods/lists/#query-parameters: -// -//	Limit: Integer. Maximum number of results. Defaults to 40 accounts. -//	Max 80 accounts. Set to 0 in order to get all accounts without pagination. -func (p *Processor) GetAllListAccounts( -	ctx context.Context, -	account *gtsmodel.Account, -	listID string, -) ([]*apimodel.Account, gtserror.WithCode) { -	// Ensure list exists + is owned by requesting account. -	_, errWithCode := p.getList( -		// Use barebones ctx; no embedded -		// structs necessary for this call. -		gtscontext.SetBarebones(ctx), -		account.ID, -		listID, -	) -	if errWithCode != nil { -		return nil, errWithCode -	} - -	// Get all entries for this list. -	listEntries, err := p.state.DB.GetListEntries(ctx, listID, "", "", "", 0) -	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		err = gtserror.Newf("error getting list entries: %w", err) -		return nil, gtserror.NewErrorInternalError(err) -	} - -	// Extract accounts from list entries + add them to response. -	accounts := make([]*apimodel.Account, 0, len(listEntries)) -	p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) { -		accounts = append(accounts, acc) -	}) - -	return accounts, nil -} -  // GetListAccounts returns accounts that are in the given list, owned by the given account. -// The additional parameters can be used for paging. +// The additional parameters can be used for paging. Nil page param returns all accounts.  func (p *Processor) GetListAccounts(  	ctx context.Context,  	account *gtsmodel.Account,  	listID string, -	maxID string, -	sinceID string, -	minID string, -	limit int, +	page *paging.Page,  ) (*apimodel.PageableResponse, gtserror.WithCode) {  	// Ensure list exists + is owned by requesting account.  	_, errWithCode := p.getList( +  		// Use barebones ctx; no embedded  		// structs necessary for this call.  		gtscontext.SetBarebones(ctx), @@ -138,71 +92,45 @@ func (p *Processor) GetListAccounts(  		return nil, errWithCode  	} -	// To know which accounts are in the list, -	// we need to first get requested list entries. -	listEntries, err := p.state.DB.GetListEntries(ctx, listID, maxID, sinceID, minID, limit) -	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		err = fmt.Errorf("GetListAccounts: error getting list entries: %w", err) +	// Get all accounts contained within list. +	accounts, err := p.state.DB.GetAccountsInList(ctx, +		listID, +		page, +	) +	if err != nil { +		err := gtserror.Newf("db error getting accounts in list: %w", err)  		return nil, gtserror.NewErrorInternalError(err)  	} -	count := len(listEntries) +	// Check for any accounts. +	count := len(accounts)  	if count == 0 { -		// No list entries means no accounts. -		return util.EmptyPageableResponse(), nil +		return paging.EmptyResponse(), nil  	}  	var ( +		// Preallocate expected frontend items.  		items = make([]interface{}, 0, count) -		// Set next + prev values before filtering and API -		// converting, so caller can still page properly. -		nextMaxIDValue = listEntries[count-1].ID -		prevMinIDValue = listEntries[0].ID +		// Set paging low / high IDs. +		lo = accounts[count-1].ID +		hi = accounts[0].ID  	) -	// Extract accounts from list entries + add them to response. -	p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) { -		items = append(items, acc) -	}) - -	return util.PackagePageableResponse(util.PageableResponseParams{ -		Items:          items, -		Path:           "/api/v1/lists/" + listID + "/accounts", -		NextMaxIDValue: nextMaxIDValue, -		PrevMinIDValue: prevMinIDValue, -		Limit:          limit, -	}) -} - -func (p *Processor) accountsFromListEntries( -	ctx context.Context, -	listEntries []*gtsmodel.ListEntry, -	appendAcc func(*apimodel.Account), -) { -	// For each list entry, we want the account it points to. -	// To get this, we need to first get the follow that the -	// list entry pertains to, then extract the target account -	// from that follow. -	// -	// We do paging not by account ID, but by list entry ID. -	for _, listEntry := range listEntries { -		if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil { -			log.Errorf(ctx, "error populating list entry: %v", err) -			continue -		} - -		if err := p.state.DB.PopulateFollow(ctx, listEntry.Follow); err != nil { -			log.Errorf(ctx, "error populating follow: %v", err) -			continue -		} - -		apiAccount, err := p.converter.AccountToAPIAccountPublic(ctx, listEntry.Follow.TargetAccount) +	// Convert accounts to frontend. +	for _, account := range accounts { +		apiAccount, err := p.converter.AccountToAPIAccountPublic(ctx, account)  		if err != nil { -			log.Errorf(ctx, "error converting to public api account: %v", err) +			log.Errorf(ctx, "error converting to api account: %v", err)  			continue  		} - -		appendAcc(apiAccount) +		items = append(items, apiAccount)  	} + +	return paging.PackageResponse(paging.ResponseParams{ +		Items: items, +		Path:  "/api/v1/lists/" + listID + "/accounts", +		Next:  page.Next(lo, hi), +		Prev:  page.Prev(lo, hi), +	}), nil  } diff --git a/internal/processing/list/updateentries.go b/internal/processing/list/updateentries.go index 6dcb951a7..c15248f39 100644 --- a/internal/processing/list/updateentries.go +++ b/internal/processing/list/updateentries.go @@ -23,73 +23,90 @@ import (  	"fmt"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/id" +	"github.com/superseriousbusiness/gotosocial/internal/util"  )  // AddToList adds targetAccountIDs to the given list, if valid.  func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode { +  	// Ensure this list exists + account owns it. -	list, errWithCode := p.getList(ctx, account.ID, listID) +	_, errWithCode := p.getList(ctx, account.ID, listID)  	if errWithCode != nil {  		return errWithCode  	} -	// Pre-assemble list of entries to add. We *could* add these -	// one by one as we iterate through accountIDs, but according -	// to the Mastodon API we should only add them all once we know -	// they're all valid, no partial updates. -	listEntries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs)) +	// Get all follows that are entries in list. +	follows, err := p.state.DB.GetFollowsInList( + +		// We only need barebones model. +		gtscontext.SetBarebones(ctx), +		listID, +		nil, +	) +	if err != nil { +		err := gtserror.Newf("error getting list follows: %w", err) +		return gtserror.NewErrorInternalError(err) +	} + +	// Convert the follows to a hash set containing the target account IDs. +	inFollows := util.ToSetFunc(follows, func(follow *gtsmodel.Follow) string { +		return follow.TargetAccountID +	}) -	// Check each targetAccountID is valid. -	//   - Follow must exist. -	//   - Follow must not already be in the given list. +	// Preallocate a slice of expected list entries, we specifically +	// gather and add all the target accounts in one go rather than +	// individually, to ensure we don't end up with partial updates. +	entries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs)) + +	// Iterate all the account IDs in given target list.  	for _, targetAccountID := range targetAccountIDs { -		// Ensure follow exists. -		follow, err := p.state.DB.GetFollow(ctx, account.ID, targetAccountID) -		if err != nil { -			if errors.Is(err, db.ErrNoEntries) { -				err = fmt.Errorf("you do not follow account %s", targetAccountID) -				return gtserror.NewErrorNotFound(err, err.Error()) -			} -			return gtserror.NewErrorInternalError(err) + +		// Look for follow to target account. +		if inFollows.Has(targetAccountID) { +			text := fmt.Sprintf("account %s is already in list %s", targetAccountID, listID) +			return gtserror.NewErrorUnprocessableEntity(errors.New(text), text)  		} -		// Ensure followID not already in list. -		// This particular call to isInList will -		// never error, so just check entryID. -		entryID, _ := isInList( -			list, -			follow.ID, -			func(listEntry *gtsmodel.ListEntry) (string, error) { -				// Looking for the listEntry follow ID. -				return listEntry.FollowID, nil -			}, +		// Get the actual follow to target. +		follow, err := p.state.DB.GetFollow( + +			// We don't need any sub-models. +			gtscontext.SetBarebones(ctx), +			account.ID, +			targetAccountID,  		) +		if err != nil && !errors.Is(err, db.ErrNoEntries) { +			err := gtserror.Newf("db error getting follow: %w", err) +			return gtserror.NewErrorInternalError(err) +		} -		// Empty entryID means entry with given -		// followID wasn't found in the list. -		if entryID != "" { -			err = fmt.Errorf("account with id %s is already in list %s with entryID %s", targetAccountID, listID, entryID) -			return gtserror.NewErrorUnprocessableEntity(err, err.Error()) +		if follow == nil { +			text := fmt.Sprintf("account %s not currently followed", targetAccountID) +			return gtserror.NewErrorNotFound(errors.New(text), text)  		} -		// Entry wasn't in the list, we can add it. -		listEntries = append(listEntries, >smodel.ListEntry{ +		// Generate new entry for this follow in list. +		entries = append(entries, >smodel.ListEntry{  			ID:       id.NewULID(),  			ListID:   listID,  			FollowID: follow.ID,  		})  	} -	// If we get to here we can assume all -	// entries are valid, so try to add them. -	if err := p.state.DB.PutListEntries(ctx, listEntries); err != nil { -		if errors.Is(err, db.ErrAlreadyExists) { -			err = fmt.Errorf("one or more errors inserting list entries: %w", err) -			return gtserror.NewErrorUnprocessableEntity(err, err.Error()) -		} +	// Add all of the gathered list entries to the database. +	switch err := p.state.DB.PutListEntries(ctx, entries); { +	case err == nil: + +	case errors.Is(err, db.ErrAlreadyExists): +		err := gtserror.Newf("conflict adding list entry: %w", err) +		return gtserror.NewErrorUnprocessableEntity(err) + +	default: +		err := gtserror.Newf("db error inserting list entries: %w", err)  		return gtserror.NewErrorInternalError(err)  	} @@ -97,55 +114,61 @@ func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, li  }  // RemoveFromList removes targetAccountIDs from the given list, if valid. -func (p *Processor) RemoveFromList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode { +func (p *Processor) RemoveFromList( +	ctx context.Context, +	account *gtsmodel.Account, +	listID string, +	targetAccountIDs []string, +) gtserror.WithCode {  	// Ensure this list exists + account owns it. -	list, errWithCode := p.getList(ctx, account.ID, listID) +	_, errWithCode := p.getList(ctx, account.ID, listID)  	if errWithCode != nil {  		return errWithCode  	} -	// For each targetAccountID, we want to check if -	// a follow with that targetAccountID is in the -	// given list. If it is in there, we want to remove -	// it from the list. +	// Get all follows that are entries in list. +	follows, err := p.state.DB.GetFollowsInList( + +		// We only need barebones model. +		gtscontext.SetBarebones(ctx), +		listID, +		nil, +	) +	if err != nil { +		err := gtserror.Newf("error getting list follows: %w", err) +		return gtserror.NewErrorInternalError(err) +	} + +	// Convert the follows to a map keyed by the target account ID. +	followsMap := util.KeyBy(follows, func(follow *gtsmodel.Follow) string { +		return follow.TargetAccountID +	}) + +	var errs gtserror.MultiError + +	// Iterate all the account IDs in given target list.  	for _, targetAccountID := range targetAccountIDs { -		// Check if targetAccountID is -		// on a follow in the list. -		entryID, err := isInList( -			list, -			targetAccountID, -			func(listEntry *gtsmodel.ListEntry) (string, error) { -				// We need the follow so populate this -				// entry, if it's not already populated. -				if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil { -					return "", err -				} - -				// Looking for the list entry targetAccountID. -				return listEntry.Follow.TargetAccountID, nil -			}, -		) -		// Error may be returned here if there was an issue -		// populating the list entry. We only return on proper -		// DB errors, we can just skip no entry errors. -		if err != nil && !errors.Is(err, db.ErrNoEntries) { -			err = fmt.Errorf("error checking if targetAccountID %s was in list %s: %w", targetAccountID, listID, err) -			return gtserror.NewErrorInternalError(err) -		} +		// Look for follow targetting this account. +		follow, ok := followsMap[targetAccountID] -		if entryID == "" { -			// There was an errNoEntries or targetAccount -			// wasn't in this list anyway, so we can skip it. +		if !ok { +			// not in list.  			continue  		} -		// TargetAccount was in the list, remove the entry. -		if err := p.state.DB.DeleteListEntry(ctx, entryID); err != nil && !errors.Is(err, db.ErrNoEntries) { -			err = fmt.Errorf("error removing list entry %s from list %s: %w", entryID, listID, err) -			return gtserror.NewErrorInternalError(err) +		// Delete the list entry containing follow ID in list. +		err := p.state.DB.DeleteListEntry(ctx, listID, follow.ID) +		if err != nil && !errors.Is(err, db.ErrNoEntries) { +			errs.Appendf("error removing list entry: %w", err) +			continue  		}  	} +	// Wrap errors in errWithCode if set. +	if err := errs.Combine(); err != nil { +		return gtserror.NewErrorInternalError(err) +	} +  	return nil  } diff --git a/internal/processing/list/util.go b/internal/processing/list/util.go index c5b1e5081..74d148704 100644 --- a/internal/processing/list/util.go +++ b/internal/processing/list/util.go @@ -33,18 +33,25 @@ import (  // appropriate errors so caller doesn't need to bother.  func (p *Processor) getList(ctx context.Context, accountID string, listID string) (*gtsmodel.List, gtserror.WithCode) {  	list, err := p.state.DB.GetListByID(ctx, listID) -	if err != nil { -		if errors.Is(err, db.ErrNoEntries) { -			// List doesn't seem to exist. -			return nil, gtserror.NewErrorNotFound(err) -		} -		// Real database error. +	if err != nil && !errors.Is(err, db.ErrNoEntries) { +		err := gtserror.Newf("db error getting list: %w", err)  		return nil, gtserror.NewErrorInternalError(err)  	} +	if list == nil { +		const text = "list not found" +		return nil, gtserror.NewErrorNotFound( +			errors.New(text), +			text, +		) +	} +  	if list.AccountID != accountID { -		err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, accountID) -		return nil, gtserror.NewErrorNotFound(err) +		const text = "list not found" +		return nil, gtserror.NewErrorNotFound( +			errors.New("list does not belong to account"), +			text, +		)  	}  	return list, nil @@ -60,26 +67,3 @@ func (p *Processor) apiList(ctx context.Context, list *gtsmodel.List) (*apimodel  	return apiList, nil  } - -// isInList check if thisID is equal to the result of thatID -// for any entry in the given list. -// -// Will return the id of the listEntry if true, empty if false, -// or an error if the result of thatID returns an error. -func isInList( -	list *gtsmodel.List, -	thisID string, -	getThatID func(listEntry *gtsmodel.ListEntry) (string, error), -) (string, error) { -	for _, listEntry := range list.ListEntries { -		thatID, err := getThatID(listEntry) -		if err != nil { -			return "", err -		} - -		if thisID == thatID { -			return listEntry.ID, nil -		} -	} -	return "", nil -}  | 
