diff options
Diffstat (limited to 'internal/api/s2s/webfinger')
| -rw-r--r-- | internal/api/s2s/webfinger/webfinger_test.go | 2 | ||||
| -rw-r--r-- | internal/api/s2s/webfinger/webfingerget.go | 34 | 
2 files changed, 10 insertions, 26 deletions
| diff --git a/internal/api/s2s/webfinger/webfinger_test.go b/internal/api/s2s/webfinger/webfinger_test.go index 0df50c503..9758a6be7 100644 --- a/internal/api/s2s/webfinger/webfinger_test.go +++ b/internal/api/s2s/webfinger/webfinger_test.go @@ -88,7 +88,7 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.storage = testrig.NewTestStorage()  	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)  	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)  	suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module) diff --git a/internal/api/s2s/webfinger/webfingerget.go b/internal/api/s2s/webfinger/webfingerget.go index 7e7ca006b..6d4764ce5 100644 --- a/internal/api/s2s/webfinger/webfingerget.go +++ b/internal/api/s2s/webfinger/webfingerget.go @@ -22,13 +22,13 @@ import (  	"context"  	"fmt"  	"net/http" -	"strings"  	"github.com/gin-gonic/gin"  	"github.com/sirupsen/logrus"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/api"  	"github.com/superseriousbusiness/gotosocial/internal/config" +	"github.com/superseriousbusiness/gotosocial/internal/util"  )  // WebfingerGETRequest swagger:operation GET /.well-known/webfinger webfingerGet @@ -67,35 +67,19 @@ func (m *Module) WebfingerGETRequest(c *gin.Context) {  		return  	} -	// remove the acct: prefix if it's present -	trimAcct := strings.TrimPrefix(resourceQuery, "acct:") -	// remove the first @ in @whatever@example.org if it's present -	namestring := strings.TrimPrefix(trimAcct, "@") - -	// at this point we should have a string like some_user@example.org -	l.Debugf("got finger request for '%s'", namestring) - -	usernameAndAccountDomain := strings.Split(namestring, "@") -	if len(usernameAndAccountDomain) != 2 { -		l.Debugf("aborting request because username and domain could not be parsed from %s", namestring) -		c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) -		return -	} - -	username := strings.ToLower(usernameAndAccountDomain[0]) -	requestedAccountDomain := strings.ToLower(usernameAndAccountDomain[1]) -	if username == "" || requestedAccountDomain == "" { -		l.Debug("aborting request because username or domain was empty") -		c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) +	requestedUsername, requestedHost, err := util.ExtractWebfingerParts(resourceQuery) +	if err != nil { +		l.Debugf("bad webfinger request with resource query %s: %s", resourceQuery, err) +		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("bad webfinger request with resource query %s", resourceQuery)})  		return  	}  	accountDomain := config.GetAccountDomain()  	host := config.GetHost() -	if requestedAccountDomain != accountDomain && requestedAccountDomain != host { -		l.Debugf("aborting request because accountDomain %s does not belong to this instance", requestedAccountDomain) -		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("accountDomain %s does not belong to this instance", requestedAccountDomain)}) +	if requestedHost != host && requestedHost != accountDomain { +		l.Debugf("aborting request because requestedHost %s does not belong to this instance", requestedHost) +		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("requested host %s does not belong to this instance", requestedHost)})  		return  	} @@ -106,7 +90,7 @@ func (m *Module) WebfingerGETRequest(c *gin.Context) {  		ctx = context.WithValue(ctx, ap.ContextRequestingPublicKeyVerifier, verifier)  	} -	resp, errWithCode := m.processor.GetWebfingerAccount(ctx, username) +	resp, errWithCode := m.processor.GetWebfingerAccount(ctx, requestedUsername)  	if errWithCode != nil {  		api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)  		return | 
