summaryrefslogtreecommitdiff
path: root/internal/federation/federatingprotocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/federation/federatingprotocol.go')
-rw-r--r--internal/federation/federatingprotocol.go40
1 files changed, 18 insertions, 22 deletions
diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go
index 9e21b43bf..5da68afd3 100644
--- a/internal/federation/federatingprotocol.go
+++ b/internal/federation/federatingprotocol.go
@@ -113,8 +113,8 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
return nil, false, errors.New("username was empty")
}
- requestedAccount := &gtsmodel.Account{}
- if err := f.db.GetLocalAccountByUsername(username, requestedAccount); err != nil {
+ requestedAccount, err := f.db.GetLocalAccountByUsername(username)
+ if err != nil {
return nil, false, fmt.Errorf("could not fetch requested account with username %s: %s", username, err)
}
@@ -132,7 +132,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// authentication has passed, so add an instance entry for this instance if it hasn't been done already
i := &gtsmodel.Instance{}
if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// there's been an actual error
return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err)
}
@@ -176,8 +176,6 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// Finally, if the authentication and authorization succeeds, then
// blocked must be false and error nil. The request will continue
// to be processed.
-//
-// TODO: implement domain block checking here as well
func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) {
l := f.log.WithFields(logrus.Fields{
"func": "Blocked",
@@ -191,19 +189,18 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, errors.New("requested account not set on request context, so couldn't determine blocks")
}
+ blocked, err := f.db.AreURIsBlocked(actorIRIs)
+ if err != nil {
+ return false, fmt.Errorf("error checking domain blocks: %s", err)
+ }
+ if blocked {
+ return blocked, nil
+ }
+
for _, uri := range actorIRIs {
- blockedDomain, err := f.blockedDomain(uri.Host)
+ requestingAccount, err := f.db.GetAccountByURI(uri.String())
if err != nil {
- return false, fmt.Errorf("error checking domain block: %s", err)
- }
- if blockedDomain {
- return true, nil
- }
-
- requestingAccount := &gtsmodel.Account{}
- if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uri.String()}}, requestingAccount); err != nil {
- _, ok := err.(db.ErrNoEntries)
- if ok {
+ if err == db.ErrNoEntries {
// we don't have an entry for this account so it's not blocked
// TODO: allow a different default to be set for this behavior
continue
@@ -211,12 +208,11 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, fmt.Errorf("error getting account with uri %s: %s", uri.String(), err)
}
- // check if requested account blocks requesting account
- if err := f.db.GetWhere([]db.Where{
- {Key: "account_id", Value: requestedAccount.ID},
- {Key: "target_account_id", Value: requestingAccount.ID},
- }, &gtsmodel.Block{}); err == nil {
- // a block exists
+ blocked, err = f.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ if err != nil {
+ return false, fmt.Errorf("error checking account block: %s", err)
+ }
+ if blocked {
return true, nil
}
}