diff options
Diffstat (limited to 'internal/api/client/streaming/stream.go')
-rw-r--r-- | internal/api/client/streaming/stream.go | 152 |
1 files changed, 94 insertions, 58 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index fc14e87e3..c175c8461 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -19,8 +19,9 @@ package streaming import ( + "context" + "errors" "fmt" - "net/http" "time" "codeberg.org/gruf/go-kv" @@ -32,16 +33,6 @@ import ( "github.com/gorilla/websocket" ) -var ( - wsUpgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - // we expect cors requests (via eg., pinafore.social) so be lenient - CheckOrigin: func(r *http.Request) bool { return true }, - } - errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader) -) - // StreamGETHandler swagger:operation GET /api/v1/streaming streamGet // // Initiate a websocket connection for live streaming of statuses and notifications. @@ -150,21 +141,20 @@ func (m *Module) StreamGETHandler(c *gin.Context) { return } - var accessToken string - if t := c.Query(AccessTokenQueryKey); t != "" { - // try query param first - accessToken = t - } else if t := c.GetHeader(AccessTokenHeader); t != "" { - // fall back to Sec-Websocket-Protocol - accessToken = t - } else { - // no token - err := errNoToken - apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) - return + var token string + + // First we check for a query param provided access token + if token = c.Query(AccessTokenQueryKey); token == "" { + // Else we check the HTTP header provided token + if token = c.GetHeader(AccessTokenHeader); token == "" { + const errStr = "no access token provided" + err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr) + apiutil.ErrorHandler(c, err, m.processor.InstanceGet) + return + } } - account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken) + account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), token) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) return @@ -178,51 +168,97 @@ func (m *Module) StreamGETHandler(c *gin.Context) { l := log.WithFields(kv.Fields{ {"account", account.Username}, - {"path", BasePath}, {"streamID", stream.ID}, {"streamType", streamType}, }...) - wsConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) + // Upgrade the incoming HTTP request, which hijacks the underlying + // connection and reuses it for the websocket (non-http) protocol. + wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) if err != nil { - // If the upgrade fails, then Upgrade replies to the client with an HTTP error response. - // Because websocket issues are a pretty common source of headaches, we should also log - // this at Error to make this plenty visible and help admins out a bit. - l.Errorf("error upgrading websocket connection: %s", err) + l.Errorf("error upgrading websocket connection: %v", err) close(stream.Hangup) return } - defer func() { - // cleanup - wsConn.Close() - close(stream.Hangup) - }() + go func() { + // We perform the main websocket send loop in a separate + // goroutine in order to let the upgrade handler return. + // This prevents the upgrade handler from holding open any + // throttle / rate-limit request tokens which could become + // problematic on instances with multiple users. + l.Info("opened websocket connection") + defer l.Info("closed websocket connection") + + // Create new context for lifetime of the connection + ctx, cncl := context.WithCancel(context.Background()) + + // Create ticker to send alive pings + pinger := time.NewTicker(m.dTicker) + + defer func() { + // Signal done + cncl() - streamTicker := time.NewTicker(m.tickDuration) - defer streamTicker.Stop() - - // We want to stay in the loop as long as possible while the client is connected. - // The only thing that should break the loop is if the client leaves or the connection becomes unhealthy. - // - // If the loop does break, we expect the client to reattempt connection, so it's cheap to leave + try again -wsLoop: - for { - select { - case m := <-stream.Messages: - l.Trace("received message from stream") - if err := wsConn.WriteJSON(m); err != nil { - l.Debugf("error writing json to websocket connection; breaking off: %s", err) - break wsLoop + // Close websocket conn + _ = wsConn.Close() + + // Close processor stream + close(stream.Hangup) + + // Stop ping ticker + pinger.Stop() + }() + + go func() { + // Signal done + defer cncl() + + for { + // We have to listen for received websocket messages in + // order to trigger the underlying wsConn.PingHandler(). + // + // So we wait on received messages but only act on errors. + _, _, err := wsConn.ReadMessage() + if err != nil { + if ctx.Err() == nil { + // Only log error if the connection was not closed + // by us. Uncanceled context indicates this is the case. + l.Errorf("error reading from websocket: %v", err) + } + return + } } - l.Trace("wrote message into websocket connection") - case <-streamTicker.C: - l.Trace("received TICK from ticker") - if err := wsConn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil { - l.Debugf("error writing ping to websocket connection; breaking off: %s", err) - break wsLoop + }() + + for { + select { + // Connection closed + case <-ctx.Done(): + return + + // Received next stream message + case msg := <-stream.Messages: + l.Tracef("sending message to websocket: %+v", msg) + if err := wsConn.WriteJSON(msg); err != nil { + l.Errorf("error writing json to websocket: %v", err) + return + } + + // Reset on each successful send. + pinger.Reset(m.dTicker) + + // Send keep-alive "ping" + case <-pinger.C: + l.Trace("pinging websocket ...") + if err := wsConn.WriteMessage( + websocket.PingMessage, + []byte{}, + ); err != nil { + l.Errorf("error writing ping to websocket: %v", err) + return + } } - l.Trace("wrote ping message into websocket connection") } - } + }() } |