summaryrefslogtreecommitdiff
path: root/internal/processing/list
diff options
context:
space:
mode:
Diffstat (limited to 'internal/processing/list')
-rw-r--r--internal/processing/list/get.go138
-rw-r--r--internal/processing/list/updateentries.go177
-rw-r--r--internal/processing/list/util.go46
3 files changed, 148 insertions, 213 deletions
diff --git a/internal/processing/list/get.go b/internal/processing/list/get.go
index cdd3c6e0c..b98678eef 100644
--- a/internal/processing/list/get.go
+++ b/internal/processing/list/get.go
@@ -20,7 +20,6 @@ package list
import (
"context"
"errors"
- "fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -28,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
- "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// Get returns the api model of one list with the given ID.
@@ -49,16 +48,14 @@ func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id strin
// GetAll returns multiple lists created by the given account, sorted by list ID DESC (newest first).
func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.List, gtserror.WithCode) {
- lists, err := p.state.DB.GetListsForAccountID(
+ lists, err := p.state.DB.GetListsByAccountID(
+
// Use barebones ctx; no embedded
// structs necessary for simple GET.
gtscontext.SetBarebones(ctx),
account.ID,
)
- if err != nil {
- if errors.Is(err, db.ErrNoEntries) {
- return nil, nil
- }
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -68,66 +65,23 @@ func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*a
if errWithCode != nil {
return nil, errWithCode
}
-
apiLists = append(apiLists, apiList)
}
return apiLists, nil
}
-// GetAllListAccounts returns all accounts that are in the given list,
-// owned by the given account. There's no pagination for this endpoint.
-//
-// See https://docs.joinmastodon.org/methods/lists/#query-parameters:
-//
-// Limit: Integer. Maximum number of results. Defaults to 40 accounts.
-// Max 80 accounts. Set to 0 in order to get all accounts without pagination.
-func (p *Processor) GetAllListAccounts(
- ctx context.Context,
- account *gtsmodel.Account,
- listID string,
-) ([]*apimodel.Account, gtserror.WithCode) {
- // Ensure list exists + is owned by requesting account.
- _, errWithCode := p.getList(
- // Use barebones ctx; no embedded
- // structs necessary for this call.
- gtscontext.SetBarebones(ctx),
- account.ID,
- listID,
- )
- if errWithCode != nil {
- return nil, errWithCode
- }
-
- // Get all entries for this list.
- listEntries, err := p.state.DB.GetListEntries(ctx, listID, "", "", "", 0)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = gtserror.Newf("error getting list entries: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- // Extract accounts from list entries + add them to response.
- accounts := make([]*apimodel.Account, 0, len(listEntries))
- p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) {
- accounts = append(accounts, acc)
- })
-
- return accounts, nil
-}
-
// GetListAccounts returns accounts that are in the given list, owned by the given account.
-// The additional parameters can be used for paging.
+// The additional parameters can be used for paging. Nil page param returns all accounts.
func (p *Processor) GetListAccounts(
ctx context.Context,
account *gtsmodel.Account,
listID string,
- maxID string,
- sinceID string,
- minID string,
- limit int,
+ page *paging.Page,
) (*apimodel.PageableResponse, gtserror.WithCode) {
// Ensure list exists + is owned by requesting account.
_, errWithCode := p.getList(
+
// Use barebones ctx; no embedded
// structs necessary for this call.
gtscontext.SetBarebones(ctx),
@@ -138,71 +92,45 @@ func (p *Processor) GetListAccounts(
return nil, errWithCode
}
- // To know which accounts are in the list,
- // we need to first get requested list entries.
- listEntries, err := p.state.DB.GetListEntries(ctx, listID, maxID, sinceID, minID, limit)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("GetListAccounts: error getting list entries: %w", err)
+ // Get all accounts contained within list.
+ accounts, err := p.state.DB.GetAccountsInList(ctx,
+ listID,
+ page,
+ )
+ if err != nil {
+ err := gtserror.Newf("db error getting accounts in list: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
- count := len(listEntries)
+ // Check for any accounts.
+ count := len(accounts)
if count == 0 {
- // No list entries means no accounts.
- return util.EmptyPageableResponse(), nil
+ return paging.EmptyResponse(), nil
}
var (
+ // Preallocate expected frontend items.
items = make([]interface{}, 0, count)
- // Set next + prev values before filtering and API
- // converting, so caller can still page properly.
- nextMaxIDValue = listEntries[count-1].ID
- prevMinIDValue = listEntries[0].ID
+ // Set paging low / high IDs.
+ lo = accounts[count-1].ID
+ hi = accounts[0].ID
)
- // Extract accounts from list entries + add them to response.
- p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) {
- items = append(items, acc)
- })
-
- return util.PackagePageableResponse(util.PageableResponseParams{
- Items: items,
- Path: "/api/v1/lists/" + listID + "/accounts",
- NextMaxIDValue: nextMaxIDValue,
- PrevMinIDValue: prevMinIDValue,
- Limit: limit,
- })
-}
-
-func (p *Processor) accountsFromListEntries(
- ctx context.Context,
- listEntries []*gtsmodel.ListEntry,
- appendAcc func(*apimodel.Account),
-) {
- // For each list entry, we want the account it points to.
- // To get this, we need to first get the follow that the
- // list entry pertains to, then extract the target account
- // from that follow.
- //
- // We do paging not by account ID, but by list entry ID.
- for _, listEntry := range listEntries {
- if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
- log.Errorf(ctx, "error populating list entry: %v", err)
- continue
- }
-
- if err := p.state.DB.PopulateFollow(ctx, listEntry.Follow); err != nil {
- log.Errorf(ctx, "error populating follow: %v", err)
- continue
- }
-
- apiAccount, err := p.converter.AccountToAPIAccountPublic(ctx, listEntry.Follow.TargetAccount)
+ // Convert accounts to frontend.
+ for _, account := range accounts {
+ apiAccount, err := p.converter.AccountToAPIAccountPublic(ctx, account)
if err != nil {
- log.Errorf(ctx, "error converting to public api account: %v", err)
+ log.Errorf(ctx, "error converting to api account: %v", err)
continue
}
-
- appendAcc(apiAccount)
+ items = append(items, apiAccount)
}
+
+ return paging.PackageResponse(paging.ResponseParams{
+ Items: items,
+ Path: "/api/v1/lists/" + listID + "/accounts",
+ Next: page.Next(lo, hi),
+ Prev: page.Prev(lo, hi),
+ }), nil
}
diff --git a/internal/processing/list/updateentries.go b/internal/processing/list/updateentries.go
index 6dcb951a7..c15248f39 100644
--- a/internal/processing/list/updateentries.go
+++ b/internal/processing/list/updateentries.go
@@ -23,73 +23,90 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
)
// AddToList adds targetAccountIDs to the given list, if valid.
func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode {
+
// Ensure this list exists + account owns it.
- list, errWithCode := p.getList(ctx, account.ID, listID)
+ _, errWithCode := p.getList(ctx, account.ID, listID)
if errWithCode != nil {
return errWithCode
}
- // Pre-assemble list of entries to add. We *could* add these
- // one by one as we iterate through accountIDs, but according
- // to the Mastodon API we should only add them all once we know
- // they're all valid, no partial updates.
- listEntries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs))
+ // Get all follows that are entries in list.
+ follows, err := p.state.DB.GetFollowsInList(
+
+ // We only need barebones model.
+ gtscontext.SetBarebones(ctx),
+ listID,
+ nil,
+ )
+ if err != nil {
+ err := gtserror.Newf("error getting list follows: %w", err)
+ return gtserror.NewErrorInternalError(err)
+ }
+
+ // Convert the follows to a hash set containing the target account IDs.
+ inFollows := util.ToSetFunc(follows, func(follow *gtsmodel.Follow) string {
+ return follow.TargetAccountID
+ })
- // Check each targetAccountID is valid.
- // - Follow must exist.
- // - Follow must not already be in the given list.
+ // Preallocate a slice of expected list entries, we specifically
+ // gather and add all the target accounts in one go rather than
+ // individually, to ensure we don't end up with partial updates.
+ entries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs))
+
+ // Iterate all the account IDs in given target list.
for _, targetAccountID := range targetAccountIDs {
- // Ensure follow exists.
- follow, err := p.state.DB.GetFollow(ctx, account.ID, targetAccountID)
- if err != nil {
- if errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("you do not follow account %s", targetAccountID)
- return gtserror.NewErrorNotFound(err, err.Error())
- }
- return gtserror.NewErrorInternalError(err)
+
+ // Look for follow to target account.
+ if inFollows.Has(targetAccountID) {
+ text := fmt.Sprintf("account %s is already in list %s", targetAccountID, listID)
+ return gtserror.NewErrorUnprocessableEntity(errors.New(text), text)
}
- // Ensure followID not already in list.
- // This particular call to isInList will
- // never error, so just check entryID.
- entryID, _ := isInList(
- list,
- follow.ID,
- func(listEntry *gtsmodel.ListEntry) (string, error) {
- // Looking for the listEntry follow ID.
- return listEntry.FollowID, nil
- },
+ // Get the actual follow to target.
+ follow, err := p.state.DB.GetFollow(
+
+ // We don't need any sub-models.
+ gtscontext.SetBarebones(ctx),
+ account.ID,
+ targetAccountID,
)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ err := gtserror.Newf("db error getting follow: %w", err)
+ return gtserror.NewErrorInternalError(err)
+ }
- // Empty entryID means entry with given
- // followID wasn't found in the list.
- if entryID != "" {
- err = fmt.Errorf("account with id %s is already in list %s with entryID %s", targetAccountID, listID, entryID)
- return gtserror.NewErrorUnprocessableEntity(err, err.Error())
+ if follow == nil {
+ text := fmt.Sprintf("account %s not currently followed", targetAccountID)
+ return gtserror.NewErrorNotFound(errors.New(text), text)
}
- // Entry wasn't in the list, we can add it.
- listEntries = append(listEntries, &gtsmodel.ListEntry{
+ // Generate new entry for this follow in list.
+ entries = append(entries, &gtsmodel.ListEntry{
ID: id.NewULID(),
ListID: listID,
FollowID: follow.ID,
})
}
- // If we get to here we can assume all
- // entries are valid, so try to add them.
- if err := p.state.DB.PutListEntries(ctx, listEntries); err != nil {
- if errors.Is(err, db.ErrAlreadyExists) {
- err = fmt.Errorf("one or more errors inserting list entries: %w", err)
- return gtserror.NewErrorUnprocessableEntity(err, err.Error())
- }
+ // Add all of the gathered list entries to the database.
+ switch err := p.state.DB.PutListEntries(ctx, entries); {
+ case err == nil:
+
+ case errors.Is(err, db.ErrAlreadyExists):
+ err := gtserror.Newf("conflict adding list entry: %w", err)
+ return gtserror.NewErrorUnprocessableEntity(err)
+
+ default:
+ err := gtserror.Newf("db error inserting list entries: %w", err)
return gtserror.NewErrorInternalError(err)
}
@@ -97,55 +114,61 @@ func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, li
}
// RemoveFromList removes targetAccountIDs from the given list, if valid.
-func (p *Processor) RemoveFromList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode {
+func (p *Processor) RemoveFromList(
+ ctx context.Context,
+ account *gtsmodel.Account,
+ listID string,
+ targetAccountIDs []string,
+) gtserror.WithCode {
// Ensure this list exists + account owns it.
- list, errWithCode := p.getList(ctx, account.ID, listID)
+ _, errWithCode := p.getList(ctx, account.ID, listID)
if errWithCode != nil {
return errWithCode
}
- // For each targetAccountID, we want to check if
- // a follow with that targetAccountID is in the
- // given list. If it is in there, we want to remove
- // it from the list.
+ // Get all follows that are entries in list.
+ follows, err := p.state.DB.GetFollowsInList(
+
+ // We only need barebones model.
+ gtscontext.SetBarebones(ctx),
+ listID,
+ nil,
+ )
+ if err != nil {
+ err := gtserror.Newf("error getting list follows: %w", err)
+ return gtserror.NewErrorInternalError(err)
+ }
+
+ // Convert the follows to a map keyed by the target account ID.
+ followsMap := util.KeyBy(follows, func(follow *gtsmodel.Follow) string {
+ return follow.TargetAccountID
+ })
+
+ var errs gtserror.MultiError
+
+ // Iterate all the account IDs in given target list.
for _, targetAccountID := range targetAccountIDs {
- // Check if targetAccountID is
- // on a follow in the list.
- entryID, err := isInList(
- list,
- targetAccountID,
- func(listEntry *gtsmodel.ListEntry) (string, error) {
- // We need the follow so populate this
- // entry, if it's not already populated.
- if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
- return "", err
- }
-
- // Looking for the list entry targetAccountID.
- return listEntry.Follow.TargetAccountID, nil
- },
- )
- // Error may be returned here if there was an issue
- // populating the list entry. We only return on proper
- // DB errors, we can just skip no entry errors.
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("error checking if targetAccountID %s was in list %s: %w", targetAccountID, listID, err)
- return gtserror.NewErrorInternalError(err)
- }
+ // Look for follow targetting this account.
+ follow, ok := followsMap[targetAccountID]
- if entryID == "" {
- // There was an errNoEntries or targetAccount
- // wasn't in this list anyway, so we can skip it.
+ if !ok {
+ // not in list.
continue
}
- // TargetAccount was in the list, remove the entry.
- if err := p.state.DB.DeleteListEntry(ctx, entryID); err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("error removing list entry %s from list %s: %w", entryID, listID, err)
- return gtserror.NewErrorInternalError(err)
+ // Delete the list entry containing follow ID in list.
+ err := p.state.DB.DeleteListEntry(ctx, listID, follow.ID)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ errs.Appendf("error removing list entry: %w", err)
+ continue
}
}
+ // Wrap errors in errWithCode if set.
+ if err := errs.Combine(); err != nil {
+ return gtserror.NewErrorInternalError(err)
+ }
+
return nil
}
diff --git a/internal/processing/list/util.go b/internal/processing/list/util.go
index c5b1e5081..74d148704 100644
--- a/internal/processing/list/util.go
+++ b/internal/processing/list/util.go
@@ -33,18 +33,25 @@ import (
// appropriate errors so caller doesn't need to bother.
func (p *Processor) getList(ctx context.Context, accountID string, listID string) (*gtsmodel.List, gtserror.WithCode) {
list, err := p.state.DB.GetListByID(ctx, listID)
- if err != nil {
- if errors.Is(err, db.ErrNoEntries) {
- // List doesn't seem to exist.
- return nil, gtserror.NewErrorNotFound(err)
- }
- // Real database error.
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ err := gtserror.Newf("db error getting list: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
+ if list == nil {
+ const text = "list not found"
+ return nil, gtserror.NewErrorNotFound(
+ errors.New(text),
+ text,
+ )
+ }
+
if list.AccountID != accountID {
- err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, accountID)
- return nil, gtserror.NewErrorNotFound(err)
+ const text = "list not found"
+ return nil, gtserror.NewErrorNotFound(
+ errors.New("list does not belong to account"),
+ text,
+ )
}
return list, nil
@@ -60,26 +67,3 @@ func (p *Processor) apiList(ctx context.Context, list *gtsmodel.List) (*apimodel
return apiList, nil
}
-
-// isInList check if thisID is equal to the result of thatID
-// for any entry in the given list.
-//
-// Will return the id of the listEntry if true, empty if false,
-// or an error if the result of thatID returns an error.
-func isInList(
- list *gtsmodel.List,
- thisID string,
- getThatID func(listEntry *gtsmodel.ListEntry) (string, error),
-) (string, error) {
- for _, listEntry := range list.ListEntries {
- thatID, err := getThatID(listEntry)
- if err != nil {
- return "", err
- }
-
- if thisID == thatID {
- return listEntry.ID, nil
- }
- }
- return "", nil
-}