summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/client/auth/callback.go4
-rw-r--r--internal/api/security/signaturecheck.go23
2 files changed, 4 insertions, 23 deletions
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index 8bf2a50b5..a26838aa3 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -116,7 +116,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return user, nil
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}
@@ -128,7 +128,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}
diff --git a/internal/api/security/signaturecheck.go b/internal/api/security/signaturecheck.go
index b852c92ab..88b0b4dff 100644
--- a/internal/api/security/signaturecheck.go
+++ b/internal/api/security/signaturecheck.go
@@ -6,8 +6,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-fed/httpsig"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
@@ -33,13 +31,13 @@ func (m *Module) SignatureCheck(c *gin.Context) {
// we managed to parse the url!
// if the domain is blocked we want to bail as early as possible
- blockedDomain, err := m.blockedDomain(requestingPublicKeyID.Host)
+ blocked, err := m.db.IsURIBlocked(requestingPublicKeyID)
if err != nil {
l.Errorf("could not tell if domain %s was blocked or not: %s", requestingPublicKeyID.Host, err)
c.AbortWithStatus(http.StatusInternalServerError)
return
}
- if blockedDomain {
+ if blocked {
l.Infof("domain %s is blocked", requestingPublicKeyID.Host)
c.AbortWithStatus(http.StatusForbidden)
return
@@ -50,20 +48,3 @@ func (m *Module) SignatureCheck(c *gin.Context) {
}
}
}
-
-func (m *Module) blockedDomain(host string) (bool, error) {
- b := &gtsmodel.DomainBlock{}
- err := m.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b)
- if err == nil {
- // block exists
- return true, nil
- }
-
- if _, ok := err.(db.ErrNoEntries); ok {
- // there are no entries so there's no block
- return false, nil
- }
-
- // there's an actual error
- return false, err
-}