summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/client/streaming/stream.go35
-rw-r--r--internal/api/client/streaming/streaming_test.go4
2 files changed, 29 insertions, 10 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index e41531a59..900df4383 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -19,6 +19,7 @@ package streaming
import (
"context"
+ "net/http"
"slices"
"time"
@@ -151,15 +152,24 @@ import (
// description: bad request
func (m *Module) StreamGETHandler(c *gin.Context) {
var (
- account *gtsmodel.Account
- errWithCode gtserror.WithCode
+ token string
+ tokenInHeader bool
+ account *gtsmodel.Account
+ errWithCode gtserror.WithCode
)
- // Try query param access token.
- token := c.Query(AccessTokenQueryKey)
- if token == "" {
- // Try fallback HTTP header provided token.
- token = c.GetHeader(AccessTokenHeader)
+ if t := c.Query(AccessTokenQueryKey); t != "" {
+ // Token was provided as
+ // query param, no problem.
+ token = t
+ } else if t := c.GetHeader(AccessTokenHeader); t != "" {
+ // Token was provided in "Sec-Websocket-Protocol" header.
+ //
+ // This is hacky and not technically correct but some
+ // clients do it since Mastodon allows it, so we must
+ // also allow it to avoid breaking expectations.
+ token = t
+ tokenInHeader = true
}
if token != "" {
@@ -230,7 +240,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
//
// If the upgrade fails, then Upgrade replies to the client
// with an HTTP error response.
- wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
+ var responseHeader http.Header
+ if tokenInHeader {
+ // Return the token in the response,
+ // else Chrome fails to connect.
+ //
+ // https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism#sec-websocket-protocol
+ responseHeader = http.Header{AccessTokenHeader: {token}}
+ }
+
+ wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, responseHeader)
if err != nil {
l.Errorf("error upgrading websocket connection: %v", err)
stream.Close()
diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go
index 1d94a87ec..acdcafd8a 100644
--- a/internal/api/client/streaming/streaming_test.go
+++ b/internal/api/client/streaming/streaming_test.go
@@ -22,7 +22,7 @@ import (
"encoding/base64"
"errors"
"fmt"
- "io/ioutil"
+ "io"
"net"
"net/http"
"net/http/httptest"
@@ -236,7 +236,7 @@ func (suite *StreamingTestSuite) TestSecurityHeader() {
result := recorder.Result()
defer result.Body.Close()
- b, err := ioutil.ReadAll(result.Body)
+ b, err := io.ReadAll(result.Body)
suite.NoError(err)
// check response