diff options
Diffstat (limited to 'internal/api/client/streaming/stream.go')
-rw-r--r-- | internal/api/client/streaming/stream.go | 60 |
1 files changed, 43 insertions, 17 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index a9cb62732..de98719c2 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package streaming import ( @@ -6,7 +24,7 @@ import ( "time" "codeberg.org/gruf/go-kv" - "github.com/superseriousbusiness/gotosocial/internal/api" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -14,12 +32,15 @@ import ( "github.com/gorilla/websocket" ) -var wsUpgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - // we expect cors requests (via eg., pinafore.social) so be lenient - CheckOrigin: func(r *http.Request) bool { return true }, -} +var ( + wsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // we expect cors requests (via eg., pinafore.social) so be lenient + CheckOrigin: func(r *http.Request) bool { return true }, + } + errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader) +) // StreamGETHandler swagger:operation GET /api/v1/streaming streamGet // @@ -125,29 +146,33 @@ func (m *Module) StreamGETHandler(c *gin.Context) { streamType := c.Query(StreamQueryKey) if streamType == "" { err := fmt.Errorf("no stream type provided under query key %s", StreamQueryKey) - api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) return } - accessToken := c.Query(AccessTokenQueryKey) - if accessToken == "" { - accessToken = c.GetHeader(AccessTokenHeader) - } - if accessToken == "" { - err := fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader) - api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) + var accessToken string + if t := c.Query(AccessTokenQueryKey); t != "" { + // try query param first + accessToken = t + } else if t := c.GetHeader(AccessTokenHeader); t != "" { + // fall back to Sec-Websocket-Protocol + accessToken = t + } else { + // no token + err := errNoToken + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) return } account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken) if errWithCode != nil { - api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) return } stream, errWithCode := m.processor.OpenStreamForAccount(c.Request.Context(), account, streamType) if errWithCode != nil { - api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) return } @@ -175,6 +200,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { }() streamTicker := time.NewTicker(m.tickDuration) + defer streamTicker.Stop() // We want to stay in the loop as long as possible while the client is connected. // The only thing that should break the loop is if the client leaves or the connection becomes unhealthy. |