summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/api/swagger.yaml96
-rw-r--r--go.mod1
-rw-r--r--go.sum2
-rw-r--r--internal/api/client/accounts/follow_test.go411
-rw-r--r--internal/api/client/accounts/followers.go62
-rw-r--r--internal/api/client/accounts/following.go62
-rw-r--r--internal/api/client/blocks/blocksget.go43
-rw-r--r--internal/api/client/followrequests/authorize.go2
-rw-r--r--internal/api/client/followrequests/get.go59
-rw-r--r--internal/api/client/followrequests/get_test.go220
-rw-r--r--internal/api/client/followrequests/reject.go2
-rw-r--r--internal/db/bundb/relationship.go101
-rw-r--r--internal/db/bundb/relationship_test.go6
-rw-r--r--internal/db/bundb/timeline.go1
-rw-r--r--internal/db/bundb/timeline_test.go2
-rw-r--r--internal/db/bundb/util.go25
-rw-r--r--internal/db/relationship.go33
-rw-r--r--internal/federation/federatingdb/followers.go2
-rw-r--r--internal/federation/federatingdb/following.go2
-rw-r--r--internal/federation/federatingdb/following_test.go4
-rw-r--r--internal/federation/federatingdb/inbox.go2
-rw-r--r--internal/paging/boundary.go48
-rw-r--r--internal/paging/page.go144
-rw-r--r--internal/paging/page_test.go12
-rw-r--r--internal/paging/parse.go57
-rw-r--r--internal/paging/response.go8
-rw-r--r--internal/paging/response_test.go32
-rw-r--r--internal/paging/util.go6
-rw-r--r--internal/processing/account/account.go6
-rw-r--r--internal/processing/account/account_test.go4
-rw-r--r--internal/processing/account/block.go50
-rw-r--r--internal/processing/account/delete.go8
-rw-r--r--internal/processing/account/follow.go63
-rw-r--r--internal/processing/account/follow_request.go119
-rw-r--r--internal/processing/account/relationships.go166
-rw-r--r--internal/processing/blocks.go86
-rw-r--r--internal/processing/common/account.go.go238
-rw-r--r--internal/processing/common/common.go50
-rw-r--r--internal/processing/common/status.go248
-rw-r--r--internal/processing/followrequest.go123
-rw-r--r--internal/processing/followrequest_test.go76
-rw-r--r--internal/processing/processor.go4
-rw-r--r--internal/timeline/get_test.go2
-rw-r--r--testrig/testmodels.go4
-rw-r--r--vendor/github.com/tomnomnom/linkheader/.gitignore2
-rw-r--r--vendor/github.com/tomnomnom/linkheader/.travis.yml6
-rw-r--r--vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd10
-rw-r--r--vendor/github.com/tomnomnom/linkheader/LICENSE21
-rw-r--r--vendor/github.com/tomnomnom/linkheader/README.mkd35
-rw-r--r--vendor/github.com/tomnomnom/linkheader/main.go151
-rw-r--r--vendor/modules.txt3
51 files changed, 2280 insertions, 640 deletions
diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml
index d9bf40b06..e522cdb2a 100644
--- a/docs/api/swagger.yaml
+++ b/docs/api/swagger.yaml
@@ -3072,6 +3072,13 @@ paths:
- accounts
/api/v1/accounts/{id}/followers:
get:
+ description: |-
+ The next and previous queries can be parsed from the returned Link header.
+ Example:
+
+ ```
+ <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
+ ````
operationId: accountFollowers
parameters:
- description: Account ID.
@@ -3079,6 +3086,25 @@ paths:
name: id
required: true
type: string
+ - description: 'Return only follower accounts *OLDER* than the given max ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
+ in: query
+ name: max_id
+ type: string
+ - description: 'Return only follower accounts *NEWER* than the given since ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
+ in: query
+ name: since_id
+ type: string
+ - description: 'Return only follower accounts *IMMEDIATELY NEWER* than the given min ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
+ in: query
+ name: min_id
+ type: string
+ - default: 40
+ description: Number of follower accounts to return.
+ in: query
+ maximum: 80
+ minimum: 1
+ name: limit
+ type: integer
produces:
- application/json
responses:
@@ -3106,6 +3132,13 @@ paths:
- accounts
/api/v1/accounts/{id}/following:
get:
+ description: |-
+ The next and previous queries can be parsed from the returned Link header.
+ Example:
+
+ ```
+ <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
+ ````
operationId: accountFollowing
parameters:
- description: Account ID.
@@ -3113,6 +3146,25 @@ paths:
name: id
required: true
type: string
+ - description: 'Return only following accounts *OLDER* than the given max ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
+ in: query
+ name: max_id
+ type: string
+ - description: 'Return only following accounts *NEWER* than the given since ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
+ in: query
+ name: since_id
+ type: string
+ - description: 'Return only following accounts *IMMEDIATELY NEWER* than the given min ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
+ in: query
+ name: min_id
+ type: string
+ - default: 40
+ description: Number of following accounts to return.
+ in: query
+ maximum: 80
+ minimum: 1
+ name: limit
+ type: integer
produces:
- application/json
responses:
@@ -4679,19 +4731,25 @@ paths:
````
operationId: blocksGet
parameters:
- - default: 20
- description: Number of blocks to return.
- in: query
- name: limit
- type: integer
- - description: Return only blocks *OLDER* than the given block ID. The block with the specified ID will not be included in the response.
+ - description: 'Return only blocked accounts *OLDER* than the given max ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.'
in: query
name: max_id
type: string
- - description: Return only blocks *NEWER* than the given block ID. The block with the specified ID will not be included in the response.
+ - description: 'Return only blocked accounts *NEWER* than the given since ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.'
in: query
name: since_id
type: string
+ - description: 'Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.'
+ in: query
+ name: min_id
+ type: string
+ - default: 40
+ description: Number of blocked accounts to return.
+ in: query
+ maximum: 80
+ minimum: 1
+ name: limit
+ type: integer
produces:
- application/json
responses:
@@ -4857,12 +4915,32 @@ paths:
- featured_tags
/api/v1/follow_requests:
get:
- description: Accounts will be sorted in order of follow request date descending (newest first).
+ description: |-
+ The next and previous queries can be parsed from the returned Link header.
+ Example:
+
+ ```
+ <https://example.org/api/v1/follow_requests?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/follow_requests?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
+ ````
operationId: getFollowRequests
parameters:
+ - description: 'Return only follow requesting accounts *OLDER* than the given max ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.'
+ in: query
+ name: max_id
+ type: string
+ - description: 'Return only follow requesting accounts *NEWER* than the given since ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.'
+ in: query
+ name: since_id
+ type: string
+ - description: 'Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.'
+ in: query
+ name: min_id
+ type: string
- default: 40
- description: Number of accounts to return.
+ description: Number of follow requesting accounts to return.
in: query
+ maximum: 80
+ minimum: 1
name: limit
type: integer
produces:
diff --git a/go.mod b/go.mod
index 2a6658319..db2a4c3b1 100644
--- a/go.mod
+++ b/go.mod
@@ -46,6 +46,7 @@ require (
github.com/superseriousbusiness/exif-terminator v0.5.0
github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8
github.com/tdewolff/minify/v2 v2.12.9
+ github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
github.com/ulule/limiter/v3 v3.11.2
github.com/uptrace/bun v1.1.15
github.com/uptrace/bun/dialect/pgdialect v1.1.15
diff --git a/go.sum b/go.sum
index 0da102d44..de9eff1ee 100644
--- a/go.sum
+++ b/go.sum
@@ -568,6 +568,8 @@ github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
+github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y=
+github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
diff --git a/internal/api/client/accounts/follow_test.go b/internal/api/client/accounts/follow_test.go
index 9660acd4f..47526da1d 100644
--- a/internal/api/client/accounts/follow_test.go
+++ b/internal/api/client/accounts/follow_test.go
@@ -18,21 +18,33 @@
package accounts_test
import (
+ "context"
+ "encoding/json"
"fmt"
"io/ioutil"
+ "math/rand"
"net/http"
"net/http/httptest"
+ "net/url"
+ "strconv"
"strings"
"testing"
+ "time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
+ "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
+ "github.com/tomnomnom/linkheader"
)
+// random reader according to current-time source seed.
+var randRd = rand.New(rand.NewSource(time.Now().Unix()))
+
type FollowTestSuite struct {
AccountStandardTestSuite
}
@@ -69,6 +81,405 @@ func (suite *FollowTestSuite) TestFollowSelf() {
assert.NoError(suite.T(), err)
}
+func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit2() {
+ suite.testGetFollowersPage(2, "backward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit4() {
+ suite.testGetFollowersPage(4, "backward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit6() {
+ suite.testGetFollowersPage(6, "backward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit2() {
+ suite.testGetFollowersPage(2, "forward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit4() {
+ suite.testGetFollowersPage(4, "forward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit6() {
+ suite.testGetFollowersPage(6, "forward")
+}
+
+func (suite *FollowTestSuite) testGetFollowersPage(limit int, direction string) {
+ ctx := context.Background()
+
+ // The authed local account we are going to use for HTTP requests
+ requestingAccount := suite.testAccounts["local_account_1"]
+ suite.clearAccountRelations(requestingAccount.ID)
+
+ // Get current time.
+ now := time.Now()
+
+ var i int
+
+ for _, targetAccount := range suite.testAccounts {
+ if targetAccount.ID == requestingAccount.ID {
+ // we cannot be our own target...
+ continue
+ }
+
+ // Get next simple ID.
+ id := strconv.Itoa(i)
+ i++
+
+ // put a follow in the database
+ err := suite.db.PutFollow(ctx, &gtsmodel.Follow{
+ ID: id,
+ CreatedAt: now,
+ UpdatedAt: now,
+ URI: fmt.Sprintf("%s/follow/%s", targetAccount.URI, id),
+ AccountID: targetAccount.ID,
+ TargetAccountID: requestingAccount.ID,
+ })
+ suite.NoError(err)
+
+ // Bump now by 1 second.
+ now = now.Add(time.Second)
+ }
+
+ // Get _ALL_ follows we expect to see without any paging (this filters invisible).
+ apiRsp, err := suite.processor.Account().FollowersGet(ctx, requestingAccount, requestingAccount.ID, nil)
+ suite.NoError(err)
+ expectAccounts := apiRsp.Items // interfaced{} account slice
+
+ // Iteratively set
+ // link query string.
+ var query string
+
+ switch direction {
+ case "backward":
+ // Set the starting query to page backward from newest.
+ acc := expectAccounts[0].(*model.Account)
+ newest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID)
+ expectAccounts = expectAccounts[1:]
+ query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID)
+
+ case "forward":
+ // Set the starting query to page forward from the oldest.
+ acc := expectAccounts[len(expectAccounts)-1].(*model.Account)
+ oldest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID)
+ expectAccounts = expectAccounts[:len(expectAccounts)-1]
+ query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID)
+ }
+
+ for p := 0; ; p++ {
+ // Prepare new request for endpoint
+ recorder := httptest.NewRecorder()
+ endpoint := fmt.Sprintf("/api/v1/accounts/%s/followers", requestingAccount.ID)
+ ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "")
+ ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}}
+ ctx.Request.URL.RawQuery = query // setting provided next query value
+
+ // call the handler and check for valid response code.
+ suite.T().Logf("direction=%q page=%d query=%q", direction, p, query)
+ suite.accountsModule.AccountFollowersGETHandler(ctx)
+ suite.Equal(http.StatusOK, recorder.Code)
+
+ var accounts []*model.Account
+
+ // Decode response body into API account models
+ result := recorder.Result()
+ dec := json.NewDecoder(result.Body)
+ err := dec.Decode(&accounts)
+ suite.NoError(err)
+ _ = result.Body.Close()
+
+ var (
+
+ // start provides the starting index for loop in accounts.
+ start func([]*model.Account) int
+
+ // iter performs the loop iter step with index.
+ iter func(int) int
+
+ // check performs the loop conditional check against index and accounts.
+ check func(int, []*model.Account) bool
+
+ // expect pulls the next account to check against from expectAccounts.
+ expect func([]interface{}) interface{}
+
+ // trunc drops the last checked account from expectAccounts.
+ trunc func([]interface{}) []interface{}
+ )
+
+ switch direction {
+ case "backward":
+ // When paging backwards (DESC) we:
+ // - iter from end of received accounts
+ // - iterate backward through received accounts
+ // - stop when we reach last index of received accounts
+ // - compare each received with the first index of expected accounts
+ // - after each compare, drop the first index of expected accounts
+ start = func([]*model.Account) int { return 0 }
+ iter = func(i int) int { return i + 1 }
+ check = func(idx int, i []*model.Account) bool { return idx < len(i) }
+ expect = func(i []interface{}) interface{} { return i[0] }
+ trunc = func(i []interface{}) []interface{} { return i[1:] }
+
+ case "forward":
+ // When paging forwards (ASC) we:
+ // - iter from end of received accounts
+ // - iterate backward through received accounts
+ // - stop when we reach first index of received accounts
+ // - compare each received with the last index of expected accounts
+ // - after each compare, drop the last index of expected accounts
+ start = func(i []*model.Account) int { return len(i) - 1 }
+ iter = func(i int) int { return i - 1 }
+ check = func(idx int, i []*model.Account) bool { return idx >= 0 }
+ expect = func(i []interface{}) interface{} { return i[len(i)-1] }
+ trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] }
+ }
+
+ for i := start(accounts); check(i, accounts); i = iter(i) {
+ // Get next expected account.
+ iface := expect(expectAccounts)
+
+ // Check that expected account matches received.
+ expectAccID := iface.(*model.Account).ID
+ receivdAccID := accounts[i].ID
+ suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p)
+
+ // Drop checked from expected accounts.
+ expectAccounts = trunc(expectAccounts)
+ }
+
+ if len(expectAccounts) == 0 {
+ // Reached end.
+ break
+ }
+
+ // Parse response link header values.
+ values := result.Header.Values("Link")
+ links := linkheader.ParseMultiple(values)
+ filteredLinks := links.FilterByRel("next")
+ suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p)
+
+ // A ref link header was set.
+ link := filteredLinks[0]
+
+ // Parse URI from URI string.
+ uri, err := url.Parse(link.URL)
+ suite.NoError(err)
+
+ // Set next raw query value.
+ query = uri.RawQuery
+ }
+}
+
+func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit2() {
+ suite.testGetFollowingPage(2, "backward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit4() {
+ suite.testGetFollowingPage(4, "backward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit6() {
+ suite.testGetFollowingPage(6, "backward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit2() {
+ suite.testGetFollowingPage(2, "forward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit4() {
+ suite.testGetFollowingPage(4, "forward")
+}
+
+func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit6() {
+ suite.testGetFollowingPage(6, "forward")
+}
+
+func (suite *FollowTestSuite) testGetFollowingPage(limit int, direction string) {
+ ctx := context.Background()
+
+ // The authed local account we are going to use for HTTP requests
+ requestingAccount := suite.testAccounts["local_account_1"]
+ suite.clearAccountRelations(requestingAccount.ID)
+
+ // Get current time.
+ now := time.Now()
+
+ var i int
+
+ for _, targetAccount := range suite.testAccounts {
+ if targetAccount.ID == requestingAccount.ID {
+ // we cannot be our own target...
+ continue
+ }
+
+ // Get next simple ID.
+ id := strconv.Itoa(i)
+ i++
+
+ // put a follow in the database
+ err := suite.db.PutFollow(ctx, &gtsmodel.Follow{
+ ID: id,
+ CreatedAt: now,
+ UpdatedAt: now,
+ URI: fmt.Sprintf("%s/follow/%s", requestingAccount.URI, id),
+ AccountID: requestingAccount.ID,
+ TargetAccountID: targetAccount.ID,
+ })
+ suite.NoError(err)
+
+ // Bump now by 1 second.
+ now = now.Add(time.Second)
+ }
+
+ // Get _ALL_ follows we expect to see without any paging (this filters invisible).
+ apiRsp, err := suite.processor.Account().FollowingGet(ctx, requestingAccount, requestingAccount.ID, nil)
+ suite.NoError(err)
+ expectAccounts := apiRsp.Items // interfaced{} account slice
+
+ // Iteratively set
+ // link query string.
+ var query string
+
+ switch direction {
+ case "backward":
+ // Set the starting query to page backward from newest.
+ acc := expectAccounts[0].(*model.Account)
+ newest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID)
+ expectAccounts = expectAccounts[1:]
+ query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID)
+
+ case "forward":
+ // Set the starting query to page forward from the oldest.
+ acc := expectAccounts[len(expectAccounts)-1].(*model.Account)
+ oldest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID)
+ expectAccounts = expectAccounts[:len(expectAccounts)-1]
+ query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID)
+ }
+
+ for p := 0; ; p++ {
+ // Prepare new request for endpoint
+ recorder := httptest.NewRecorder()
+ endpoint := fmt.Sprintf("/api/v1/accounts/%s/following", requestingAccount.ID)
+ ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "")
+ ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}}
+ ctx.Request.URL.RawQuery = query // setting provided next query value
+
+ // call the handler and check for valid response code.
+ suite.T().Logf("direction=%q page=%d query=%q", direction, p, query)
+ suite.accountsModule.AccountFollowingGETHandler(ctx)
+ suite.Equal(http.StatusOK, recorder.Code)
+
+ var accounts []*model.Account
+
+ // Decode response body into API account models
+ result := recorder.Result()
+ dec := json.NewDecoder(result.Body)
+ err := dec.Decode(&accounts)
+ suite.NoError(err)
+ _ = result.Body.Close()
+
+ var (
+ // start provides the starting index for loop in accounts.
+ start func([]*model.Account) int
+
+ // iter performs the loop iter step with index.
+ iter func(int) int
+
+ // check performs the loop conditional check against index and accounts.
+ check func(int, []*model.Account) bool
+
+ // expect pulls the next account to check against from expectAccounts.
+ expect func([]interface{}) interface{}
+
+ // trunc drops the last checked account from expectAccounts.
+ trunc func([]interface{}) []interface{}
+ )
+
+ switch direction {
+ case "backward":
+ // When paging backwards (DESC) we:
+ // - iter from end of received accounts
+ // - iterate backward through received accounts
+ // - stop when we reach last index of received accounts
+ // - compare each received with the first index of expected accounts
+ // - after each compare, drop the first index of expected accounts
+ start = func([]*model.Account) int { return 0 }
+ iter = func(i int) int { return i + 1 }
+ check = func(idx int, i []*model.Account) bool { return idx < len(i) }
+ expect = func(i []interface{}) interface{} { return i[0] }
+ trunc = func(i []interface{}) []interface{} { return i[1:] }
+
+ case "forward":
+ // When paging forwards (ASC) we:
+ // - iter from end of received accounts
+ // - iterate backward through received accounts
+ // - stop when we reach first index of received accounts
+ // - compare each received with the last index of expected accounts
+ // - after each compare, drop the last index of expected accounts
+ start = func(i []*model.Account) int { return len(i) - 1 }
+ iter = func(i int) int { return i - 1 }
+ check = func(idx int, i []*model.Account) bool { return idx >= 0 }
+ expect = func(i []interface{}) interface{} { return i[len(i)-1] }
+ trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] }
+ }
+
+ for i := start(accounts); check(i, accounts); i = iter(i) {
+ // Get next expected account.
+ iface := expect(expectAccounts)
+
+ // Check that expected account matches received.
+ expectAccID := iface.(*model.Account).ID
+ receivdAccID := accounts[i].ID
+ suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p)
+
+ // Drop checked from expected accounts.
+ expectAccounts = trunc(expectAccounts)
+ }
+
+ if len(expectAccounts) == 0 {
+ // Reached end.
+ break
+ }
+
+ // Parse response link header values.
+ values := result.Header.Values("Link")
+ links := linkheader.ParseMultiple(values)
+ filteredLinks := links.FilterByRel("next")
+ suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p)
+
+ // A ref link header was set.
+ link := filteredLinks[0]
+
+ // Parse URI from URI string.
+ uri, err := url.Parse(link.URL)
+ suite.NoError(err)
+
+ // Set next raw query value.
+ query = uri.RawQuery
+ }
+}
+
+func (suite *FollowTestSuite) clearAccountRelations(id string) {
+ // Esnure no account blocks exist between accounts.
+ _ = suite.db.DeleteAccountBlocks(
+ context.Background(),
+ id,
+ )
+
+ // Ensure no account follows exist between accounts.
+ _ = suite.db.DeleteAccountFollows(
+ context.Background(),
+ id,
+ )
+
+ // Ensure no account follow_requests exist between accounts.
+ _ = suite.db.DeleteAccountFollowRequests(
+ context.Background(),
+ id,
+ )
+}
+
func TestFollowTestSuite(t *testing.T) {
suite.Run(t, new(FollowTestSuite))
}
diff --git a/internal/api/client/accounts/followers.go b/internal/api/client/accounts/followers.go
index 96b034877..2448bc50a 100644
--- a/internal/api/client/accounts/followers.go
+++ b/internal/api/client/accounts/followers.go
@@ -25,12 +25,20 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// AccountFollowersGETHandler swagger:operation GET /api/v1/accounts/{id}/followers accountFollowers
//
// See followers of account with given id.
//
+// The next and previous queries can be parsed from the returned Link header.
+// Example:
+//
+// ```
+// <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
+// ````
+//
// ---
// tags:
// - accounts
@@ -45,6 +53,42 @@ import (
// description: Account ID.
// in: path
// required: true
+// -
+// name: max_id
+// type: string
+// description: >-
+// Return only follower accounts *OLDER* than the given max ID.
+// The follower account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: since_id
+// type: string
+// description: >-
+// Return only follower accounts *NEWER* than the given since ID.
+// The follower account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: min_id
+// type: string
+// description: >-
+// Return only follower accounts *IMMEDIATELY NEWER* than the given min ID.
+// The follower account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: limit
+// type: integer
+// description: Number of follower accounts to return.
+// default: 40
+// minimum: 1
+// maximum: 80
+// in: query
+// required: false
//
// security:
// - OAuth2 Bearer:
@@ -87,11 +131,25 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) {
return
}
- followers, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID)
+ page, errWithCode := paging.ParseIDPage(c,
+ 1, // min limit
+ 80, // max limit
+ 40, // default limit
+ )
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
+ }
+
+ resp, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID, page)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
- c.JSON(http.StatusOK, followers)
+ if resp.LinkHeader != "" {
+ c.Header("Link", resp.LinkHeader)
+ }
+
+ c.JSON(http.StatusOK, resp.Items)
}
diff --git a/internal/api/client/accounts/following.go b/internal/api/client/accounts/following.go
index 122a12a6e..d106d6ea6 100644
--- a/internal/api/client/accounts/following.go
+++ b/internal/api/client/accounts/following.go
@@ -25,12 +25,20 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// AccountFollowingGETHandler swagger:operation GET /api/v1/accounts/{id}/following accountFollowing
//
// See accounts followed by given account id.
//
+// The next and previous queries can be parsed from the returned Link header.
+// Example:
+//
+// ```
+// <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
+// ````
+//
// ---
// tags:
// - accounts
@@ -45,6 +53,42 @@ import (
// description: Account ID.
// in: path
// required: true
+// -
+// name: max_id
+// type: string
+// description: >-
+// Return only following accounts *OLDER* than the given max ID.
+// The following account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: since_id
+// type: string
+// description: >-
+// Return only following accounts *NEWER* than the given since ID.
+// The following account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: min_id
+// type: string
+// description: >-
+// Return only following accounts *IMMEDIATELY NEWER* than the given min ID.
+// The following account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: limit
+// type: integer
+// description: Number of following accounts to return.
+// default: 40
+// minimum: 1
+// maximum: 80
+// in: query
+// required: false
//
// security:
// - OAuth2 Bearer:
@@ -87,11 +131,25 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) {
return
}
- following, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID)
+ page, errWithCode := paging.ParseIDPage(c,
+ 1, // min limit
+ 80, // max limit
+ 40, // default limit
+ )
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
+ }
+
+ resp, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID, page)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
- c.JSON(http.StatusOK, following)
+ if resp.LinkHeader != "" {
+ c.Header("Link", resp.LinkHeader)
+ }
+
+ c.JSON(http.StatusOK, resp.Items)
}
diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go
index dcf70e9cf..0761160bc 100644
--- a/internal/api/client/blocks/blocksget.go
+++ b/internal/api/client/blocks/blocksget.go
@@ -47,25 +47,40 @@ import (
//
// parameters:
// -
-// name: limit
-// type: integer
-// description: Number of blocks to return.
-// default: 20
-// in: query
-// -
// name: max_id
// type: string
// description: >-
-// Return only blocks *OLDER* than the given block ID.
-// The block with the specified ID will not be included in the response.
+// Return only blocked accounts *OLDER* than the given max ID.
+// The blocked account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal block, NOT any of the returned accounts.
// in: query
+// required: false
// -
// name: since_id
// type: string
// description: >-
-// Return only blocks *NEWER* than the given block ID.
-// The block with the specified ID will not be included in the response.
+// Return only blocked accounts *NEWER* than the given since ID.
+// The blocked account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal block, NOT any of the returned accounts.
+// in: query
+// -
+// name: min_id
+// type: string
+// description: >-
+// Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID.
+// The blocked account with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal block, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: limit
+// type: integer
+// description: Number of blocked accounts to return.
+// default: 40
+// minimum: 1
+// maximum: 80
// in: query
+// required: false
//
// security:
// - OAuth2 Bearer:
@@ -104,16 +119,16 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
}
page, errWithCode := paging.ParseIDPage(c,
- 1, // min limit
- 100, // max limit
- 20, // default limit
+ 1, // min limit
+ 80, // max limit
+ 40, // default limit
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
- resp, errWithCode := m.processor.BlocksGet(
+ resp, errWithCode := m.processor.Account().BlocksGet(
c.Request.Context(),
authed.Account,
page,
diff --git a/internal/api/client/followrequests/authorize.go b/internal/api/client/followrequests/authorize.go
index 7a19c0f86..707d3db26 100644
--- a/internal/api/client/followrequests/authorize.go
+++ b/internal/api/client/followrequests/authorize.go
@@ -87,7 +87,7 @@ func (m *Module) FollowRequestAuthorizePOSTHandler(c *gin.Context) {
return
}
- relationship, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID)
+ relationship, errWithCode := m.processor.Account().FollowRequestAccept(c.Request.Context(), authed.Account, originAccountID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
diff --git a/internal/api/client/followrequests/get.go b/internal/api/client/followrequests/get.go
index 628e3b807..af2f3741c 100644
--- a/internal/api/client/followrequests/get.go
+++ b/internal/api/client/followrequests/get.go
@@ -24,12 +24,19 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// FollowRequestGETHandler swagger:operation GET /api/v1/follow_requests getFollowRequests
//
// Get an array of accounts that have requested to follow you.
-// Accounts will be sorted in order of follow request date descending (newest first).
+//
+// The next and previous queries can be parsed from the returned Link header.
+// Example:
+//
+// ```
+// <https://example.org/api/v1/follow_requests?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/follow_requests?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
+// ````
//
// ---
// tags:
@@ -40,11 +47,41 @@ import (
//
// parameters:
// -
+// name: max_id
+// type: string
+// description: >-
+// Return only follow requesting accounts *OLDER* than the given max ID.
+// The follow requester with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow request, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: since_id
+// type: string
+// description: >-
+// Return only follow requesting accounts *NEWER* than the given since ID.
+// The follow requester with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow request, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
+// name: min_id
+// type: string
+// description: >-
+// Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID.
+// The follow requester with the specified ID will not be included in the response.
+// NOTE: the ID is of the internal follow request, NOT any of the returned accounts.
+// in: query
+// required: false
+// -
// name: limit
// type: integer
-// description: Number of accounts to return.
+// description: Number of follow requesting accounts to return.
// default: 40
+// minimum: 1
+// maximum: 80
// in: query
+// required: false
//
// security:
// - OAuth2 Bearer:
@@ -82,11 +119,25 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) {
return
}
- accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed)
+ page, errWithCode := paging.ParseIDPage(c,
+ 1, // min limit
+ 80, // max limit
+ 40, // default limit
+ )
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
- c.JSON(http.StatusOK, accts)
+ resp, errWithCode := m.processor.Account().FollowRequestsGet(c.Request.Context(), authed.Account, page)
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
+ }
+
+ if resp.LinkHeader != "" {
+ c.Header("Link", resp.LinkHeader)
+ }
+
+ c.JSON(http.StatusOK, resp.Items)
}
diff --git a/internal/api/client/followrequests/get_test.go b/internal/api/client/followrequests/get_test.go
index d95c9878c..f2fa832a1 100644
--- a/internal/api/client/followrequests/get_test.go
+++ b/internal/api/client/followrequests/get_test.go
@@ -22,17 +22,25 @@ import (
"context"
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
+ "math/rand"
"net/http"
"net/http/httptest"
+ "net/url"
+ "strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/tomnomnom/linkheader"
)
+// random reader according to current-time source seed.
+var randRd = rand.New(rand.NewSource(time.Now().Unix()))
+
type GetTestSuite struct {
FollowRequestStandardTestSuite
}
@@ -68,7 +76,7 @@ func (suite *GetTestSuite) TestGet() {
defer result.Body.Close()
// check the response
- b, err := ioutil.ReadAll(result.Body)
+ b, err := io.ReadAll(result.Body)
assert.NoError(suite.T(), err)
dst := new(bytes.Buffer)
err = json.Indent(dst, b, "", " ")
@@ -99,6 +107,214 @@ func (suite *GetTestSuite) TestGet() {
]`, dst.String())
}
+func (suite *GetTestSuite) TestGetPageBackwardLimit2() {
+ suite.testGetPage(2, "backward")
+}
+
+func (suite *GetTestSuite) TestGetPageBackwardLimit4() {
+ suite.testGetPage(4, "backward")
+}
+
+func (suite *GetTestSuite) TestGetPageBackwardLimit6() {
+ suite.testGetPage(6, "backward")
+}
+
+func (suite *GetTestSuite) TestGetPageForwardLimit2() {
+ suite.testGetPage(2, "forward")
+}
+
+func (suite *GetTestSuite) TestGetPageForwardLimit4() {
+ suite.testGetPage(4, "forward")
+}
+
+func (suite *GetTestSuite) TestGetPageForwardLimit6() {
+ suite.testGetPage(6, "forward")
+}
+
+func (suite *GetTestSuite) testGetPage(limit int, direction string) {
+ ctx := context.Background()
+
+ // The authed local account we are going to use for HTTP requests
+ requestingAccount := suite.testAccounts["local_account_1"]
+ suite.clearAccountRelations(requestingAccount.ID)
+
+ // Get current time.
+ now := time.Now()
+
+ var i int
+
+ for _, targetAccount := range suite.testAccounts {
+ if targetAccount.ID == requestingAccount.ID {
+ // we cannot be our own target...
+ continue
+ }
+
+ // Get next simple ID.
+ id := strconv.Itoa(i)
+ i++
+
+ // put a follow request in the database
+ err := suite.db.PutFollowRequest(ctx, &gtsmodel.FollowRequest{
+ ID: id,
+ CreatedAt: now,
+ UpdatedAt: now,
+ URI: fmt.Sprintf("%s/follow/%s", targetAccount.URI, id),
+ AccountID: targetAccount.ID,
+ TargetAccountID: requestingAccount.ID,
+ })
+ suite.NoError(err)
+
+ // Bump now by 1 second.
+ now = now.Add(time.Second)
+ }
+
+ // Get _ALL_ follow requests we expect to see without any paging (this filters invisible).
+ apiRsp, err := suite.processor.Account().FollowRequestsGet(ctx, requestingAccount, nil)
+ suite.NoError(err)
+ expectAccounts := apiRsp.Items // interfaced{} account slice
+
+ // Iteratively set
+ // link query string.
+ var query string
+
+ switch direction {
+ case "backward":
+ // Set the starting query to page backward from newest.
+ acc := expectAccounts[0].(*model.Account)
+ newest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID)
+ expectAccounts = expectAccounts[1:]
+ query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID)
+
+ case "forward":
+ // Set the starting query to page forward from the oldest.
+ acc := expectAccounts[len(expectAccounts)-1].(*model.Account)
+ oldest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID)
+ expectAccounts = expectAccounts[:len(expectAccounts)-1]
+ query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID)
+ }
+
+ for p := 0; ; p++ {
+ // Prepare new request for endpoint
+ recorder := httptest.NewRecorder()
+ ctx := suite.newContext(recorder, http.MethodGet, []byte{}, "/api/v1/follow_requests", "")
+ ctx.Request.URL.RawQuery = query // setting provided next query value
+
+ // call the handler and check for valid response code.
+ suite.T().Logf("direction=%q page=%d query=%q", direction, p, query)
+ suite.followRequestModule.FollowRequestGETHandler(ctx)
+ suite.Equal(http.StatusOK, recorder.Code)
+
+ var accounts []*model.Account
+
+ // Decode response body into API account models
+ result := recorder.Result()
+ dec := json.NewDecoder(result.Body)
+ err := dec.Decode(&accounts)
+ suite.NoError(err)
+ _ = result.Body.Close()
+
+ var (
+
+ // start provides the starting index for loop in accounts.
+ start func([]*model.Account) int
+
+ // iter performs the loop iter step with index.
+ iter func(int) int
+
+ // check performs the loop conditional check against index and accounts.
+ check func(int, []*model.Account) bool
+
+ // expect pulls the next account to check against from expectAccounts.
+ expect func([]interface{}) interface{}
+
+ // trunc drops the last checked account from expectAccounts.
+ trunc func([]interface{}) []interface{}
+ )
+
+ switch direction {
+ case "backward":
+ // When paging backwards (DESC) we:
+ // - iter from end of received accounts
+ // - iterate backward through received accounts
+ // - stop when we reach last index of received accounts
+ // - compare each received with the first index of expected accounts
+ // - after each compare, drop the first index of expected accounts
+ start = func([]*model.Account) int { return 0 }
+ iter = func(i int) int { return i + 1 }
+ check = func(idx int, i []*model.Account) bool { return idx < len(i) }
+ expect = func(i []interface{}) interface{} { return i[0] }
+ trunc = func(i []interface{}) []interface{} { return i[1:] }
+
+ case "forward":
+ // When paging forwards (ASC) we:
+ // - iter from end of received accounts
+ // - iterate backward through received accounts
+ // - stop when we reach first index of received accounts
+ // - compare each received with the last index of expected accounts
+ // - after each compare, drop the last index of expected accounts
+ start = func(i []*model.Account) int { return len(i) - 1 }
+ iter = func(i int) int { return i - 1 }
+ check = func(idx int, i []*model.Account) bool { return idx >= 0 }
+ expect = func(i []interface{}) interface{} { return i[len(i)-1] }
+ trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] }
+ }
+
+ for i := start(accounts); check(i, accounts); i = iter(i) {
+ // Get next expected account.
+ iface := expect(expectAccounts)
+
+ // Check that expected account matches received.
+ expectAccID := iface.(*model.Account).ID
+ receivdAccID := accounts[i].ID
+ suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p)
+
+ // Drop checked from expected accounts.
+ expectAccounts = trunc(expectAccounts)
+ }
+
+ if len(expectAccounts) == 0 {
+ // Reached end.
+ break
+ }
+
+ // Parse response link header values.
+ values := result.Header.Values("Link")
+ links := linkheader.ParseMultiple(values)
+ filteredLinks := links.FilterByRel("next")
+ suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p)
+
+ // A ref link header was set.
+ link := filteredLinks[0]
+
+ // Parse URI from URI string.
+ uri, err := url.Parse(link.URL)
+ suite.NoError(err)
+
+ // Set next raw query value.
+ query = uri.RawQuery
+ }
+}
+
+func (suite *GetTestSuite) clearAccountRelations(id string) {
+ // Esnure no account blocks exist between accounts.
+ _ = suite.db.DeleteAccountBlocks(
+ context.Background(),
+ id,
+ )
+
+ // Ensure no account follows exist between accounts.
+ _ = suite.db.DeleteAccountFollows(
+ context.Background(),
+ id,
+ )
+
+ // Ensure no account follow_requests exist between accounts.
+ _ = suite.db.DeleteAccountFollowRequests(
+ context.Background(),
+ id,
+ )
+}
+
func TestGetTestSuite(t *testing.T) {
suite.Run(t, &GetTestSuite{})
}
diff --git a/internal/api/client/followrequests/reject.go b/internal/api/client/followrequests/reject.go
index 3f75facba..6514a615e 100644
--- a/internal/api/client/followrequests/reject.go
+++ b/internal/api/client/followrequests/reject.go
@@ -85,7 +85,7 @@ func (m *Module) FollowRequestRejectPOSTHandler(c *gin.Context) {
return
}
- relationship, errWithCode := m.processor.FollowRequestReject(c.Request.Context(), authed, originAccountID)
+ relationship, errWithCode := m.processor.Account().FollowRequestReject(c.Request.Context(), authed.Account, originAccountID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index f1bdcf52b..822e697c1 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -102,8 +102,8 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
return &rel, nil
}
-func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- followIDs, err := r.getAccountFollowIDs(ctx, accountID)
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
+ followIDs, err := r.getAccountFollowIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
@@ -118,8 +118,8 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s
return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- followerIDs, err := r.getAccountFollowerIDs(ctx, accountID)
+func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
+ followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
@@ -134,16 +134,16 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID
return r.GetFollowsByIDs(ctx, followerIDs)
}
-func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
- followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID)
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
+ followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
-func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
- followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID)
+func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
+ followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
@@ -151,39 +151,15 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account
}
func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) {
- // Load block IDs from cache with database loader callback.
- blockIDs, err := r.state.Caches.GTS.BlockIDs().Load(accountID, func() ([]string, error) {
- var blockIDs []string
-
- // Block IDs not in cache, perform DB query!
- q := newSelectBlocks(r.db, accountID)
- if _, err := q.Exec(ctx, &blockIDs); err != nil {
- return nil, err
- }
-
- return blockIDs, nil
- })
+ blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
-
- // Our cached / selected block IDs are
- // ALWAYS stored in descending order.
- // Depending on the paging requested
- // this may be an unexpected order.
- if !page.GetOrder().Ascending() {
- blockIDs = paging.Reverse(blockIDs)
- }
-
- // Page the resulting block IDs.
- blockIDs = page.Page(blockIDs)
-
- // Convert these IDs to full block objects.
return r.GetBlocksByIDs(ctx, blockIDs)
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
- followIDs, err := r.getAccountFollowIDs(ctx, accountID)
+ followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil)
return len(followIDs), err
}
@@ -193,7 +169,7 @@ func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID
}
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
- followerIDs, err := r.getAccountFollowerIDs(ctx, accountID)
+ followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil)
return len(followerIDs), err
}
@@ -203,17 +179,22 @@ func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, account
}
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
- followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID)
+ followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil)
return len(followReqIDs), err
}
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
- followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID)
+ followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil)
return len(followReqIDs), err
}
-func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) {
- return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) {
+func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) {
+ blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil)
+ return len(blockIDs), err
+}
+
+func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@@ -240,8 +221,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
})
}
-func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
- return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) {
+func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@@ -268,8 +249,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
})
}
-func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) {
- return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) {
+func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) {
var followReqIDs []string
// Follow request IDs not in cache, perform DB query!
@@ -282,8 +263,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
})
}
-func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) {
- return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) {
+func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) {
var followReqIDs []string
// Follow request IDs not in cache, perform DB query!
@@ -296,13 +277,27 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
})
}
+func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) {
+ var blockIDs []string
+
+ // Block IDs not in cache, perform DB query!
+ q := newSelectBlocks(r.db, accountID)
+ if _, err := q.Exec(ctx, &blockIDs); err != nil {
+ return nil, err
+ }
+
+ return blockIDs, nil
+ })
+}
+
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
- OrderExpr("? DESC", bun.Ident("updated_at"))
+ OrderExpr("? DESC", bun.Ident("id"))
}
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
@@ -311,7 +306,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery {
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
- OrderExpr("? DESC", bun.Ident("updated_at"))
+ OrderExpr("? DESC", bun.Ident("id"))
}
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
@@ -320,7 +315,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery {
Table("follows").
Column("id").
Where("? = ?", bun.Ident("account_id"), accountID).
- OrderExpr("? DESC", bun.Ident("updated_at"))
+ OrderExpr("? DESC", bun.Ident("id"))
}
// newSelectLocalFollows returns a new select query for all rows in the follows table with
@@ -338,7 +333,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery {
Column("id").
Where("? IS NULL", bun.Ident("domain")),
).
- OrderExpr("? DESC", bun.Ident("updated_at"))
+ OrderExpr("? DESC", bun.Ident("id"))
}
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
@@ -347,7 +342,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery {
Table("follows").
Column("id").
Where("? = ?", bun.Ident("target_account_id"), accountID).
- OrderExpr("? DESC", bun.Ident("updated_at"))
+ OrderExpr("? DESC", bun.Ident("id"))
}
// newSelectLocalFollowers returns a new select query for all rows in the follows table with
@@ -365,14 +360,14 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery {
Column("id").
Where("? IS NULL", bun.Ident("domain")),
).
- OrderExpr("? DESC", bun.Ident("updated_at"))
+ OrderExpr("? DESC", bun.Ident("id"))
}
// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID.
func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("blocks")).
- ColumnExpr("?", bun.Ident("?")).
+ ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("account_id"), accountID).
- OrderExpr("? DESC", bun.Ident("updated_at"))
+ OrderExpr("? DESC", bun.Ident("id"))
}
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
index d7c93ff0e..aa2353961 100644
--- a/internal/db/bundb/relationship_test.go
+++ b/internal/db/bundb/relationship_test.go
@@ -753,14 +753,14 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
suite.FailNow(err.Error())
}
- followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
+ followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID, nil)
suite.NoError(err)
suite.Len(followRequests, 1)
}
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
account := suite.testAccounts["local_account_1"]
- follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
+ follows, err := suite.db.GetAccountFollows(context.Background(), account.ID, nil)
suite.NoError(err)
suite.Len(follows, 2)
}
@@ -781,7 +781,7 @@ func (suite *RelationshipTestSuite) TestCountAccountFollows() {
func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
account := suite.testAccounts["local_account_1"]
- follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID)
+ follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil)
suite.NoError(err)
suite.Len(follows, 2)
}
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
index f63937bc1..229245899 100644
--- a/internal/db/bundb/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -114,6 +114,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
follows, err := t.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx),
accountID,
+ nil, // select all
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.Newf("db error getting follows for account %s: %w", accountID, err)
diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go
index e5a78dfd1..ac169ec4a 100644
--- a/internal/db/bundb/timeline_test.go
+++ b/internal/db/bundb/timeline_test.go
@@ -167,8 +167,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
follows, err := suite.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx),
viewingAccount.ID,
+ nil, // select all
)
-
if err != nil {
suite.FailNow(err.Error())
}
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
index 3c3249daf..1d820d081 100644
--- a/internal/db/bundb/util.go
+++ b/internal/db/bundb/util.go
@@ -20,7 +20,9 @@ package bundb
import (
"strings"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/uptrace/bun"
)
@@ -83,6 +85,29 @@ func whereStartsLike(
)
}
+// loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs.
+// NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order.
+func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) {
+ // Check cache for IDs, else load.
+ ids, err := cache.Load(key, loadDESC)
+ if err != nil {
+ return nil, err
+ }
+
+ // Our cached / selected IDs are ALWAYS
+ // fetched from `loadDESC` in descending
+ // order. Depending on the paging requested
+ // this may be an unexpected order.
+ if page.GetOrder().Ascending() {
+ ids = paging.Reverse(ids)
+ }
+
+ // Page the resulting IDs.
+ ids = page.Page(ids)
+
+ return ids, nil
+}
+
// updateWhere parses []db.Where and adds it to the given update query.
func updateWhere(q *bun.UpdateQuery, where []db.Where) {
for _, w := range where {
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index 91c98644c..b3b45551b 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -138,43 +138,46 @@ type Relationship interface {
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error
// GetAccountFollows returns a slice of follows owned by the given accountID.
- GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
+ GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
+ // GetAccountFollowers fetches follows that target given accountID.
+ GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
+
+ // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
+ GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
+
+ // GetAccountFollowRequests returns all follow requests targeting the given account.
+ GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
+
+ // GetAccountFollowRequesting returns all follow requests originating from the given account.
+ GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
+
+ // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters.
+ GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error)
+
// CountAccountFollows returns the amount of accounts that the given accountID is following.
CountAccountFollows(ctx context.Context, accountID string) (int, error)
// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
- // GetAccountFollowers fetches follows that target given accountID.
- GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
-
- // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
- GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
-
// CountAccountFollowers returns the amounts that the given ID is followed by.
CountAccountFollowers(ctx context.Context, accountID string) (int, error)
// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
- // GetAccountFollowRequests returns all follow requests targeting the given account.
- GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
-
- // GetAccountFollowRequesting returns all follow requests originating from the given account.
- GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
-
// CountAccountFollowRequests returns number of follow requests targeting the given account.
CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
// CountAccountFollowerRequests returns number of follow requests originating from the given account.
CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
- // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters.
- GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error)
+ // CountAccountBlocks ...
+ CountAccountBlocks(ctx context.Context, accountID string) (int, error)
// GetNote gets a private note from a source account on a target account, if it exists.
GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error)
diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go
index 4ca2e2683..eada48c1b 100644
--- a/internal/federation/federatingdb/followers.go
+++ b/internal/federation/federatingdb/followers.go
@@ -38,7 +38,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
- follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID)
+ follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID, nil)
if err != nil {
return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err)
}
diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go
index 391a2f810..deb965564 100644
--- a/internal/federation/federatingdb/following.go
+++ b/internal/federation/federatingdb/following.go
@@ -38,7 +38,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
- follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID)
+ follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID, nil)
if err != nil {
return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err)
}
diff --git a/internal/federation/federatingdb/following_test.go b/internal/federation/federatingdb/following_test.go
index 83d1a72b5..93bc6d348 100644
--- a/internal/federation/federatingdb/following_test.go
+++ b/internal/federation/federatingdb/following_test.go
@@ -47,8 +47,8 @@ func (suite *FollowingTestSuite) TestGetFollowing() {
suite.Equal(`{
"@context": "https://www.w3.org/ns/activitystreams",
"items": [
- "http://localhost:8080/users/admin",
- "http://localhost:8080/users/1happyturtle"
+ "http://localhost:8080/users/1happyturtle",
+ "http://localhost:8080/users/admin"
],
"type": "Collection"
}`, string(fJson))
diff --git a/internal/federation/federatingdb/inbox.go b/internal/federation/federatingdb/inbox.go
index 18974ba79..9bd9f8d87 100644
--- a/internal/federation/federatingdb/inbox.go
+++ b/internal/federation/federatingdb/inbox.go
@@ -89,7 +89,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err)
}
- follows, err := f.state.DB.GetAccountFollowers(c, account.ID)
+ follows, err := f.state.DB.GetAccountFollowers(c, account.ID, nil)
if err != nil {
return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err)
}
diff --git a/internal/paging/boundary.go b/internal/paging/boundary.go
index 2f202097b..15af65e0c 100644
--- a/internal/paging/boundary.go
+++ b/internal/paging/boundary.go
@@ -17,10 +17,10 @@
package paging
-// MinID returns an ID boundary with given min ID value,
+// EitherMinID returns an ID boundary with given min ID value,
// using either the `since_id`,"DESC" name,ordering or
// `min_id`,"ASC" name,ordering depending on which is set.
-func MinID(minID, sinceID string) Boundary {
+func EitherMinID(minID, sinceID string) Boundary {
/*
Paging with `since_id` vs `min_id`:
@@ -47,18 +47,28 @@ func MinID(minID, sinceID string) Boundary {
*/
switch {
case minID != "":
- return Boundary{
- Name: "min_id",
- Value: minID,
- Order: OrderAscending,
- }
+ return MinID(minID)
default:
// default min is `since_id`
- return Boundary{
- Name: "since_id",
- Value: sinceID,
- Order: OrderDescending,
- }
+ return SinceID(sinceID)
+ }
+}
+
+// SinceID ...
+func SinceID(sinceID string) Boundary {
+ return Boundary{
+ Name: "since_id",
+ Value: sinceID,
+ Order: OrderDescending,
+ }
+}
+
+// MinID ...
+func MinID(minID string) Boundary {
+ return Boundary{
+ Name: "min_id",
+ Value: minID,
+ Order: OrderAscending,
}
}
@@ -111,7 +121,7 @@ func (b Boundary) new(value string) Boundary {
// Find finds the boundary's set value in input slice, or returns -1.
func (b Boundary) Find(in []string) int {
- if zero(b.Value) {
+ if b.Value == "" {
return -1
}
for i := range in {
@@ -121,15 +131,3 @@ func (b Boundary) Find(in []string) int {
}
return -1
}
-
-// Query returns this boundary as assembled query key=value pair.
-func (b Boundary) Query() string {
- switch {
- case zero(b.Value):
- return ""
- case b.Name == "":
- panic("value without boundary name")
- default:
- return b.Name + "=" + b.Value
- }
-}
diff --git a/internal/paging/page.go b/internal/paging/page.go
index 7d8f84aab..0a9bc71b1 100644
--- a/internal/paging/page.go
+++ b/internal/paging/page.go
@@ -20,7 +20,6 @@ package paging
import (
"net/url"
"strconv"
- "strings"
"golang.org/x/exp/slices"
)
@@ -70,26 +69,10 @@ func (p *Page) GetOrder() Order {
}
func (p *Page) order() Order {
- var (
- // Check if min/max values set.
- minValue = zero(p.Min.Value)
- maxValue = zero(p.Max.Value)
-
- // Check if min/max orders set.
- minOrder = (p.Min.Order != 0)
- maxOrder = (p.Max.Order != 0)
- )
-
switch {
- // Boundaries with a value AND order set
- // take priority. Min always comes first.
- case minValue && minOrder:
- return p.Min.Order
- case maxValue && maxOrder:
- return p.Max.Order
- case minOrder:
+ case p.Min.Order != 0:
return p.Min.Order
- case maxOrder:
+ case p.Max.Order != 0:
return p.Max.Order
default:
return 0
@@ -108,31 +91,9 @@ func (p *Page) Page(in []string) []string {
return in
}
- if o := p.order(); !o.Ascending() {
- // Default sort is descending,
- // catching all cases when NOT
- // ascending (even zero value).
- //
- // NOTE: sorted data does not always
- // occur according to string ineqs
- // so we unfortunately cannot check.
-
- if maxIdx := p.Max.Find(in); maxIdx != -1 {
- // Reslice skipping up to max.
- in = in[maxIdx+1:]
- }
-
- if minIdx := p.Min.Find(in); minIdx != -1 {
- // Reslice stripping past min.
- in = in[:minIdx]
- }
- } else {
+ if p.order().Ascending() {
// Sort type is ascending, input
// data is assumed to be ascending.
- //
- // NOTE: sorted data does not always
- // occur according to string ineqs
- // so we unfortunately cannot check.
if minIdx := p.Min.Find(in); minIdx != -1 {
// Reslice skipping up to min.
@@ -144,6 +105,11 @@ func (p *Page) Page(in []string) []string {
in = in[:maxIdx]
}
+ if p.Limit > 0 && p.Limit < len(in) {
+ // Reslice input to limit.
+ in = in[:p.Limit]
+ }
+
if len(in) > 1 {
// Clone input before
// any modifications.
@@ -153,11 +119,25 @@ func (p *Page) Page(in []string) []string {
// ALWAYS be descending.
in = Reverse(in)
}
- }
+ } else {
+ // Default sort is descending,
+ // catching all cases when NOT
+ // ascending (even zero value).
+
+ if maxIdx := p.Max.Find(in); maxIdx != -1 {
+ // Reslice skipping up to max.
+ in = in[maxIdx+1:]
+ }
+
+ if minIdx := p.Min.Find(in); minIdx != -1 {
+ // Reslice stripping past min.
+ in = in[:minIdx]
+ }
- if p.Limit > 0 && p.Limit < len(in) {
- // Reslice input to limit.
- in = in[:p.Limit]
+ if p.Limit > 0 && p.Limit < len(in) {
+ // Reslice input to limit.
+ in = in[:p.Limit]
+ }
}
return in
@@ -165,8 +145,8 @@ func (p *Page) Page(in []string) []string {
// Next creates a new instance for the next returnable page, using
// given max value. This preserves original limit and max key name.
-func (p *Page) Next(max string) *Page {
- if p == nil || max == "" {
+func (p *Page) Next(lo, hi string) *Page {
+ if p == nil || lo == "" || hi == "" {
// no paging.
return nil
}
@@ -177,16 +157,27 @@ func (p *Page) Next(max string) *Page {
// Set original limit.
p2.Limit = p.Limit
- // Create new from old.
- p2.Max = p.Max.new(max)
+ if p.order().Ascending() {
+ // When ascending, next page
+ // needs to start with min at
+ // the next highest value.
+ p2.Min = p.Min.new(hi)
+ p2.Max = p.Max.new("")
+ } else {
+ // When descending, next page
+ // needs to start with max at
+ // the next lowest value.
+ p2.Min = p.Min.new("")
+ p2.Max = p.Max.new(lo)
+ }
return p2
}
// Prev creates a new instance for the prev returnable page, using
// given min value. This preserves original limit and min key name.
-func (p *Page) Prev(min string) *Page {
- if p == nil || min == "" {
+func (p *Page) Prev(lo, hi string) *Page {
+ if p == nil || lo == "" || hi == "" {
// no paging.
return nil
}
@@ -197,55 +188,56 @@ func (p *Page) Prev(min string) *Page {
// Set original limit.
p2.Limit = p.Limit
- // Create new from old.
- p2.Min = p.Min.new(min)
+ if p.order().Ascending() {
+ // When ascending, prev page
+ // needs to start with max at
+ // the next lowest value.
+ p2.Min = p.Min.new("")
+ p2.Max = p.Max.new(lo)
+ } else {
+ // When descending, next page
+ // needs to start with max at
+ // the next lowest value.
+ p2.Min = p.Min.new(hi)
+ p2.Max = p.Max.new("")
+ }
return p2
}
// ToLink builds a URL link for given endpoint information and extra query parameters,
// appending this Page's minimum / maximum boundaries and available limit (if any).
-func (p *Page) ToLink(proto, host, path string, queryParams []string) string {
+func (p *Page) ToLink(proto, host, path string, queryParams url.Values) string {
if p == nil {
// no paging.
return ""
}
- // Check length before
- // adding boundary params.
- old := len(queryParams)
+ if queryParams == nil {
+ // Allocate new query parameters.
+ queryParams = make(url.Values)
+ }
- if minParam := p.Min.Query(); minParam != "" {
+ if p.Min.Value != "" {
// A page-minimum query parameter is available.
- queryParams = append(queryParams, minParam)
+ queryParams.Add(p.Min.Name, p.Min.Value)
}
- if maxParam := p.Max.Query(); maxParam != "" {
+ if p.Max.Value != "" {
// A page-maximum query parameter is available.
- queryParams = append(queryParams, maxParam)
- }
-
- if len(queryParams) == old {
- // No page boundaries.
- return ""
+ queryParams.Add(p.Max.Name, p.Max.Value)
}
if p.Limit > 0 {
- // Build limit key-value query parameter.
- param := "limit=" + strconv.Itoa(p.Limit)
-
- // Append `limit=$value` query parameter.
- queryParams = append(queryParams, param)
+ // A page limit query parameter is available.
+ queryParams.Add("limit", strconv.Itoa(p.Limit))
}
- // Join collected params into query str.
- query := strings.Join(queryParams, "&")
-
// Build URL string.
return (&url.URL{
Scheme: proto,
Host: host,
Path: path,
- RawQuery: query,
+ RawQuery: queryParams.Encode(),
}).String()
}
diff --git a/internal/paging/page_test.go b/internal/paging/page_test.go
index 419b9ea44..01cc74d9f 100644
--- a/internal/paging/page_test.go
+++ b/internal/paging/page_test.go
@@ -97,7 +97,7 @@ var cases = []Case{
// Return page and expected IDs.
return ids, &paging.Page{
- Min: paging.MinID(minID, ""),
+ Min: paging.MinID(minID),
Max: paging.MaxID(maxID),
}, expect
}),
@@ -129,7 +129,7 @@ var cases = []Case{
// Return page and expected IDs.
return ids, &paging.Page{
- Min: paging.MinID(minID, ""),
+ Min: paging.MinID(minID),
Max: paging.MaxID(maxID),
Limit: limit,
}, expect
@@ -156,7 +156,7 @@ var cases = []Case{
// Return page and expected IDs.
return ids, &paging.Page{
- Min: paging.MinID(minID, ""),
+ Min: paging.MinID(minID),
Max: paging.MaxID(maxID),
Limit: len(ids) * 2,
}, expect
@@ -182,7 +182,7 @@ var cases = []Case{
// Return page and expected IDs.
return ids, &paging.Page{
- Min: paging.MinID("", sinceID),
+ Min: paging.SinceID(sinceID),
Max: paging.MaxID(maxID),
}, expect
}),
@@ -225,7 +225,7 @@ var cases = []Case{
// Return page and expected IDs.
return ids, &paging.Page{
- Min: paging.MinID("", sinceID),
+ Min: paging.SinceID(sinceID),
}, expect
}),
CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) {
@@ -247,7 +247,7 @@ var cases = []Case{
// Return page and expected IDs.
return ids, &paging.Page{
- Min: paging.MinID(minID, ""),
+ Min: paging.MinID(minID),
}, expect
}),
}
diff --git a/internal/paging/parse.go b/internal/paging/parse.go
index 55ebef7f5..ce6391708 100644
--- a/internal/paging/parse.go
+++ b/internal/paging/parse.go
@@ -30,9 +30,9 @@ import (
// While conversely, a zero default limit will not enforce paging, returning a nil page value.
func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) {
// Extract request query params.
- sinceID := c.Query("since_id")
- minID := c.Query("min_id")
- maxID := c.Query("max_id")
+ sinceID, haveSince := c.GetQuery("since_id")
+ minID, haveMin := c.GetQuery("min_id")
+ maxID, haveMax := c.GetQuery("max_id")
// Extract request limit parameter.
limit, errWithCode := ParseLimit(c, min, max, _default)
@@ -40,20 +40,38 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo
return nil, errWithCode
}
- if sinceID == "" &&
- minID == "" &&
- maxID == "" &&
- limit == 0 {
+ switch {
+ case haveMin:
+ // A min_id was supplied, even if the value
+ // itself is empty. This indicates ASC order.
+ return &Page{
+ Min: MinID(minID),
+ Max: MaxID(maxID),
+ Limit: limit,
+ }, nil
+
+ case haveMax || haveSince:
+ // A max_id or since_id was supplied, even if the
+ // value itself is empty. This indicates DESC order.
+ return &Page{
+ Min: SinceID(sinceID),
+ Max: MaxID(maxID),
+ Limit: limit,
+ }, nil
+
+ case limit == 0:
// No ID paging params provided, and no default
// limit value which indicates paging not enforced.
return nil, nil
- }
- return &Page{
- Min: MinID(minID, sinceID),
- Max: MaxID(maxID),
- Limit: limit,
- }, nil
+ default:
+ // only limit.
+ return &Page{
+ Min: SinceID(""),
+ Max: MaxID(""),
+ Limit: limit,
+ }, nil
+ }
}
// ParseShortcodeDomainPage parses an emoji shortcode domain Page from a request context, returning BadRequest
@@ -62,8 +80,8 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo
// a zero default limit will not enforce paging, returning a nil page value.
func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) {
// Extract request query parameters.
- minShortcode := c.Query("min_shortcode_domain")
- maxShortcode := c.Query("max_shortcode_domain")
+ minShortcode, haveMin := c.GetQuery("min_shortcode_domain")
+ maxShortcode, haveMax := c.GetQuery("max_shortcode_domain")
// Extract request limit parameter.
limit, errWithCode := ParseLimit(c, min, max, _default)
@@ -71,8 +89,8 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt
return nil, errWithCode
}
- if minShortcode == "" &&
- maxShortcode == "" &&
+ if !haveMin &&
+ !haveMax &&
limit == 0 {
// No ID paging params provided, and no default
// limit value which indicates paging not enforced.
@@ -89,7 +107,10 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt
// ParseLimit parses the limit query parameter from a request context, returning BadRequest on error parsing and _default if zero limit given.
func ParseLimit(c *gin.Context, min, max, _default int) (int, gtserror.WithCode) {
// Get limit query param.
- str := c.Query("limit")
+ str, ok := c.GetQuery("limit")
+ if !ok {
+ return _default, nil
+ }
// Attempt to parse limit int.
i, err := strconv.Atoi(str)
diff --git a/internal/paging/response.go b/internal/paging/response.go
index 498b42d34..71b0cf213 100644
--- a/internal/paging/response.go
+++ b/internal/paging/response.go
@@ -18,6 +18,7 @@
package paging
import (
+ "net/url"
"strings"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -35,18 +36,13 @@ type ResponseParams struct {
Path string // path to use for next/prev queries in the link header
Next *Page // page details for the next page
Prev *Page // page details for the previous page
- Query []string // any extra query parameters to provide in the link header, should be in the format 'example=value'
+ Query url.Values // any extra query parameters to provide in the link header, should be in the format 'example=value'
}
// PackageResponse is a convenience function for returning
// a bunch of pageable items (notifications, statuses, etc), as well
// as a Link header to inform callers of where to find next/prev items.
func PackageResponse(params ResponseParams) *apimodel.PageableResponse {
- if len(params.Items) == 0 {
- // No items to page through.
- return EmptyResponse()
- }
-
var (
// Extract paging params.
nextPg = params.Next
diff --git a/internal/paging/response_test.go b/internal/paging/response_test.go
index 8eca2a601..b4b7d6058 100644
--- a/internal/paging/response_test.go
+++ b/internal/paging/response_test.go
@@ -42,9 +42,9 @@ func (suite *PagingSuite) TestPagingStandard() {
resp := paging.PackageResponse(params)
suite.Equal(make([]interface{}, 10, 10), resp.Items)
- suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10>; rel="prev"`, resp.LinkHeader)
- suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink)
- suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink)
+ suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader)
+ suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink)
+ suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink)
}
func (suite *PagingSuite) TestPagingNoLimit() {
@@ -77,9 +77,9 @@ func (suite *PagingSuite) TestPagingNoNextID() {
resp := paging.PackageResponse(params)
suite.Equal(make([]interface{}, 10, 10), resp.Items)
- suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10>; rel="prev"`, resp.LinkHeader)
+ suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader)
suite.Equal(``, resp.NextLink)
- suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink)
+ suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink)
}
func (suite *PagingSuite) TestPagingNoPrevID() {
@@ -94,27 +94,11 @@ func (suite *PagingSuite) TestPagingNoPrevID() {
resp := paging.PackageResponse(params)
suite.Equal(make([]interface{}, 10, 10), resp.Items)
- suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10>; rel="next"`, resp.LinkHeader)
- suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink)
+ suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next"`, resp.LinkHeader)
+ suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink)
suite.Equal(``, resp.PrevLink)
}
-func (suite *PagingSuite) TestPagingNoItems() {
- config.SetHost("example.org")
-
- params := paging.ResponseParams{
- Next: nextPage("01H11KA1DM2VH3747YDE7FV5HN", 10),
- Prev: prevPage("01H11KBBVRRDYYC5KEPME1NP5R", 10),
- }
-
- resp := paging.PackageResponse(params)
-
- suite.Empty(resp.Items)
- suite.Empty(resp.LinkHeader)
- suite.Empty(resp.NextLink)
- suite.Empty(resp.PrevLink)
-}
-
func TestPagingSuite(t *testing.T) {
suite.Run(t, &PagingSuite{})
}
@@ -128,7 +112,7 @@ func nextPage(id string, limit int) *paging.Page {
func prevPage(id string, limit int) *paging.Page {
return &paging.Page{
- Min: paging.MinID(id, ""),
+ Min: paging.MinID(id),
Limit: limit,
}
}
diff --git a/internal/paging/util.go b/internal/paging/util.go
index d9adb9cbf..dd941dd88 100644
--- a/internal/paging/util.go
+++ b/internal/paging/util.go
@@ -41,9 +41,3 @@ func Reverse(in []string) []string {
return in
}
-
-// zero is a shorthand to check a generic value is its zero value.
-func zero[T comparable](t T) bool {
- var z T
- return t == z
-}
diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go
index 7bef8b0c5..a32a73ac1 100644
--- a/internal/processing/account/account.go
+++ b/internal/processing/account/account.go
@@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
@@ -32,6 +33,9 @@ import (
//
// It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc.
type Processor struct {
+ // common processor logic
+ c *common.Processor
+
state *state.State
tc typeutils.TypeConverter
mediaManager *media.Manager
@@ -44,6 +48,7 @@ type Processor struct {
// New returns a new account processor.
func New(
+ common *common.Processor,
state *state.State,
tc typeutils.TypeConverter,
mediaManager *media.Manager,
@@ -53,6 +58,7 @@ func New(
parseMention gtsmodel.ParseMentionFunc,
) Processor {
return Processor{
+ c: common,
state: state,
tc: tc,
mediaManager: mediaManager,
diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go
index 4ba7de16e..2e4a64844 100644
--- a/internal/processing/account/account_test.go
+++ b/internal/processing/account/account_test.go
@@ -30,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
+ "github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
@@ -113,7 +114,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
filter := visibility.NewFilter(&suite.state)
- suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
+ common := common.New(&suite.state, suite.tc, suite.federator, filter)
+ suite.accountProcessor = account.New(&common, &suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
}
diff --git a/internal/processing/account/block.go b/internal/processing/account/block.go
index 1ec31a753..270048100 100644
--- a/internal/processing/account/block.go
+++ b/internal/processing/account/block.go
@@ -28,8 +28,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/uris"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
)
// BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.
@@ -128,6 +131,53 @@ func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel
return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
+// BlocksGet ...
+func (p *Processor) BlocksGet(
+ ctx context.Context,
+ requestingAccount *gtsmodel.Account,
+ page *paging.Page,
+) (*apimodel.PageableResponse, gtserror.WithCode) {
+ blocks, err := p.state.DB.GetAccountBlocks(ctx,
+ requestingAccount.ID,
+ page,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+
+ // Check for empty response.
+ count := len(blocks)
+ if len(blocks) == 0 {
+ return util.EmptyPageableResponse(), nil
+ }
+
+ items := make([]interface{}, 0, count)
+
+ for _, block := range blocks {
+ // Convert target account to frontend API model. (target will never be nil)
+ account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount)
+ if err != nil {
+ log.Errorf(ctx, "error converting account to public api account: %v", err)
+ continue
+ }
+
+ // Append target to return items.
+ items = append(items, account)
+ }
+
+ // Get the lowest and highest
+ // ID values, used for paging.
+ lo := blocks[count-1].ID
+ hi := blocks[0].ID
+
+ return paging.PackageResponse(paging.ResponseParams{
+ Items: items,
+ Path: "/api/v1/blocks",
+ Next: page.Next(lo, hi),
+ Prev: page.Prev(lo, hi),
+ }), nil
+}
+
func (p *Processor) getBlockTarget(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*gtsmodel.Account, *gtsmodel.Block, gtserror.WithCode) {
// Account should not block or unblock itself.
if requestingAccount.ID == targetAccountID {
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index da13eb20e..e89ebf13f 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -160,7 +160,7 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *
// - Follow requests created by account.
func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error {
// Delete follows targeting this account.
- followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID)
+ followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err)
}
@@ -172,7 +172,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
}
// Delete follow requests targeting this account.
- followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID)
+ followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err)
}
@@ -193,7 +193,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
)
// Delete follows originating from this account.
- following, err := p.state.DB.GetAccountFollows(ctx, account.ID)
+ following, err := p.state.DB.GetAccountFollows(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err)
}
@@ -211,7 +211,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
}
// Delete follow requests originating from this account.
- followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID)
+ followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err)
}
diff --git a/internal/processing/account/follow.go b/internal/processing/account/follow.go
index 1aed92e75..8006f8d79 100644
--- a/internal/processing/account/follow.go
+++ b/internal/processing/account/follow.go
@@ -20,7 +20,6 @@ package account
import (
"context"
"errors"
- "fmt"
"github.com/superseriousbusiness/gotosocial/internal/ap"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -35,7 +34,7 @@ import (
// FollowCreate handles a follow request to an account, either remote or local.
func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
- targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, form.ID)
+ targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, form.ID)
if errWithCode != nil {
return nil, errWithCode
}
@@ -46,7 +45,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
requestingAccount.ID,
targetAccount.ID,
); err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("FollowCreate: db error checking existing follow: %w", err)
+ err = gtserror.Newf("db error checking existing follow: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if follow != nil {
// Already follows, update if necessary + return relationship.
@@ -66,7 +65,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
requestingAccount.ID,
targetAccount.ID,
); err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("FollowCreate: db error checking existing follow request: %w", err)
+ err = gtserror.Newf("db error checking existing follow request: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if followRequest != nil {
// Already requested, update if necessary + return relationship.
@@ -100,7 +99,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil {
- err = fmt.Errorf("FollowCreate: error creating follow request in db: %s", err)
+ err = gtserror.Newf("error creating follow request in db: %s", err)
return nil, gtserror.NewErrorInternalError(err)
}
@@ -112,7 +111,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
// Because we know the requestingAccount is also
// local, we don't need to federate the accept out.
if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
- err = fmt.Errorf("FollowCreate: error accepting follow request for local unlocked account: %w", err)
+ err = gtserror.Newf("error accepting follow request for local unlocked account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
} else if targetAccount.IsRemote() {
@@ -132,7 +131,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
// FollowRemove handles the removal of a follow/follow request to an account, either remote or local.
func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, targetAccountID)
+ targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, targetAccountID)
if errWithCode != nil {
return nil, errWithCode
}
@@ -140,7 +139,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
// Unfollow and deal with side effects.
msgs, err := p.unfollow(ctx, requestingAccount, targetAccount)
if err != nil {
- return nil, gtserror.NewErrorNotFound(fmt.Errorf("FollowRemove: account %s not found in the db: %s", targetAccountID, err))
+ return nil, gtserror.NewErrorNotFound(gtserror.Newf("account %s not found in the db: %s", targetAccountID, err))
}
// Batch queue accreted client api messages.
@@ -166,7 +165,6 @@ func (p *Processor) updateFollow(
currentNotify *bool,
update func(...string) error,
) (*apimodel.Relationship, gtserror.WithCode) {
-
if form.Reblogs == nil && form.Notify == nil {
// There's nothing to update.
return p.RelationshipGet(ctx, requestingAccount, form.ID)
@@ -192,7 +190,7 @@ func (p *Processor) updateFollow(
}
if err := update(columns...); err != nil {
- err = fmt.Errorf("updateFollow: error updating existing follow (request): %w", err)
+ err = gtserror.Newf("error updating existing follow (request): %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
@@ -201,38 +199,23 @@ func (p *Processor) updateFollow(
// getFollowTarget is a convenience function which:
// - Checks if account is trying to follow/unfollow itself.
-// - Returns not found if there's a block in place between accounts.
+// - Returns not found if target should not be visible to requester.
// - Returns target account according to its id.
-func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID string, targetAccountID string) (*gtsmodel.Account, gtserror.WithCode) {
+func (p *Processor) getFollowTarget(ctx context.Context, requester *gtsmodel.Account, targetID string) (*gtsmodel.Account, gtserror.WithCode) {
+ // Check for requester.
+ if requester == nil {
+ err := errors.New("no authorized user")
+ return nil, gtserror.NewErrorUnauthorized(err)
+ }
+
// Account can't follow or unfollow itself.
- if requestingAccountID == targetAccountID {
+ if requester.ID == targetID {
err := errors.New("account can't follow or unfollow itself")
return nil, gtserror.NewErrorNotAcceptable(err)
}
- // Do nothing if a block exists in either direction between accounts.
- if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, targetAccountID); err != nil {
- err = fmt.Errorf("db error checking block between accounts: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- } else if blocked {
- err = errors.New("block exists between accounts")
- return nil, gtserror.NewErrorNotFound(err)
- }
-
- // Ensure target account retrievable.
- targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
- if err != nil {
- if !errors.Is(err, db.ErrNoEntries) {
- // Real db error.
- err = fmt.Errorf("db error looking for target account %s: %w", targetAccountID, err)
- return nil, gtserror.NewErrorInternalError(err)
- }
- // Account not found.
- err = fmt.Errorf("target account %s not found in the db", targetAccountID)
- return nil, gtserror.NewErrorNotFound(err, err.Error())
- }
-
- return targetAccount, nil
+ // Fetch the target account for requesting user account.
+ return p.c.GetVisibleTargetAccount(ctx, requester, targetID)
}
// unfollow is a convenience function for having requesting account
@@ -248,7 +231,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
// Get follow from requesting account to target account.
follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("unfollow: error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ err = gtserror.Newf("error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
}
@@ -257,7 +240,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
err = p.state.DB.DeleteFollowByID(ctx, follow.ID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("unfollow: error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ err = gtserror.Newf("error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
}
@@ -284,7 +267,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
// Get follow request from requesting account to target account.
followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("unfollow: error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ err = gtserror.Newf("error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
}
@@ -293,7 +276,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("unfollow: error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ err = gtserror.Newf("error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
}
diff --git a/internal/processing/account/follow_request.go b/internal/processing/account/follow_request.go
new file mode 100644
index 000000000..c054637c8
--- /dev/null
+++ b/internal/processing/account/follow_request.go
@@ -0,0 +1,119 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package account
+
+import (
+ "context"
+ "errors"
+
+ "github.com/superseriousbusiness/gotosocial/internal/ap"
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
+)
+
+// FollowRequestAccept handles the accepting of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account).
+func (p *Processor) FollowRequestAccept(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ follow, err := p.state.DB.AcceptFollowRequest(ctx, sourceAccountID, requestingAccount.ID)
+ if err != nil {
+ return nil, gtserror.NewErrorNotFound(err)
+ }
+
+ if follow.Account != nil {
+ // Only enqueue work in the case we have a request creating account stored.
+ // NOTE: due to how AcceptFollowRequest works, the inverse shouldn't be possible.
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
+ APObjectType: ap.ActivityFollow,
+ APActivityType: ap.ActivityAccept,
+ GTSModel: follow,
+ OriginAccount: follow.Account,
+ TargetAccount: follow.TargetAccount,
+ })
+ }
+
+ return p.RelationshipGet(ctx, requestingAccount, sourceAccountID)
+}
+
+// FollowRequestReject handles the rejection of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account).
+func (p *Processor) FollowRequestReject(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ followRequest, err := p.state.DB.GetFollowRequest(ctx, sourceAccountID, requestingAccount.ID)
+ if err != nil {
+ return nil, gtserror.NewErrorNotFound(err)
+ }
+
+ err = p.state.DB.RejectFollowRequest(ctx, sourceAccountID, requestingAccount.ID)
+ if err != nil {
+ return nil, gtserror.NewErrorNotFound(err)
+ }
+
+ if followRequest.Account != nil {
+ // Only enqueue work in the case we have a request creating account stored.
+ // NOTE: due to how GetFollowRequest works, the inverse shouldn't be possible.
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
+ APObjectType: ap.ActivityFollow,
+ APActivityType: ap.ActivityReject,
+ GTSModel: followRequest,
+ OriginAccount: followRequest.Account,
+ TargetAccount: followRequest.TargetAccount,
+ })
+ }
+
+ return p.RelationshipGet(ctx, requestingAccount, sourceAccountID)
+}
+
+// FollowRequestsGet fetches a list of the accounts that are follow requesting the given requestingAccount (the currently authorized account).
+func (p *Processor) FollowRequestsGet(ctx context.Context, requestingAccount *gtsmodel.Account, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) {
+ // Fetch follow requests targeting the given requesting account model.
+ followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, requestingAccount.ID, page)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+
+ // Check for empty response.
+ count := len(followRequests)
+ if count == 0 {
+ return paging.EmptyResponse(), nil
+ }
+
+ // Get the lowest and highest
+ // ID values, used for paging.
+ lo := followRequests[count-1].ID
+ hi := followRequests[0].ID
+
+ // Func to fetch follow source at index.
+ getIdx := func(i int) *gtsmodel.Account {
+ return followRequests[i].Account
+ }
+
+ // Get a filtered slice of public API account models.
+ items := p.c.GetVisibleAPIAccountsPaged(ctx,
+ requestingAccount,
+ getIdx,
+ count,
+ )
+
+ return paging.PackageResponse(paging.ResponseParams{
+ Items: items,
+ Path: "/api/v1/follow_requests",
+ Next: page.Next(lo, hi),
+ Prev: page.Prev(lo, hi),
+ }), nil
+}
diff --git a/internal/processing/account/relationships.go b/internal/processing/account/relationships.go
index d12d989ef..58c98f3ba 100644
--- a/internal/processing/account/relationships.go
+++ b/internal/processing/account/relationships.go
@@ -20,128 +20,120 @@ package account
import (
"context"
"errors"
- "fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/paging"
)
// FollowersGet fetches a list of the target account's followers.
-func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
- err = fmt.Errorf("FollowersGet: db error checking block: %w", err)
+func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) {
+ // Fetch target account to check it exists, and visibility of requester->target.
+ _, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID)
+ if errWithCode != nil {
+ return nil, errWithCode
+ }
+
+ follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID, page)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ err = gtserror.Newf("db error getting followers: %w", err)
return nil, gtserror.NewErrorInternalError(err)
- } else if blocked {
- err = errors.New("FollowersGet: block exists between accounts")
- return nil, gtserror.NewErrorNotFound(err)
}
- follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID)
- if err != nil {
- if !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("FollowersGet: db error getting followers: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- }
- return []apimodel.Account{}, nil
+ // Check for empty response.
+ count := len(follows)
+ if count == 0 {
+ return paging.EmptyResponse(), nil
}
- return p.accountsFromFollows(ctx, follows, requestingAccount.ID)
+ // Get the lowest and highest
+ // ID values, used for paging.
+ lo := follows[count-1].ID
+ hi := follows[0].ID
+
+ // Func to fetch follow source at index.
+ getIdx := func(i int) *gtsmodel.Account {
+ return follows[i].Account
+ }
+
+ // Get a filtered slice of public API account models.
+ items := p.c.GetVisibleAPIAccountsPaged(ctx,
+ requestingAccount,
+ getIdx,
+ len(follows),
+ )
+
+ return paging.PackageResponse(paging.ResponseParams{
+ Items: items,
+ Path: "/api/v1/accounts/" + targetAccountID + "/followers",
+ Next: page.Next(lo, hi),
+ Prev: page.Prev(lo, hi),
+ }), nil
}
// FollowingGet fetches a list of the accounts that target account is following.
-func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
- err = fmt.Errorf("FollowingGet: db error checking block: %w", err)
+func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) {
+ // Fetch target account to check it exists, and visibility of requester->target.
+ _, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID)
+ if errWithCode != nil {
+ return nil, errWithCode
+ }
+
+ // Fetch known accounts that follow given target account ID.
+ follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID, page)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ err = gtserror.Newf("db error getting followers: %w", err)
return nil, gtserror.NewErrorInternalError(err)
- } else if blocked {
- err = errors.New("FollowingGet: block exists between accounts")
- return nil, gtserror.NewErrorNotFound(err)
}
- follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID)
- if err != nil {
- if !errors.Is(err, db.ErrNoEntries) {
- err = fmt.Errorf("FollowingGet: db error getting followers: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- }
- return []apimodel.Account{}, nil
+ // Check for empty response.
+ count := len(follows)
+ if count == 0 {
+ return paging.EmptyResponse(), nil
}
- return p.targetAccountsFromFollows(ctx, follows, requestingAccount.ID)
+ // Get the lowest and highest
+ // ID values, used for paging.
+ lo := follows[count-1].ID
+ hi := follows[0].ID
+
+ // Func to fetch follow source at index.
+ getIdx := func(i int) *gtsmodel.Account {
+ return follows[i].TargetAccount
+ }
+
+ // Get a filtered slice of public API account models.
+ items := p.c.GetVisibleAPIAccountsPaged(ctx,
+ requestingAccount,
+ getIdx,
+ len(follows),
+ )
+
+ return paging.PackageResponse(paging.ResponseParams{
+ Items: items,
+ Path: "/api/v1/accounts/" + targetAccountID + "/following",
+ Next: page.Next(lo, hi),
+ Prev: page.Prev(lo, hi),
+ }), nil
}
// RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
if requestingAccount == nil {
- return nil, gtserror.NewErrorForbidden(errors.New("not authed"))
+ return nil, gtserror.NewErrorForbidden(gtserror.New("not authed"))
}
gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
if err != nil {
- return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err))
+ return nil, gtserror.NewErrorInternalError(gtserror.Newf("error getting relationship: %s", err))
}
r, err := p.tc.RelationshipToAPIRelationship(ctx, gtsR)
if err != nil {
- return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting relationship: %s", err))
+ return nil, gtserror.NewErrorInternalError(gtserror.Newf("error converting relationship: %s", err))
}
return r, nil
}
-
-func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- accounts := make([]apimodel.Account, 0, len(follows))
- for _, follow := range follows {
- if follow.Account == nil {
- // No account set for some reason; just skip.
- log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account")
- continue
- }
-
- if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.AccountID); err != nil {
- err = fmt.Errorf("accountsFromFollows: db error checking block: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- } else if blocked {
- continue
- }
-
- account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.Account)
- if err != nil {
- err = fmt.Errorf("accountsFromFollows: error converting account to api account: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- accounts = append(accounts, *account)
- }
- return accounts, nil
-}
-
-func (p *Processor) targetAccountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- accounts := make([]apimodel.Account, 0, len(follows))
- for _, follow := range follows {
- if follow.TargetAccount == nil {
- // No account set for some reason; just skip.
- log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated target account")
- continue
- }
-
- if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.TargetAccountID); err != nil {
- err = fmt.Errorf("targetAccountsFromFollows: db error checking block: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- } else if blocked {
- continue
- }
-
- account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.TargetAccount)
- if err != nil {
- err = fmt.Errorf("targetAccountsFromFollows: error converting account to api account: %w", err)
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- accounts = append(accounts, *account)
- }
- return accounts, nil
-}
diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go
deleted file mode 100644
index 014b6af21..000000000
--- a/internal/processing/blocks.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-package processing
-
-import (
- "context"
- "errors"
-
- apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
- "github.com/superseriousbusiness/gotosocial/internal/paging"
- "github.com/superseriousbusiness/gotosocial/internal/util"
-)
-
-// BlocksGet ...
-func (p *Processor) BlocksGet(
- ctx context.Context,
- requestingAccount *gtsmodel.Account,
- page *paging.Page,
-) (*apimodel.PageableResponse, gtserror.WithCode) {
- blocks, err := p.state.DB.GetAccountBlocks(ctx,
- requestingAccount.ID,
- page,
- )
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- // Check for zero length.
- count := len(blocks)
- if len(blocks) == 0 {
- return util.EmptyPageableResponse(), nil
- }
-
- var (
- items = make([]interface{}, 0, count)
-
- // Set next + prev values before API converting
- // so the caller can still page even on error.
- nextMaxIDValue = blocks[count-1].ID
- prevMinIDValue = blocks[0].ID
- )
-
- for _, block := range blocks {
- if block.TargetAccount == nil {
- // All models should be populated at this point.
- log.Warnf(ctx, "block target account was nil: %v", err)
- continue
- }
-
- // Convert target account to frontend API model.
- account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount)
- if err != nil {
- log.Errorf(ctx, "error converting account to public api account: %v", err)
- continue
- }
-
- // Append target to return items.
- items = append(items, account)
- }
-
- return paging.PackageResponse(paging.ResponseParams{
- Items: items,
- Path: "/api/v1/blocks",
- Next: page.Next(nextMaxIDValue),
- Prev: page.Prev(prevMinIDValue),
- }), nil
-}
diff --git a/internal/processing/common/account.go.go b/internal/processing/common/account.go.go
new file mode 100644
index 000000000..06e87fa0e
--- /dev/null
+++ b/internal/processing/common/account.go.go
@@ -0,0 +1,238 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package common
+
+import (
+ "context"
+ "errors"
+
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// GetTargetAccountBy fetches the target account with db load function, given the authorized (or, nil) requester's
+// account. This returns an approprate gtserror.WithCode accounting (ha) for not found and visibility to requester.
+func (p *Processor) GetTargetAccountBy(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ getTargetFromDB func() (*gtsmodel.Account, error),
+) (
+ account *gtsmodel.Account,
+ visible bool,
+ errWithCode gtserror.WithCode,
+) {
+ // Fetch the target account from db.
+ target, err := getTargetFromDB()
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, false, gtserror.NewErrorInternalError(err)
+ }
+
+ if target == nil {
+ // DB loader could not find account in database.
+ err := errors.New("target account not found")
+ return nil, false, gtserror.NewErrorNotFound(err)
+ }
+
+ // Check whether target account is visible to requesting account.
+ visible, err = p.filter.AccountVisible(ctx, requester, target)
+ if err != nil {
+ return nil, false, gtserror.NewErrorInternalError(err)
+ }
+
+ if requester != nil && visible {
+ // Ensure the account is up-to-date.
+ p.federator.RefreshAccountAsync(ctx,
+ requester.Username,
+ target,
+ nil,
+ false,
+ )
+ }
+
+ return target, visible, nil
+}
+
+// GetTargetAccountByID is a call-through to GetTargetAccountBy() using the db GetAccountByID() function.
+func (p *Processor) GetTargetAccountByID(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ targetID string,
+) (
+ account *gtsmodel.Account,
+ visible bool,
+ errWithCode gtserror.WithCode,
+) {
+ return p.GetTargetAccountBy(ctx, requester, func() (*gtsmodel.Account, error) {
+ return p.state.DB.GetAccountByID(ctx, targetID)
+ })
+}
+
+// GetVisibleTargetAccount calls GetTargetAccountByID(),
+// but converts a non-visible result to not-found error.
+func (p *Processor) GetVisibleTargetAccount(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ targetID string,
+) (
+ account *gtsmodel.Account,
+ errWithCode gtserror.WithCode,
+) {
+ // Fetch the target account by ID from the database.
+ target, visible, errWithCode := p.GetTargetAccountByID(ctx,
+ requester,
+ targetID,
+ )
+ if errWithCode != nil {
+ return nil, errWithCode
+ }
+
+ if !visible {
+ // Pretend account doesn't exist if not visible.
+ err := errors.New("target account not found")
+ return nil, gtserror.NewErrorNotFound(err)
+ }
+
+ return target, nil
+}
+
+// GetAPIAccount fetches the appropriate API account model depending on whether requester = target.
+func (p *Processor) GetAPIAccount(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ target *gtsmodel.Account,
+) (
+ apiAcc *apimodel.Account,
+ errWithCode gtserror.WithCode,
+) {
+ var err error
+
+ if requester != nil && requester.ID == target.ID {
+ // Only return sensitive account model _if_ requester = target.
+ apiAcc, err = p.converter.AccountToAPIAccountSensitive(ctx, target)
+ } else {
+ // Else, fall back to returning the public account model.
+ apiAcc, err = p.converter.AccountToAPIAccountPublic(ctx, target)
+ }
+
+ if err != nil {
+ err := gtserror.Newf("error converting account: %w", err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+
+ return apiAcc, nil
+}
+
+// GetAPIAccountBlocked fetches the limited "blocked" account model for given target.
+func (p *Processor) GetAPIAccountBlocked(
+ ctx context.Context,
+ targetAcc *gtsmodel.Account,
+) (
+ apiAcc *apimodel.Account,
+ errWithCode gtserror.WithCode,
+) {
+ apiAccount, err := p.converter.AccountToAPIAccountBlocked(ctx, targetAcc)
+ if err != nil {
+ err = gtserror.Newf("error converting account: %w", err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ return apiAccount, nil
+}
+
+// GetVisibleAPIAccounts converts an array of gtsmodel.Accounts (inputted by next function) into
+// public API model accounts, checking first for visibility. Please note that all errors will be
+// logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping
+// errors in the lead-up to this function, whereas calling this should not be a show-stopper.
+func (p *Processor) GetVisibleAPIAccounts(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ next func(int) *gtsmodel.Account,
+ length int,
+) []*apimodel.Account {
+ return p.getVisibleAPIAccounts(ctx, 3, requester, next, length)
+}
+
+// GetVisibleAPIAccountsPaged is functionally equivalent to GetVisibleAPIAccounts(),
+// except the accounts are returned as a converted slice of accounts as interface{}.
+func (p *Processor) GetVisibleAPIAccountsPaged(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ next func(int) *gtsmodel.Account,
+ length int,
+) []interface{} {
+ accounts := p.getVisibleAPIAccounts(ctx, 3, requester, next, length)
+ if len(accounts) == 0 {
+ return nil
+ }
+ items := make([]interface{}, len(accounts))
+ for i, account := range accounts {
+ items[i] = account
+ }
+ return items
+}
+
+func (p *Processor) getVisibleAPIAccounts(
+ ctx context.Context,
+ calldepth int, // used to skip wrapping func above these's names
+ requester *gtsmodel.Account,
+ next func(int) *gtsmodel.Account,
+ length int,
+) []*apimodel.Account {
+ // Start new log entry with
+ // the above calling func's name.
+ l := log.
+ WithContext(ctx).
+ WithField("caller", log.Caller(calldepth+1))
+
+ // Preallocate slice according to expected length.
+ accounts := make([]*apimodel.Account, 0, length)
+
+ for i := 0; i < length; i++ {
+ // Get next account.
+ account := next(i)
+ if account == nil {
+ continue
+ }
+
+ // Check whether this account is visible to requesting account.
+ visible, err := p.filter.AccountVisible(ctx, requester, account)
+ if err != nil {
+ l.Errorf("error checking account visibility: %v", err)
+ continue
+ }
+
+ if !visible {
+ // Not visible to requester.
+ continue
+ }
+
+ // Convert the account to a public API model representation.
+ apiAcc, err := p.converter.AccountToAPIAccountPublic(ctx, account)
+ if err != nil {
+ l.Errorf("error converting account: %v", err)
+ continue
+ }
+
+ // Append API model to return slice.
+ accounts = append(accounts, apiAcc)
+ }
+
+ return accounts
+}
diff --git a/internal/processing/common/common.go b/internal/processing/common/common.go
new file mode 100644
index 000000000..53c298579
--- /dev/null
+++ b/internal/processing/common/common.go
@@ -0,0 +1,50 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package common
+
+import (
+ "github.com/superseriousbusiness/gotosocial/internal/federation"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/typeutils"
+ "github.com/superseriousbusiness/gotosocial/internal/visibility"
+)
+
+// Processor provides a processor with logic
+// common to multiple logical domains of the
+// processing subsection of the codebase.
+type Processor struct {
+ state *state.State
+ converter typeutils.TypeConverter
+ federator federation.Federator
+ filter *visibility.Filter
+}
+
+// New returns a new Processor instance.
+func New(
+ state *state.State,
+ converter typeutils.TypeConverter,
+ federator federation.Federator,
+ filter *visibility.Filter,
+) Processor {
+ return Processor{
+ state: state,
+ converter: converter,
+ federator: federator,
+ filter: filter,
+ }
+}
diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go
new file mode 100644
index 000000000..fb480ec7e
--- /dev/null
+++ b/internal/processing/common/status.go
@@ -0,0 +1,248 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package common
+
+import (
+ "context"
+ "errors"
+
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// GetTargetStatusBy fetches the target status with db load function, given the authorized (or, nil) requester's
+// account. This returns an approprate gtserror.WithCode accounting for not found and visibility to requester.
+func (p *Processor) GetTargetStatusBy(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ getTargetFromDB func() (*gtsmodel.Status, error),
+) (
+ status *gtsmodel.Status,
+ visible bool,
+ errWithCode gtserror.WithCode,
+) {
+ // Fetch the target status from db.
+ target, err := getTargetFromDB()
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, false, gtserror.NewErrorInternalError(err)
+ }
+
+ if target == nil {
+ // DB loader could not find status in database.
+ err := errors.New("target status not found")
+ return nil, false, gtserror.NewErrorNotFound(err)
+ }
+
+ // Check whether target status is visible to requesting account.
+ visible, err = p.filter.StatusVisible(ctx, requester, target)
+ if err != nil {
+ return nil, false, gtserror.NewErrorInternalError(err)
+ }
+
+ if requester != nil && visible {
+ // Ensure remote status is up-to-date.
+ p.federator.RefreshStatusAsync(ctx,
+ requester.Username,
+ target,
+ nil,
+ false,
+ )
+ }
+
+ return target, visible, nil
+}
+
+// GetTargetStatusByID is a call-through to GetTargetStatus() using the db GetStatusByID() function.
+func (p *Processor) GetTargetStatusByID(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ targetID string,
+) (
+ status *gtsmodel.Status,
+ visible bool,
+ errWithCode gtserror.WithCode,
+) {
+ return p.GetTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) {
+ return p.state.DB.GetStatusByID(ctx, targetID)
+ })
+}
+
+// GetVisibleTargetStatus calls GetTargetStatusByID(),
+// but converts a non-visible result to not-found error.
+func (p *Processor) GetVisibleTargetStatus(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ targetID string,
+) (
+ status *gtsmodel.Status,
+ errWithCode gtserror.WithCode,
+) {
+ // Fetch the target status by ID from the database.
+ target, visible, errWithCode := p.GetTargetStatusByID(ctx,
+ requester,
+ targetID,
+ )
+ if errWithCode != nil {
+ return nil, errWithCode
+ }
+
+ if !visible {
+ // Target should not be seen by requester.
+ err := errors.New("target status not found")
+ return nil, gtserror.NewErrorNotFound(err)
+ }
+
+ return target, nil
+}
+
+// GetAPIStatus fetches the appropriate API status model for target.
+func (p *Processor) GetAPIStatus(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ target *gtsmodel.Status,
+) (
+ apiStatus *apimodel.Status,
+ errWithCode gtserror.WithCode,
+) {
+ apiStatus, err := p.converter.StatusToAPIStatus(ctx, target, requester)
+ if err != nil {
+ err = gtserror.Newf("error converting status: %w", err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ return apiStatus, nil
+}
+
+// GetVisibleAPIStatuses converts an array of gtsmodel.Status (inputted by next function) into
+// API model statuses, checking first for visibility. Please note that all errors will be
+// logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping
+// errors in the lead-up to this function, whereas calling this should not be a show-stopper.
+func (p *Processor) GetVisibleAPIStatuses(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ next func(int) *gtsmodel.Status,
+ length int,
+) []*apimodel.Status {
+ return p.getVisibleAPIStatuses(ctx, 3, requester, next, length)
+}
+
+// GetVisibleAPIStatusesPaged is functionally equivalent to GetVisibleAPIStatuses(),
+// except the statuses are returned as a converted slice of statuses as interface{}.
+func (p *Processor) GetVisibleAPIStatusesPaged(
+ ctx context.Context,
+ requester *gtsmodel.Account,
+ next func(int) *gtsmodel.Status,
+ length int,
+) []interface{} {
+ statuses := p.getVisibleAPIStatuses(ctx, 3, requester, next, length)
+ if len(statuses) == 0 {
+ return nil
+ }
+ items := make([]interface{}, len(statuses))
+ for i, status := range statuses {
+ items[i] = status
+ }
+ return items
+}
+
+func (p *Processor) getVisibleAPIStatuses(
+ ctx context.Context,
+ calldepth int, // used to skip wrapping func above these's names
+ requester *gtsmodel.Account,
+ next func(int) *gtsmodel.Status,
+ length int,
+) []*apimodel.Status {
+ // Start new log entry with
+ // the above calling func's name.
+ l := log.
+ WithContext(ctx).
+ WithField("caller", log.Caller(calldepth+1))
+
+ // Preallocate slice according to expected length.
+ statuses := make([]*apimodel.Status, 0, length)
+
+ for i := 0; i < length; i++ {
+ // Get next status.
+ status := next(i)
+ if status == nil {
+ continue
+ }
+
+ // Check whether this status is visible to requesting account.
+ visible, err := p.filter.StatusVisible(ctx, requester, status)
+ if err != nil {
+ l.Errorf("error checking status visibility: %v", err)
+ continue
+ }
+
+ if !visible {
+ // Not visible to requester.
+ continue
+ }
+
+ // Convert the status to an API model representation.
+ apiStatus, err := p.converter.StatusToAPIStatus(ctx, status, requester)
+ if err != nil {
+ l.Errorf("error converting status: %v", err)
+ continue
+ }
+
+ // Append API model to return slice.
+ statuses = append(statuses, apiStatus)
+ }
+
+ return statuses
+}
+
+// InvalidateTimelinedStatus is a shortcut function for invalidating the cached
+// representation one status in the home timeline and all list timelines of the
+// given accountID. It should only be called in cases where a status update
+// does *not* need to be passed into the processor via the worker queue, since
+// such invalidation will, in that case, be handled by the processor instead.
+func (p *Processor) InvalidateTimelinedStatus(ctx context.Context, accountID string, statusID string) error {
+ // Get lists first + bail if this fails.
+ lists, err := p.state.DB.GetListsForAccountID(ctx, accountID)
+ if err != nil {
+ return gtserror.Newf("db error getting lists for account %s: %w", accountID, err)
+ }
+
+ // Start new log entry with
+ // the above calling func's name.
+ l := log.
+ WithContext(ctx).
+ WithField("caller", log.Caller(3)).
+ WithField("accountID", accountID).
+ WithField("statusID", statusID)
+
+ // Unprepare item from home + list timelines, just log
+ // if something goes wrong since this is not a showstopper.
+
+ if err := p.state.Timelines.Home.UnprepareItem(ctx, accountID, statusID); err != nil {
+ l.Errorf("error unpreparing item from home timeline: %v", err)
+ }
+
+ for _, list := range lists {
+ if err := p.state.Timelines.List.UnprepareItem(ctx, list.ID, statusID); err != nil {
+ l.Errorf("error unpreparing item from list timeline %s: %v", list.ID, err)
+ }
+ }
+
+ return nil
+}
diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go
deleted file mode 100644
index 6587b73bb..000000000
--- a/internal/processing/followrequest.go
+++ /dev/null
@@ -1,123 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-package processing
-
-import (
- "context"
- "errors"
-
- "github.com/superseriousbusiness/gotosocial/internal/ap"
- apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/log"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/oauth"
-)
-
-func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
- followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- accts := make([]apimodel.Account, 0, len(followRequests))
- for _, followRequest := range followRequests {
- if followRequest.Account == nil {
- // The creator of the follow doesn't exist,
- // just skip this one.
- log.WithContext(ctx).WithField("followRequest", followRequest).Warn("follow request had no associated account")
- continue
- }
-
- apiAcct, err := p.tc.AccountToAPIAccountPublic(ctx, followRequest.Account)
- if err != nil {
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- accts = append(accts, *apiAcct)
- }
-
- return accts, nil
-}
-
-func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
- follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
- if err != nil {
- return nil, gtserror.NewErrorNotFound(err)
- }
-
- if follow.Account == nil {
- // The creator of the follow doesn't exist,
- // so we can't do further processing.
- log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account")
- return p.relationship(ctx, auth.Account.ID, accountID)
- }
-
- p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
- APObjectType: ap.ActivityFollow,
- APActivityType: ap.ActivityAccept,
- GTSModel: follow,
- OriginAccount: follow.Account,
- TargetAccount: follow.TargetAccount,
- })
-
- return p.relationship(ctx, auth.Account.ID, accountID)
-}
-
-func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
- followRequest, err := p.state.DB.GetFollowRequest(ctx, accountID, auth.Account.ID)
- if err != nil {
- return nil, gtserror.NewErrorNotFound(err)
- }
-
- err = p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
- if err != nil {
- return nil, gtserror.NewErrorNotFound(err)
- }
-
- if followRequest.Account == nil {
- // The creator of the request doesn't exist,
- // so we can't do further processing.
- return p.relationship(ctx, auth.Account.ID, accountID)
- }
-
- p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
- APObjectType: ap.ActivityFollow,
- APActivityType: ap.ActivityReject,
- GTSModel: followRequest,
- OriginAccount: followRequest.Account,
- TargetAccount: followRequest.TargetAccount,
- })
-
- return p.relationship(ctx, auth.Account.ID, accountID)
-}
-
-func (p *Processor) relationship(ctx context.Context, accountID string, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- relationship, err := p.state.DB.GetRelationship(ctx, accountID, targetAccountID)
- if err != nil {
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- apiRelationship, err := p.tc.RelationshipToAPIRelationship(ctx, relationship)
- if err != nil {
- return nil, gtserror.NewErrorInternalError(err)
- }
-
- return apiRelationship, nil
-}
diff --git a/internal/processing/followrequest_test.go b/internal/processing/followrequest_test.go
index addb5052e..4c089be4a 100644
--- a/internal/processing/followrequest_test.go
+++ b/internal/processing/followrequest_test.go
@@ -30,35 +30,57 @@ import (
"github.com/superseriousbusiness/gotosocial/testrig"
)
+// TODO: move this to the "internal/processing/account" pkg
type FollowRequestTestSuite struct {
ProcessingStandardTestSuite
}
func (suite *FollowRequestTestSuite) TestFollowRequestAccept() {
- requestingAccount := suite.testAccounts["remote_account_2"]
- targetAccount := suite.testAccounts["local_account_1"]
+ // The authed local account we are going to use for HTTP requests
+ requestingAccount := suite.testAccounts["local_account_1"]
+
+ // The remote account whose follow request we are accepting
+ targetAccount := suite.testAccounts["remote_account_2"]
// put a follow request in the database
fr := &gtsmodel.FollowRequest{
ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
- URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI),
- AccountID: requestingAccount.ID,
- TargetAccountID: targetAccount.ID,
+ URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI),
+ AccountID: targetAccount.ID,
+ TargetAccountID: requestingAccount.ID,
}
err := suite.db.Put(context.Background(), fr)
suite.NoError(err)
- relationship, errWithCode := suite.processor.FollowRequestAccept(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID)
+ relationship, errWithCode := suite.processor.Account().FollowRequestAccept(
+ context.Background(),
+ requestingAccount,
+ targetAccount.ID,
+ )
suite.NoError(errWithCode)
- suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: true, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship)
+ suite.EqualValues(&apimodel.Relationship{
+ ID: "01FHMQX3GAABWSM0S2VZEC2SWC",
+ Following: false,
+ ShowingReblogs: false,
+ Notifying: false,
+ FollowedBy: true,
+ Blocking: false,
+ BlockedBy: false,
+ Muting: false,
+ MutingNotifications: false,
+ Requested: false,
+ DomainBlocking: false,
+ Endorsed: false,
+ Note: "",
+ }, relationship)
// accept should be sent to Some_User
var sent [][]byte
if !testrig.WaitFor(func() bool {
- sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI)
+ sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI)
if ok {
sent, ok = sentI.([][]byte)
if !ok {
@@ -87,41 +109,45 @@ func (suite *FollowRequestTestSuite) TestFollowRequestAccept() {
err = json.Unmarshal(sent[0], accept)
suite.NoError(err)
- suite.Equal(targetAccount.URI, accept.Actor)
- suite.Equal(requestingAccount.URI, accept.Object.Actor)
+ suite.Equal(requestingAccount.URI, accept.Actor)
+ suite.Equal(targetAccount.URI, accept.Object.Actor)
suite.Equal(fr.URI, accept.Object.ID)
- suite.Equal(targetAccount.URI, accept.Object.Object)
- suite.Equal(targetAccount.URI, accept.Object.To)
+ suite.Equal(requestingAccount.URI, accept.Object.Object)
+ suite.Equal(requestingAccount.URI, accept.Object.To)
suite.Equal("Follow", accept.Object.Type)
- suite.Equal(requestingAccount.URI, accept.To)
+ suite.Equal(targetAccount.URI, accept.To)
suite.Equal("Accept", accept.Type)
}
func (suite *FollowRequestTestSuite) TestFollowRequestReject() {
- requestingAccount := suite.testAccounts["remote_account_2"]
- targetAccount := suite.testAccounts["local_account_1"]
+ requestingAccount := suite.testAccounts["local_account_1"]
+ targetAccount := suite.testAccounts["remote_account_2"]
// put a follow request in the database
fr := &gtsmodel.FollowRequest{
ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
- URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI),
- AccountID: requestingAccount.ID,
- TargetAccountID: targetAccount.ID,
+ URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI),
+ AccountID: targetAccount.ID,
+ TargetAccountID: requestingAccount.ID,
}
err := suite.db.Put(context.Background(), fr)
suite.NoError(err)
- relationship, errWithCode := suite.processor.FollowRequestReject(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID)
+ relationship, errWithCode := suite.processor.Account().FollowRequestReject(
+ context.Background(),
+ requestingAccount,
+ targetAccount.ID,
+ )
suite.NoError(errWithCode)
suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: false, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship)
// reject should be sent to Some_User
var sent [][]byte
if !testrig.WaitFor(func() bool {
- sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI)
+ sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI)
if ok {
sent, ok = sentI.([][]byte)
if !ok {
@@ -150,13 +176,13 @@ func (suite *FollowRequestTestSuite) TestFollowRequestReject() {
err = json.Unmarshal(sent[0], reject)
suite.NoError(err)
- suite.Equal(targetAccount.URI, reject.Actor)
- suite.Equal(requestingAccount.URI, reject.Object.Actor)
+ suite.Equal(requestingAccount.URI, reject.Actor)
+ suite.Equal(targetAccount.URI, reject.Object.Actor)
suite.Equal(fr.URI, reject.Object.ID)
- suite.Equal(targetAccount.URI, reject.Object.Object)
- suite.Equal(targetAccount.URI, reject.Object.To)
+ suite.Equal(requestingAccount.URI, reject.Object.Object)
+ suite.Equal(requestingAccount.URI, reject.Object.To)
suite.Equal("Follow", reject.Object.Type)
- suite.Equal(requestingAccount.URI, reject.To)
+ suite.Equal(targetAccount.URI, reject.To)
suite.Equal("Reject", reject.Type)
}
diff --git a/internal/processing/processor.go b/internal/processing/processor.go
index c0fd15a24..f814d5a96 100644
--- a/internal/processing/processor.go
+++ b/internal/processing/processor.go
@@ -24,6 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/processing/admin"
+ "github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/processing/fedi"
"github.com/superseriousbusiness/gotosocial/internal/processing/list"
"github.com/superseriousbusiness/gotosocial/internal/processing/markers"
@@ -147,7 +148,8 @@ func NewProcessor(
//
// Start with sub processors that will
// be required by the workers processor.
- accountProcessor := account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc)
+ commonProcessor := common.New(state, tc, federator, filter)
+ accountProcessor := account.New(&commonProcessor, state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc)
mediaProcessor := media.New(state, tc, mediaManager, federator.TransportController())
streamProcessor := stream.New(state, oauthServer)
diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go
index 7b31ec977..4522d5858 100644
--- a/internal/timeline/get_test.go
+++ b/internal/timeline/get_test.go
@@ -66,6 +66,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st
follows, err := suite.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx),
accountID,
+ nil, // select all
)
if err != nil {
suite.FailNow(err.Error())
@@ -82,6 +83,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st
follows, err = suite.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx),
accountID,
+ nil, // select all
)
if err != nil {
suite.FailNow(err.Error())
diff --git a/testrig/testmodels.go b/testrig/testmodels.go
index 4f0768b45..fa6ff92ff 100644
--- a/testrig/testmodels.go
+++ b/testrig/testmodels.go
@@ -364,6 +364,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
SuspendedAt: time.Time{},
HideCollections: util.Ptr(false),
SuspensionOrigin: "",
+ EnableRSS: util.Ptr(false),
},
"admin_account": {
ID: "01F8MH17FWEB39HZJ76B6VXSKF",
@@ -539,6 +540,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
SuspendedAt: time.Time{},
HideCollections: util.Ptr(false),
SuspensionOrigin: "",
+ EnableRSS: util.Ptr(false),
},
"remote_account_2": {
ID: "01FHMQX3GAABWSM0S2VZEC2SWC",
@@ -575,6 +577,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
SuspendedAt: time.Time{},
HideCollections: util.Ptr(false),
SuspensionOrigin: "",
+ EnableRSS: util.Ptr(false),
},
"remote_account_3": {
ID: "062G5WYKY35KKD12EMSM3F8PJ8",
@@ -612,6 +615,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
HideCollections: util.Ptr(false),
SuspensionOrigin: "",
HeaderMediaAttachmentID: "01PFPMWK2FF0D9WMHEJHR07C3R",
+ EnableRSS: util.Ptr(false),
},
"remote_account_4": {
ID: "07GZRBAEMBNKGZ8Z9VSKSXKR98",
diff --git a/vendor/github.com/tomnomnom/linkheader/.gitignore b/vendor/github.com/tomnomnom/linkheader/.gitignore
new file mode 100644
index 000000000..0a00ddebb
--- /dev/null
+++ b/vendor/github.com/tomnomnom/linkheader/.gitignore
@@ -0,0 +1,2 @@
+cpu.out
+linkheader.test
diff --git a/vendor/github.com/tomnomnom/linkheader/.travis.yml b/vendor/github.com/tomnomnom/linkheader/.travis.yml
new file mode 100644
index 000000000..cfda08659
--- /dev/null
+++ b/vendor/github.com/tomnomnom/linkheader/.travis.yml
@@ -0,0 +1,6 @@
+language: go
+
+go:
+ - 1.6
+ - 1.7
+ - tip
diff --git a/vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd b/vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd
new file mode 100644
index 000000000..0339bec55
--- /dev/null
+++ b/vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd
@@ -0,0 +1,10 @@
+# Contributing
+
+* Raise an issue if appropriate
+* Fork the repo
+* Bootstrap the dev dependencies (run `./script/bootstrap`)
+* Make your changes
+* Use [gofmt](https://golang.org/cmd/gofmt/)
+* Make sure the tests pass (run `./script/test`)
+* Make sure the linters pass (run `./script/lint`)
+* Issue a pull request
diff --git a/vendor/github.com/tomnomnom/linkheader/LICENSE b/vendor/github.com/tomnomnom/linkheader/LICENSE
new file mode 100644
index 000000000..55192df56
--- /dev/null
+++ b/vendor/github.com/tomnomnom/linkheader/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2016 Tom Hudson
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/vendor/github.com/tomnomnom/linkheader/README.mkd b/vendor/github.com/tomnomnom/linkheader/README.mkd
new file mode 100644
index 000000000..2a949cac2
--- /dev/null
+++ b/vendor/github.com/tomnomnom/linkheader/README.mkd
@@ -0,0 +1,35 @@
+# Golang Link Header Parser
+
+Library for parsing HTTP Link headers. Requires Go 1.6 or higher.
+
+Docs can be found on [the GoDoc page](https://godoc.org/github.com/tomnomnom/linkheader).
+
+[![Build Status](https://travis-ci.org/tomnomnom/linkheader.svg)](https://travis-ci.org/tomnomnom/linkheader)
+
+## Basic Example
+
+```go
+package main
+
+import (
+ "fmt"
+
+ "github.com/tomnomnom/linkheader"
+)
+
+func main() {
+ header := "<https://api.github.com/user/58276/repos?page=2>; rel=\"next\"," +
+ "<https://api.github.com/user/58276/repos?page=2>; rel=\"last\""
+ links := linkheader.Parse(header)
+
+ for _, link := range links {
+ fmt.Printf("URL: %s; Rel: %s\n", link.URL, link.Rel)
+ }
+}
+
+// Output:
+// URL: https://api.github.com/user/58276/repos?page=2; Rel: next
+// URL: https://api.github.com/user/58276/repos?page=2; Rel: last
+```
+
+
diff --git a/vendor/github.com/tomnomnom/linkheader/main.go b/vendor/github.com/tomnomnom/linkheader/main.go
new file mode 100644
index 000000000..6b81321b8
--- /dev/null
+++ b/vendor/github.com/tomnomnom/linkheader/main.go
@@ -0,0 +1,151 @@
+// Package linkheader provides functions for parsing HTTP Link headers
+package linkheader
+
+import (
+ "fmt"
+ "strings"
+)
+
+// A Link is a single URL and related parameters
+type Link struct {
+ URL string
+ Rel string
+ Params map[string]string
+}
+
+// HasParam returns if a Link has a particular parameter or not
+func (l Link) HasParam(key string) bool {
+ for p := range l.Params {
+ if p == key {
+ return true
+ }
+ }
+ return false
+}
+
+// Param returns the value of a parameter if it exists
+func (l Link) Param(key string) string {
+ for k, v := range l.Params {
+ if key == k {
+ return v
+ }
+ }
+ return ""
+}
+
+// String returns the string representation of a link
+func (l Link) String() string {
+
+ p := make([]string, 0, len(l.Params))
+ for k, v := range l.Params {
+ p = append(p, fmt.Sprintf("%s=\"%s\"", k, v))
+ }
+ if l.Rel != "" {
+ p = append(p, fmt.Sprintf("%s=\"%s\"", "rel", l.Rel))
+ }
+ return fmt.Sprintf("<%s>; %s", l.URL, strings.Join(p, "; "))
+}
+
+// Links is a slice of Link structs
+type Links []Link
+
+// FilterByRel filters a group of Links by the provided Rel attribute
+func (l Links) FilterByRel(r string) Links {
+ links := make(Links, 0)
+ for _, link := range l {
+ if link.Rel == r {
+ links = append(links, link)
+ }
+ }
+ return links
+}
+
+// String returns the string representation of multiple Links
+// for use in HTTP responses etc
+func (l Links) String() string {
+ if l == nil {
+ return fmt.Sprint(nil)
+ }
+
+ var strs []string
+ for _, link := range l {
+ strs = append(strs, link.String())
+ }
+ return strings.Join(strs, ", ")
+}
+
+// Parse parses a raw Link header in the form:
+// <url>; rel="foo", <url>; rel="bar"; wat="dis"
+// returning a slice of Link structs
+func Parse(raw string) Links {
+ var links Links
+
+ // One chunk: <url>; rel="foo"
+ for _, chunk := range strings.Split(raw, ",") {
+
+ link := Link{URL: "", Rel: "", Params: make(map[string]string)}
+
+ // Figure out what each piece of the chunk is
+ for _, piece := range strings.Split(chunk, ";") {
+
+ piece = strings.Trim(piece, " ")
+ if piece == "" {
+ continue
+ }
+
+ // URL
+ if piece[0] == '<' && piece[len(piece)-1] == '>' {
+ link.URL = strings.Trim(piece, "<>")
+ continue
+ }
+
+ // Params
+ key, val := parseParam(piece)
+ if key == "" {
+ continue
+ }
+
+ // Special case for rel
+ if strings.ToLower(key) == "rel" {
+ link.Rel = val
+ } else {
+ link.Params[key] = val
+ }
+ }
+
+ if link.URL != "" {
+ links = append(links, link)
+ }
+ }
+
+ return links
+}
+
+// ParseMultiple is like Parse, but accepts a slice of headers
+// rather than just one header string
+func ParseMultiple(headers []string) Links {
+ links := make(Links, 0)
+ for _, header := range headers {
+ links = append(links, Parse(header)...)
+ }
+ return links
+}
+
+// parseParam takes a raw param in the form key="val" and
+// returns the key and value as seperate strings
+func parseParam(raw string) (key, val string) {
+
+ parts := strings.SplitN(raw, "=", 2)
+ if len(parts) == 1 {
+ return parts[0], ""
+ }
+ if len(parts) != 2 {
+ return "", ""
+ }
+
+ key = parts[0]
+ val = strings.Trim(parts[1], "\"")
+
+ return key, val
+
+}
diff --git a/vendor/modules.txt b/vendor/modules.txt
index dae43a87f..99c7b384f 100644
--- a/vendor/modules.txt
+++ b/vendor/modules.txt
@@ -672,6 +672,9 @@ github.com/tdewolff/parse/v2/strconv
# github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
## explicit
github.com/tmthrgd/go-hex
+# github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
+## explicit
+github.com/tomnomnom/linkheader
# github.com/twitchyliquid64/golang-asm v0.15.1
## explicit; go 1.13
github.com/twitchyliquid64/golang-asm/asm/arch