summaryrefslogtreecommitdiff
path: root/internal/api/client
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/client')
-rw-r--r--internal/api/client/streaming/stream.go361
-rw-r--r--internal/api/client/streaming/streaming.go11
2 files changed, 238 insertions, 134 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index 88c682a75..1f34e3447 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -149,60 +149,78 @@ import (
// '400':
// description: bad request
func (m *Module) StreamGETHandler(c *gin.Context) {
+ var (
+ account *gtsmodel.Account
+ errWithCode gtserror.WithCode
+ )
- // First we check for a query param provided access token
+ // Try query param access token.
token := c.Query(AccessTokenQueryKey)
if token == "" {
- // Else we check the HTTP header provided token
+ // Try fallback HTTP header provided token.
token = c.GetHeader(AccessTokenHeader)
}
- var account *gtsmodel.Account
if token != "" {
- // Check the explicit token
- var errWithCode gtserror.WithCode
+ // Token was provided, use it to authorize stream.
account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
- if errWithCode != nil {
- apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
- return
- }
} else {
- // If no explicit token was provided, try regular oauth
- auth, errStr := oauth.Authed(c, true, true, true, true)
- if errStr != nil {
- err := gtserror.NewErrorUnauthorized(errStr, errStr.Error())
- apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1)
- return
- }
- account = auth.Account
+ // No explicit token was provided:
+ // try regular oauth as a last resort.
+ account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) {
+ authed, err := oauth.Authed(c, true, true, true, true)
+ if err != nil {
+ return nil, gtserror.NewErrorUnauthorized(err, err.Error())
+ }
+
+ return authed.Account, nil
+ }()
+ }
+
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
}
- // Get the initial stream type, if there is one.
- // By appending other query params to the streamType,
- // we can allow for streaming for specific list IDs
- // or hashtags.
+ // Get the initial requested stream type, if there is one.
streamType := c.Query(StreamQueryKey)
+
+ // By appending other query params to the streamType, we
+ // can allow streaming for specific list IDs or hashtags.
+ // The streamType in this case will end up looking like
+ // `hashtag:example` or `list:01H3YF48G8B7KTPQFS8D2QBVG8`.
if list := c.Query(StreamListKey); list != "" {
streamType += ":" + list
} else if tag := c.Query(StreamTagKey); tag != "" {
streamType += ":" + tag
}
- stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
+ // Open a stream with the processor; this lets processor
+ // functions pass messages into a channel, which we can
+ // then read from and put into a websockets connection.
+ stream, errWithCode := m.processor.Stream().Open(
+ c.Request.Context(),
+ account,
+ streamType,
+ )
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
- l := log.WithContext(c.Request.Context()).
+ l := log.
+ WithContext(c.Request.Context()).
WithFields(kv.Fields{
- {"account", account.Username},
+ {"username", account.Username},
{"streamID", stream.ID},
- {"streamType", streamType},
}...)
- // Upgrade the incoming HTTP request, which hijacks the underlying
- // connection and reuses it for the websocket (non-http) protocol.
+ // Upgrade the incoming HTTP request. This hijacks the
+ // underlying connection and reuses it for the websocket
+ // (non-http) protocol.
+ //
+ // If the upgrade fails, then Upgrade replies to the client
+ // with an HTTP error response.
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil {
l.Errorf("error upgrading websocket connection: %v", err)
@@ -210,125 +228,208 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
return
}
+ l.Info("opened websocket connection")
+
+ // We perform the main websocket rw loops in a separate
+ // goroutine in order to let the upgrade handler return.
+ // This prevents the upgrade handler from holding open any
+ // throttle / rate-limit request tokens which could become
+ // problematic on instances with multiple users.
+ go m.handleWSConn(account.Username, wsConn, stream)
+}
+
+// handleWSConn handles a two-way websocket streaming connection.
+// It will both read messages from the connection, and push messages
+// into the connection. If any errors are encountered while reading
+// or writing (including expected errors like clients leaving), the
+// connection will be closed.
+func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
+ // Create new context for the lifetime of this connection.
+ ctx, cancel := context.WithCancel(context.Background())
+
+ l := log.
+ WithContext(ctx).
+ WithFields(kv.Fields{
+ {"username", username},
+ {"streamID", stream.ID},
+ }...)
+
+ // Create ticker to send keepalive pings
+ pinger := time.NewTicker(m.dTicker)
+
+ // Read messages coming from the Websocket client connection into the server.
go func() {
- // We perform the main websocket send loop in a separate
- // goroutine in order to let the upgrade handler return.
- // This prevents the upgrade handler from holding open any
- // throttle / rate-limit request tokens which could become
- // problematic on instances with multiple users.
- l.Info("opened websocket connection")
- defer l.Info("closed websocket connection")
+ defer cancel()
+ m.readFromWSConn(ctx, username, wsConn, stream)
+ }()
- // Create new context for lifetime of the connection
- ctx, cncl := context.WithCancel(context.Background())
+ // Write messages coming from the processor into the Websocket client connection.
+ go func() {
+ defer cancel()
+ m.writeToWSConn(ctx, username, wsConn, stream, pinger)
+ }()
- // Create ticker to send alive pings
- pinger := time.NewTicker(m.dTicker)
+ // Wait for either the read or write functions to close, to indicate
+ // that the client has left, or something else has gone wrong.
+ <-ctx.Done()
- defer func() {
- // Signal done
- cncl()
+ // Tidy up underlying websocket connection.
+ if err := wsConn.Close(); err != nil {
+ l.Errorf("error closing websocket connection: %v", err)
+ }
- // Close websocket conn
- _ = wsConn.Close()
+ // Close processor channel so the processor knows
+ // not to send any more messages to this stream.
+ close(stream.Hangup)
- // Close processor stream
- close(stream.Hangup)
+ // Stop ping ticker (tiny resource saving).
+ pinger.Stop()
- // Stop ping ticker
- pinger.Stop()
- }()
+ l.Info("closed websocket connection")
+}
- go func() {
- // Signal done
- defer cncl()
-
- for {
- // We have to listen for received websocket messages in
- // order to trigger the underlying wsConn.PingHandler().
- //
- // Read JSON objects from the client and act on them
- var msg map[string]string
- err := wsConn.ReadJSON(&msg)
- if err != nil {
- if ctx.Err() == nil {
- // Only log error if the connection was not closed
- // by us. Uncanceled context indicates this is the case.
- l.Errorf("error reading from websocket: %v", err)
- }
- return
- }
- l.Tracef("received message from websocket: %v", msg)
-
- // If the message contains 'stream' and 'type' fields, we can
- // update the set of timelines that are subscribed for events.
- updateType, ok := msg["type"]
- if !ok {
- l.Warn("'type' field not provided")
- continue
- }
+// readFromWSConn reads control messages coming in from the given
+// websockets connection, and modifies the subscription StreamTypes
+// of the given stream accordingly after acquiring a lock on it.
+//
+// This is a blocking function; will return only on read error or
+// if the given context is canceled.
+func (m *Module) readFromWSConn(
+ ctx context.Context,
+ username string,
+ wsConn *websocket.Conn,
+ stream *streampkg.Stream,
+) {
+ l := log.
+ WithContext(ctx).
+ WithFields(kv.Fields{
+ {"username", username},
+ {"streamID", stream.ID},
+ }...)
- updateStream, ok := msg["stream"]
- if !ok {
- l.Warn("'stream' field not provided")
- continue
- }
+readLoop:
+ for {
+ select {
+ case <-ctx.Done():
+ // Connection closed.
+ break readLoop
- // Ignore if the updateStreamType is unknown (or missing),
- // so a bad client can't cause extra memory allocations
- if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
- l.Warnf("unknown 'stream' field: %v", msg)
- continue
+ default:
+ // Read JSON objects from the client and act on them.
+ var msg map[string]string
+ if err := wsConn.ReadJSON(&msg); err != nil {
+ // Only log an error if something weird happened.
+ // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
+ if websocket.IsUnexpectedCloseError(err, []int{
+ websocket.CloseNormalClosure,
+ websocket.CloseGoingAway,
+ websocket.CloseNoStatusReceived,
+ }...) {
+ l.Errorf("error reading from websocket: %v", err)
}
- updateList, ok := msg["list"]
- if ok {
- updateStream += ":" + updateList
- }
+ // The connection is gone; no
+ // further streaming possible.
+ break readLoop
+ }
- switch updateType {
- case "subscribe":
- stream.Lock()
- stream.StreamTypes[updateStream] = true
- stream.Unlock()
- case "unsubscribe":
- stream.Lock()
- delete(stream.StreamTypes, updateStream)
- stream.Unlock()
- default:
- l.Warnf("invalid 'type' field: %v", msg)
- }
+ // Messages *from* the WS connection are infrequent
+ // and usually interesting, so log this at info.
+ l.Infof("received message from websocket: %v", msg)
+
+ // If the message contains 'stream' and 'type' fields, we can
+ // update the set of timelines that are subscribed for events.
+ updateType, ok := msg["type"]
+ if !ok {
+ l.Warn("'type' field not provided")
+ continue
}
- }()
- for {
- select {
- // Connection closed
- case <-ctx.Done():
- return
-
- // Received next stream message
- case msg := <-stream.Messages:
- l.Tracef("sending message to websocket: %+v", msg)
- if err := wsConn.WriteJSON(msg); err != nil {
- l.Debugf("error writing json to websocket: %v", err)
- return
- }
+ updateStream, ok := msg["stream"]
+ if !ok {
+ l.Warn("'stream' field not provided")
+ continue
+ }
- // Reset on each successful send.
- pinger.Reset(m.dTicker)
-
- // Send keep-alive "ping"
- case <-pinger.C:
- l.Trace("pinging websocket ...")
- if err := wsConn.WriteMessage(
- websocket.PingMessage,
- []byte{},
- ); err != nil {
- l.Debugf("error writing ping to websocket: %v", err)
- return
- }
+ // Ignore if the updateStreamType is unknown (or missing),
+ // so a bad client can't cause extra memory allocations
+ if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
+ l.Warnf("unknown 'stream' field: %v", msg)
+ continue
+ }
+
+ updateList, ok := msg["list"]
+ if ok {
+ updateStream += ":" + updateList
+ }
+
+ switch updateType {
+ case "subscribe":
+ stream.Lock()
+ stream.StreamTypes[updateStream] = true
+ stream.Unlock()
+ case "unsubscribe":
+ stream.Lock()
+ delete(stream.StreamTypes, updateStream)
+ stream.Unlock()
+ default:
+ l.Warnf("invalid 'type' field: %v", msg)
}
}
- }()
+ }
+
+ l.Debug("finished reading from websocket connection")
+}
+
+// writeToWSConn receives messages coming from the processor via the
+// given stream, and writes them into the given websockets connection.
+// This function also handles sending ping messages into the websockets
+// connection to keep it alive when no other activity occurs.
+//
+// This is a blocking function; will return only on write error or
+// if the given context is canceled.
+func (m *Module) writeToWSConn(
+ ctx context.Context,
+ username string,
+ wsConn *websocket.Conn,
+ stream *streampkg.Stream,
+ pinger *time.Ticker,
+) {
+ l := log.
+ WithContext(ctx).
+ WithFields(kv.Fields{
+ {"username", username},
+ {"streamID", stream.ID},
+ }...)
+
+writeLoop:
+ for {
+ select {
+ case <-ctx.Done():
+ // Connection closed.
+ break writeLoop
+
+ case msg := <-stream.Messages:
+ // Received a new message from the processor.
+ l.Tracef("writing message to websocket: %+v", msg)
+ if err := wsConn.WriteJSON(msg); err != nil {
+ l.Debugf("error writing json to websocket: %v", err)
+ break writeLoop
+ }
+
+ // Reset pinger on successful send, since
+ // we know the connection is still there.
+ pinger.Reset(m.dTicker)
+
+ case <-pinger.C:
+ // Time to send a keep-alive "ping".
+ l.Trace("writing ping control message to websocket")
+ if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
+ l.Debugf("error writing ping to websocket: %v", err)
+ break writeLoop
+ }
+ }
+ }
+
+ l.Debug("finished writing to websocket connection")
}
diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go
index edddeab73..303e16cd3 100644
--- a/internal/api/client/streaming/streaming.go
+++ b/internal/api/client/streaming/streaming.go
@@ -42,15 +42,18 @@ type Module struct {
}
func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module {
+ // We expect CORS requests for websockets,
+ // (via eg., semaphore.social) so be lenient.
+ // TODO: make this customizable?
+ checkOrigin := func(r *http.Request) bool { return true }
+
return &Module{
processor: processor,
dTicker: dTicker,
wsUpgrade: websocket.Upgrader{
- ReadBufferSize: wsBuf, // we don't expect reads
+ ReadBufferSize: wsBuf,
WriteBufferSize: wsBuf,
-
- // we expect cors requests (via eg., semaphore.social) so be lenient
- CheckOrigin: func(r *http.Request) bool { return true },
+ CheckOrigin: checkOrigin,
},
}
}