summaryrefslogtreecommitdiff
path: root/internal/db/bundb
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb')
-rw-r--r--internal/db/bundb/user.go72
1 files changed, 54 insertions, 18 deletions
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
index 2800a32e9..f51d1bf74 100644
--- a/internal/db/bundb/user.go
+++ b/internal/db/bundb/user.go
@@ -19,12 +19,15 @@ package bundb
import (
"context"
+ "slices"
"time"
"code.superseriousbusiness.org/gotosocial/internal/gtscontext"
"code.superseriousbusiness.org/gotosocial/internal/gtserror"
"code.superseriousbusiness.org/gotosocial/internal/gtsmodel"
+ "code.superseriousbusiness.org/gotosocial/internal/log"
"code.superseriousbusiness.org/gotosocial/internal/state"
+ "code.superseriousbusiness.org/gotosocial/internal/util/xslices"
"github.com/uptrace/bun"
)
@@ -45,27 +48,47 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, er
}
func (u *userDB) GetUsersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.User, error) {
- var (
- users = make([]*gtsmodel.User, 0, len(ids))
-
- // Collect errors instead of
- // returning early on any.
- errs gtserror.MultiError
+ // Load all input user IDs via cache loader callback.
+ users, err := u.state.Caches.DB.User.LoadIDs("ID",
+ ids,
+ func(uncached []string) ([]*gtsmodel.User, error) {
+ // Preallocate expected length of uncached users.
+ users := make([]*gtsmodel.User, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) user IDs.
+ if err := u.db.NewSelect().
+ Model(&users).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+ },
)
+ if err != nil {
+ return nil, err
+ }
- for _, id := range ids {
- // Attempt to fetch user from DB.
- user, err := u.GetUserByID(ctx, id)
- if err != nil {
- errs.Appendf("error getting user %s: %w", id, err)
- continue
- }
+ // Reorder the users by their
+ // IDs to ensure in correct order.
+ getID := func(s *gtsmodel.User) string { return s.ID }
+ xslices.OrderBy(users, ids, getID)
- // Append user to return slice.
- users = append(users, user)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return users, nil
}
- return users, errs.Combine()
+ // Populate all loaded users.
+ for _, user := range users {
+ if err := u.PopulateUser(ctx, user); err != nil {
+ log.Errorf(ctx, "error populating user %s: %v", user.ID, err)
+ }
+ }
+
+ return users, nil
}
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) {
@@ -161,7 +184,11 @@ func (u *userDB) PopulateUser(ctx context.Context, user *gtsmodel.User) error {
return errs.Combine()
}
-func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
+func (u *userDB) GetAllUserIDs(ctx context.Context) ([]string, error) {
+ if p := u.state.Caches.DB.LocalInstance.UserIDs.Load(); p != nil {
+ return slices.Clone(*p), nil
+ }
+
var userIDs []string
// Scan all user IDs into slice.
@@ -172,7 +199,16 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
return nil, err
}
- // Transform user IDs into user slice.
+ // Store the scanned user IDs in our local cache ptr.
+ u.state.Caches.DB.LocalInstance.UserIDs.Store(&userIDs)
+ return userIDs, nil
+}
+
+func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
+ userIDs, err := u.GetAllUserIDs(ctx)
+ if err != nil {
+ return nil, err
+ }
return u.GetUsersByIDs(ctx, userIDs)
}