diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/user.go | 72 |
1 files changed, 54 insertions, 18 deletions
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 2800a32e9..f51d1bf74 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -19,12 +19,15 @@ package bundb import ( "context" + "slices" "time" "code.superseriousbusiness.org/gotosocial/internal/gtscontext" "code.superseriousbusiness.org/gotosocial/internal/gtserror" "code.superseriousbusiness.org/gotosocial/internal/gtsmodel" + "code.superseriousbusiness.org/gotosocial/internal/log" "code.superseriousbusiness.org/gotosocial/internal/state" + "code.superseriousbusiness.org/gotosocial/internal/util/xslices" "github.com/uptrace/bun" ) @@ -45,27 +48,47 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, er } func (u *userDB) GetUsersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.User, error) { - var ( - users = make([]*gtsmodel.User, 0, len(ids)) - - // Collect errors instead of - // returning early on any. - errs gtserror.MultiError + // Load all input user IDs via cache loader callback. + users, err := u.state.Caches.DB.User.LoadIDs("ID", + ids, + func(uncached []string) ([]*gtsmodel.User, error) { + // Preallocate expected length of uncached users. + users := make([]*gtsmodel.User, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) user IDs. + if err := u.db.NewSelect(). + Model(&users). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return users, nil + }, ) + if err != nil { + return nil, err + } - for _, id := range ids { - // Attempt to fetch user from DB. - user, err := u.GetUserByID(ctx, id) - if err != nil { - errs.Appendf("error getting user %s: %w", id, err) - continue - } + // Reorder the users by their + // IDs to ensure in correct order. + getID := func(s *gtsmodel.User) string { return s.ID } + xslices.OrderBy(users, ids, getID) - // Append user to return slice. - users = append(users, user) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return users, nil } - return users, errs.Combine() + // Populate all loaded users. + for _, user := range users { + if err := u.PopulateUser(ctx, user); err != nil { + log.Errorf(ctx, "error populating user %s: %v", user.ID, err) + } + } + + return users, nil } func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) { @@ -161,7 +184,11 @@ func (u *userDB) PopulateUser(ctx context.Context, user *gtsmodel.User) error { return errs.Combine() } -func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { +func (u *userDB) GetAllUserIDs(ctx context.Context) ([]string, error) { + if p := u.state.Caches.DB.LocalInstance.UserIDs.Load(); p != nil { + return slices.Clone(*p), nil + } + var userIDs []string // Scan all user IDs into slice. @@ -172,7 +199,16 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { return nil, err } - // Transform user IDs into user slice. + // Store the scanned user IDs in our local cache ptr. + u.state.Caches.DB.LocalInstance.UserIDs.Store(&userIDs) + return userIDs, nil +} + +func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { + userIDs, err := u.GetAllUserIDs(ctx) + if err != nil { + return nil, err + } return u.GetUsersByIDs(ctx, userIDs) } |
