diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/client/streaming/stream.go | 27 | ||||
| -rw-r--r-- | internal/api/client/streaming/streaming_test.go | 15 | 
2 files changed, 25 insertions, 17 deletions
| diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 1f34e3447..2d1c48341 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -162,24 +162,27 @@ func (m *Module) StreamGETHandler(c *gin.Context) {  	}  	if token != "" { +  		// Token was provided, use it to authorize stream.  		account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) +		if errWithCode != nil { +			apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) +			return +		} +  	} else { +  		// No explicit token was provided:  		// try regular oauth as a last resort. -		account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) { -			authed, err := oauth.Authed(c, true, true, true, true) -			if err != nil { -				return nil, gtserror.NewErrorUnauthorized(err, err.Error()) -			} - -			return authed.Account, nil -		}() -	} +		authed, err := oauth.Authed(c, true, true, true, true) +		if err != nil { +			errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) +			apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) +			return +		} -	if errWithCode != nil { -		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) -		return +		// Set the auth'ed account. +		account = authed.Account  	}  	// Get the initial requested stream type, if there is one. diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index 30574080e..df4009890 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -19,6 +19,7 @@ package streaming_test  import (  	"bufio" +	"encoding/base64"  	"errors"  	"fmt"  	"io/ioutil" @@ -227,17 +228,21 @@ func (suite *StreamingTestSuite) TestSecurityHeader() {  	ctx.Request.Header.Set("Connection", "upgrade")  	ctx.Request.Header.Set("Upgrade", "websocket")  	ctx.Request.Header.Set("Sec-Websocket-Version", "13") -	ctx.Request.Header.Set("Sec-Websocket-Key", "abcd") +	key := [16]byte{'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'} +	key64 := base64.StdEncoding.EncodeToString(key[:]) // sec-websocket-key must be base64 encoded and 16 bytes long +	ctx.Request.Header.Set("Sec-Websocket-Key", key64)  	suite.streamingModule.StreamGETHandler(ctx) -	// check response -	suite.EqualValues(http.StatusOK, recorder.Code) -  	result := recorder.Result()  	defer result.Body.Close() -	_, err := ioutil.ReadAll(result.Body) +	b, err := ioutil.ReadAll(result.Body)  	suite.NoError(err) + +	// check response +	if !suite.EqualValues(http.StatusOK, recorder.Code) { +		suite.T().Logf("%s", b) +	}  }  func TestStreamingTestSuite(t *testing.T) { | 
