summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/db.go10
-rw-r--r--internal/db/pg/pg.go101
2 files changed, 107 insertions, 4 deletions
diff --git a/internal/db/db.go b/internal/db/db.go
index 31a8ba5d9..1774420c6 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -253,12 +253,20 @@ type DB interface {
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
+ // WhoBoostedStatus returns a slice of accounts who boosted the given status.
+ // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
+ WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
+
// GetHomeTimelineForAccount fetches the account's HOME timeline -- ie., posts and replies from people they *follow*.
// It will use the given filters and try to return as many statuses up to the limit as possible.
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
+ // GetPublicTimelineForAccount fetches the account's PUBLIC timline -- ie., posts and replies that are public.
+ // It will use the given filters and try to return as many statuses as possible up to the limit.
+ GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
+
// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
- GetNotificationsForAccount(accountID string, limit int, maxID string) ([]*gtsmodel.Notification, error)
+ GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error)
/*
USEFUL CONVERSION FUNCTIONS
diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go
index f352404aa..c2bb7032b 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/pg/pg.go
@@ -1115,11 +1115,36 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
return accounts, nil
}
+func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
+ accounts := []*gtsmodel.Account{}
+
+ boosts := []*gtsmodel.Status{}
+ if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return accounts, nil // no rows just means nobody has boosted this status, so that's fine
+ }
+ return nil, err // an actual error has occurred
+ }
+
+ for _, f := range boosts {
+ acc := &gtsmodel.Account{}
+ if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
+ }
+ return nil, err // an actual error has occurred
+ }
+ accounts = append(accounts, acc)
+ }
+ return accounts, nil
+}
+
func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
statuses := []*gtsmodel.Status{}
- q := ps.conn.Model(&statuses).
- ColumnExpr("status.*").
+ q := ps.conn.Model(&statuses)
+
+ q = q.ColumnExpr("status.*").
Join("JOIN follows AS f ON f.target_account_id = status.account_id").
Where("f.account_id = ?", accountID).
Limit(limit).
@@ -1133,6 +1158,68 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
q = q.Where("status.created_at < ?", s.CreatedAt)
}
+ if minID != "" {
+ s := &gtsmodel.Status{}
+ if err := ps.conn.Model(s).Where("id = ?", minID).Select(); err != nil {
+ return nil, err
+ }
+ q = q.Where("status.created_at > ?", s.CreatedAt)
+ }
+
+ if sinceID != "" {
+ s := &gtsmodel.Status{}
+ if err := ps.conn.Model(s).Where("id = ?", sinceID).Select(); err != nil {
+ return nil, err
+ }
+ q = q.Where("status.created_at > ?", s.CreatedAt)
+ }
+
+ err := q.Select()
+ if err != nil {
+ if err != pg.ErrNoRows {
+ return nil, err
+ }
+ }
+
+ return statuses, nil
+}
+
+func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
+ statuses := []*gtsmodel.Status{}
+
+ q := ps.conn.Model(&statuses).
+ Where("visibility = ?", gtsmodel.VisibilityPublic).
+ Limit(limit).
+ Order("created_at DESC")
+
+ if maxID != "" {
+ s := &gtsmodel.Status{}
+ if err := ps.conn.Model(s).Where("id = ?", maxID).Select(); err != nil {
+ return nil, err
+ }
+ q = q.Where("created_at < ?", s.CreatedAt)
+ }
+
+ if minID != "" {
+ s := &gtsmodel.Status{}
+ if err := ps.conn.Model(s).Where("id = ?", minID).Select(); err != nil {
+ return nil, err
+ }
+ q = q.Where("created_at > ?", s.CreatedAt)
+ }
+
+ if sinceID != "" {
+ s := &gtsmodel.Status{}
+ if err := ps.conn.Model(s).Where("id = ?", sinceID).Select(); err != nil {
+ return nil, err
+ }
+ q = q.Where("created_at > ?", s.CreatedAt)
+ }
+
+ if local {
+ q = q.Where("local = ?", local)
+ }
+
err := q.Select()
if err != nil {
if err != pg.ErrNoRows {
@@ -1143,7 +1230,7 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
return statuses, nil
}
-func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string) ([]*gtsmodel.Notification, error) {
+func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
notifications := []*gtsmodel.Notification{}
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
@@ -1156,6 +1243,14 @@ func (ps *postgresService) GetNotificationsForAccount(accountID string, limit in
q = q.Where("created_at < ?", n.CreatedAt)
}
+ if sinceID != "" {
+ n := &gtsmodel.Notification{}
+ if err := ps.conn.Model(n).Where("id = ?", sinceID).Select(); err != nil {
+ return nil, err
+ }
+ q = q.Where("created_at > ?", n.CreatedAt)
+ }
+
if limit != 0 {
q = q.Limit(limit)
}