summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/client/streaming/stream.go27
-rw-r--r--internal/api/client/streaming/streaming_test.go15
2 files changed, 25 insertions, 17 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index 1f34e3447..2d1c48341 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -162,24 +162,27 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
if token != "" {
+
// Token was provided, use it to authorize stream.
account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
+ }
+
} else {
+
// No explicit token was provided:
// try regular oauth as a last resort.
- account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) {
- authed, err := oauth.Authed(c, true, true, true, true)
- if err != nil {
- return nil, gtserror.NewErrorUnauthorized(err, err.Error())
- }
-
- return authed.Account, nil
- }()
- }
+ authed, err := oauth.Authed(c, true, true, true, true)
+ if err != nil {
+ errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
+ }
- if errWithCode != nil {
- apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
- return
+ // Set the auth'ed account.
+ account = authed.Account
}
// Get the initial requested stream type, if there is one.
diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go
index 30574080e..df4009890 100644
--- a/internal/api/client/streaming/streaming_test.go
+++ b/internal/api/client/streaming/streaming_test.go
@@ -19,6 +19,7 @@ package streaming_test
import (
"bufio"
+ "encoding/base64"
"errors"
"fmt"
"io/ioutil"
@@ -227,17 +228,21 @@ func (suite *StreamingTestSuite) TestSecurityHeader() {
ctx.Request.Header.Set("Connection", "upgrade")
ctx.Request.Header.Set("Upgrade", "websocket")
ctx.Request.Header.Set("Sec-Websocket-Version", "13")
- ctx.Request.Header.Set("Sec-Websocket-Key", "abcd")
+ key := [16]byte{'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'}
+ key64 := base64.StdEncoding.EncodeToString(key[:]) // sec-websocket-key must be base64 encoded and 16 bytes long
+ ctx.Request.Header.Set("Sec-Websocket-Key", key64)
suite.streamingModule.StreamGETHandler(ctx)
- // check response
- suite.EqualValues(http.StatusOK, recorder.Code)
-
result := recorder.Result()
defer result.Body.Close()
- _, err := ioutil.ReadAll(result.Body)
+ b, err := ioutil.ReadAll(result.Body)
suite.NoError(err)
+
+ // check response
+ if !suite.EqualValues(http.StatusOK, recorder.Code) {
+ suite.T().Logf("%s", b)
+ }
}
func TestStreamingTestSuite(t *testing.T) {