diff options
Diffstat (limited to 'internal/processing/list')
-rw-r--r-- | internal/processing/list/get.go | 138 | ||||
-rw-r--r-- | internal/processing/list/updateentries.go | 177 | ||||
-rw-r--r-- | internal/processing/list/util.go | 46 |
3 files changed, 148 insertions, 213 deletions
diff --git a/internal/processing/list/get.go b/internal/processing/list/get.go index cdd3c6e0c..b98678eef 100644 --- a/internal/processing/list/get.go +++ b/internal/processing/list/get.go @@ -20,7 +20,6 @@ package list import ( "context" "errors" - "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,7 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // Get returns the api model of one list with the given ID. @@ -49,16 +48,14 @@ func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id strin // GetAll returns multiple lists created by the given account, sorted by list ID DESC (newest first). func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.List, gtserror.WithCode) { - lists, err := p.state.DB.GetListsForAccountID( + lists, err := p.state.DB.GetListsByAccountID( + // Use barebones ctx; no embedded // structs necessary for simple GET. gtscontext.SetBarebones(ctx), account.ID, ) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - return nil, nil - } + if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.NewErrorInternalError(err) } @@ -68,66 +65,23 @@ func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*a if errWithCode != nil { return nil, errWithCode } - apiLists = append(apiLists, apiList) } return apiLists, nil } -// GetAllListAccounts returns all accounts that are in the given list, -// owned by the given account. There's no pagination for this endpoint. -// -// See https://docs.joinmastodon.org/methods/lists/#query-parameters: -// -// Limit: Integer. Maximum number of results. Defaults to 40 accounts. -// Max 80 accounts. Set to 0 in order to get all accounts without pagination. -func (p *Processor) GetAllListAccounts( - ctx context.Context, - account *gtsmodel.Account, - listID string, -) ([]*apimodel.Account, gtserror.WithCode) { - // Ensure list exists + is owned by requesting account. - _, errWithCode := p.getList( - // Use barebones ctx; no embedded - // structs necessary for this call. - gtscontext.SetBarebones(ctx), - account.ID, - listID, - ) - if errWithCode != nil { - return nil, errWithCode - } - - // Get all entries for this list. - listEntries, err := p.state.DB.GetListEntries(ctx, listID, "", "", "", 0) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - err = gtserror.Newf("error getting list entries: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - // Extract accounts from list entries + add them to response. - accounts := make([]*apimodel.Account, 0, len(listEntries)) - p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) { - accounts = append(accounts, acc) - }) - - return accounts, nil -} - // GetListAccounts returns accounts that are in the given list, owned by the given account. -// The additional parameters can be used for paging. +// The additional parameters can be used for paging. Nil page param returns all accounts. func (p *Processor) GetListAccounts( ctx context.Context, account *gtsmodel.Account, listID string, - maxID string, - sinceID string, - minID string, - limit int, + page *paging.Page, ) (*apimodel.PageableResponse, gtserror.WithCode) { // Ensure list exists + is owned by requesting account. _, errWithCode := p.getList( + // Use barebones ctx; no embedded // structs necessary for this call. gtscontext.SetBarebones(ctx), @@ -138,71 +92,45 @@ func (p *Processor) GetListAccounts( return nil, errWithCode } - // To know which accounts are in the list, - // we need to first get requested list entries. - listEntries, err := p.state.DB.GetListEntries(ctx, listID, maxID, sinceID, minID, limit) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("GetListAccounts: error getting list entries: %w", err) + // Get all accounts contained within list. + accounts, err := p.state.DB.GetAccountsInList(ctx, + listID, + page, + ) + if err != nil { + err := gtserror.Newf("db error getting accounts in list: %w", err) return nil, gtserror.NewErrorInternalError(err) } - count := len(listEntries) + // Check for any accounts. + count := len(accounts) if count == 0 { - // No list entries means no accounts. - return util.EmptyPageableResponse(), nil + return paging.EmptyResponse(), nil } var ( + // Preallocate expected frontend items. items = make([]interface{}, 0, count) - // Set next + prev values before filtering and API - // converting, so caller can still page properly. - nextMaxIDValue = listEntries[count-1].ID - prevMinIDValue = listEntries[0].ID + // Set paging low / high IDs. + lo = accounts[count-1].ID + hi = accounts[0].ID ) - // Extract accounts from list entries + add them to response. - p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) { - items = append(items, acc) - }) - - return util.PackagePageableResponse(util.PageableResponseParams{ - Items: items, - Path: "/api/v1/lists/" + listID + "/accounts", - NextMaxIDValue: nextMaxIDValue, - PrevMinIDValue: prevMinIDValue, - Limit: limit, - }) -} - -func (p *Processor) accountsFromListEntries( - ctx context.Context, - listEntries []*gtsmodel.ListEntry, - appendAcc func(*apimodel.Account), -) { - // For each list entry, we want the account it points to. - // To get this, we need to first get the follow that the - // list entry pertains to, then extract the target account - // from that follow. - // - // We do paging not by account ID, but by list entry ID. - for _, listEntry := range listEntries { - if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil { - log.Errorf(ctx, "error populating list entry: %v", err) - continue - } - - if err := p.state.DB.PopulateFollow(ctx, listEntry.Follow); err != nil { - log.Errorf(ctx, "error populating follow: %v", err) - continue - } - - apiAccount, err := p.converter.AccountToAPIAccountPublic(ctx, listEntry.Follow.TargetAccount) + // Convert accounts to frontend. + for _, account := range accounts { + apiAccount, err := p.converter.AccountToAPIAccountPublic(ctx, account) if err != nil { - log.Errorf(ctx, "error converting to public api account: %v", err) + log.Errorf(ctx, "error converting to api account: %v", err) continue } - - appendAcc(apiAccount) + items = append(items, apiAccount) } + + return paging.PackageResponse(paging.ResponseParams{ + Items: items, + Path: "/api/v1/lists/" + listID + "/accounts", + Next: page.Next(lo, hi), + Prev: page.Prev(lo, hi), + }), nil } diff --git a/internal/processing/list/updateentries.go b/internal/processing/list/updateentries.go index 6dcb951a7..c15248f39 100644 --- a/internal/processing/list/updateentries.go +++ b/internal/processing/list/updateentries.go @@ -23,73 +23,90 @@ import ( "fmt" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/util" ) // AddToList adds targetAccountIDs to the given list, if valid. func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode { + // Ensure this list exists + account owns it. - list, errWithCode := p.getList(ctx, account.ID, listID) + _, errWithCode := p.getList(ctx, account.ID, listID) if errWithCode != nil { return errWithCode } - // Pre-assemble list of entries to add. We *could* add these - // one by one as we iterate through accountIDs, but according - // to the Mastodon API we should only add them all once we know - // they're all valid, no partial updates. - listEntries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs)) + // Get all follows that are entries in list. + follows, err := p.state.DB.GetFollowsInList( + + // We only need barebones model. + gtscontext.SetBarebones(ctx), + listID, + nil, + ) + if err != nil { + err := gtserror.Newf("error getting list follows: %w", err) + return gtserror.NewErrorInternalError(err) + } + + // Convert the follows to a hash set containing the target account IDs. + inFollows := util.ToSetFunc(follows, func(follow *gtsmodel.Follow) string { + return follow.TargetAccountID + }) - // Check each targetAccountID is valid. - // - Follow must exist. - // - Follow must not already be in the given list. + // Preallocate a slice of expected list entries, we specifically + // gather and add all the target accounts in one go rather than + // individually, to ensure we don't end up with partial updates. + entries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs)) + + // Iterate all the account IDs in given target list. for _, targetAccountID := range targetAccountIDs { - // Ensure follow exists. - follow, err := p.state.DB.GetFollow(ctx, account.ID, targetAccountID) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("you do not follow account %s", targetAccountID) - return gtserror.NewErrorNotFound(err, err.Error()) - } - return gtserror.NewErrorInternalError(err) + + // Look for follow to target account. + if inFollows.Has(targetAccountID) { + text := fmt.Sprintf("account %s is already in list %s", targetAccountID, listID) + return gtserror.NewErrorUnprocessableEntity(errors.New(text), text) } - // Ensure followID not already in list. - // This particular call to isInList will - // never error, so just check entryID. - entryID, _ := isInList( - list, - follow.ID, - func(listEntry *gtsmodel.ListEntry) (string, error) { - // Looking for the listEntry follow ID. - return listEntry.FollowID, nil - }, + // Get the actual follow to target. + follow, err := p.state.DB.GetFollow( + + // We don't need any sub-models. + gtscontext.SetBarebones(ctx), + account.ID, + targetAccountID, ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err := gtserror.Newf("db error getting follow: %w", err) + return gtserror.NewErrorInternalError(err) + } - // Empty entryID means entry with given - // followID wasn't found in the list. - if entryID != "" { - err = fmt.Errorf("account with id %s is already in list %s with entryID %s", targetAccountID, listID, entryID) - return gtserror.NewErrorUnprocessableEntity(err, err.Error()) + if follow == nil { + text := fmt.Sprintf("account %s not currently followed", targetAccountID) + return gtserror.NewErrorNotFound(errors.New(text), text) } - // Entry wasn't in the list, we can add it. - listEntries = append(listEntries, >smodel.ListEntry{ + // Generate new entry for this follow in list. + entries = append(entries, >smodel.ListEntry{ ID: id.NewULID(), ListID: listID, FollowID: follow.ID, }) } - // If we get to here we can assume all - // entries are valid, so try to add them. - if err := p.state.DB.PutListEntries(ctx, listEntries); err != nil { - if errors.Is(err, db.ErrAlreadyExists) { - err = fmt.Errorf("one or more errors inserting list entries: %w", err) - return gtserror.NewErrorUnprocessableEntity(err, err.Error()) - } + // Add all of the gathered list entries to the database. + switch err := p.state.DB.PutListEntries(ctx, entries); { + case err == nil: + + case errors.Is(err, db.ErrAlreadyExists): + err := gtserror.Newf("conflict adding list entry: %w", err) + return gtserror.NewErrorUnprocessableEntity(err) + + default: + err := gtserror.Newf("db error inserting list entries: %w", err) return gtserror.NewErrorInternalError(err) } @@ -97,55 +114,61 @@ func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, li } // RemoveFromList removes targetAccountIDs from the given list, if valid. -func (p *Processor) RemoveFromList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode { +func (p *Processor) RemoveFromList( + ctx context.Context, + account *gtsmodel.Account, + listID string, + targetAccountIDs []string, +) gtserror.WithCode { // Ensure this list exists + account owns it. - list, errWithCode := p.getList(ctx, account.ID, listID) + _, errWithCode := p.getList(ctx, account.ID, listID) if errWithCode != nil { return errWithCode } - // For each targetAccountID, we want to check if - // a follow with that targetAccountID is in the - // given list. If it is in there, we want to remove - // it from the list. + // Get all follows that are entries in list. + follows, err := p.state.DB.GetFollowsInList( + + // We only need barebones model. + gtscontext.SetBarebones(ctx), + listID, + nil, + ) + if err != nil { + err := gtserror.Newf("error getting list follows: %w", err) + return gtserror.NewErrorInternalError(err) + } + + // Convert the follows to a map keyed by the target account ID. + followsMap := util.KeyBy(follows, func(follow *gtsmodel.Follow) string { + return follow.TargetAccountID + }) + + var errs gtserror.MultiError + + // Iterate all the account IDs in given target list. for _, targetAccountID := range targetAccountIDs { - // Check if targetAccountID is - // on a follow in the list. - entryID, err := isInList( - list, - targetAccountID, - func(listEntry *gtsmodel.ListEntry) (string, error) { - // We need the follow so populate this - // entry, if it's not already populated. - if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil { - return "", err - } - - // Looking for the list entry targetAccountID. - return listEntry.Follow.TargetAccountID, nil - }, - ) - // Error may be returned here if there was an issue - // populating the list entry. We only return on proper - // DB errors, we can just skip no entry errors. - if err != nil && !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("error checking if targetAccountID %s was in list %s: %w", targetAccountID, listID, err) - return gtserror.NewErrorInternalError(err) - } + // Look for follow targetting this account. + follow, ok := followsMap[targetAccountID] - if entryID == "" { - // There was an errNoEntries or targetAccount - // wasn't in this list anyway, so we can skip it. + if !ok { + // not in list. continue } - // TargetAccount was in the list, remove the entry. - if err := p.state.DB.DeleteListEntry(ctx, entryID); err != nil && !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("error removing list entry %s from list %s: %w", entryID, listID, err) - return gtserror.NewErrorInternalError(err) + // Delete the list entry containing follow ID in list. + err := p.state.DB.DeleteListEntry(ctx, listID, follow.ID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + errs.Appendf("error removing list entry: %w", err) + continue } } + // Wrap errors in errWithCode if set. + if err := errs.Combine(); err != nil { + return gtserror.NewErrorInternalError(err) + } + return nil } diff --git a/internal/processing/list/util.go b/internal/processing/list/util.go index c5b1e5081..74d148704 100644 --- a/internal/processing/list/util.go +++ b/internal/processing/list/util.go @@ -33,18 +33,25 @@ import ( // appropriate errors so caller doesn't need to bother. func (p *Processor) getList(ctx context.Context, accountID string, listID string) (*gtsmodel.List, gtserror.WithCode) { list, err := p.state.DB.GetListByID(ctx, listID) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // List doesn't seem to exist. - return nil, gtserror.NewErrorNotFound(err) - } - // Real database error. + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err := gtserror.Newf("db error getting list: %w", err) return nil, gtserror.NewErrorInternalError(err) } + if list == nil { + const text = "list not found" + return nil, gtserror.NewErrorNotFound( + errors.New(text), + text, + ) + } + if list.AccountID != accountID { - err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, accountID) - return nil, gtserror.NewErrorNotFound(err) + const text = "list not found" + return nil, gtserror.NewErrorNotFound( + errors.New("list does not belong to account"), + text, + ) } return list, nil @@ -60,26 +67,3 @@ func (p *Processor) apiList(ctx context.Context, list *gtsmodel.List) (*apimodel return apiList, nil } - -// isInList check if thisID is equal to the result of thatID -// for any entry in the given list. -// -// Will return the id of the listEntry if true, empty if false, -// or an error if the result of thatID returns an error. -func isInList( - list *gtsmodel.List, - thisID string, - getThatID func(listEntry *gtsmodel.ListEntry) (string, error), -) (string, error) { - for _, listEntry := range list.ListEntries { - thatID, err := getThatID(listEntry) - if err != nil { - return "", err - } - - if thisID == thatID { - return listEntry.ID, nil - } - } - return "", nil -} |