diff options
Diffstat (limited to 'internal/api/client/streaming/stream.go')
-rw-r--r-- | internal/api/client/streaming/stream.go | 63 |
1 files changed, 49 insertions, 14 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index f41bc0ac2..88c682a75 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -82,6 +82,20 @@ import ( // `direct`: receive updates for direct messages. // in: query // required: true +// - +// name: list +// type: string +// description: |- +// ID of the list to subscribe to. +// Only used if stream type is 'list'. +// in: query +// - +// name: tag +// type: string +// description: |- +// Name of the tag to subscribe to. +// Only used if stream type is 'hashtag' or 'hashtag:local'. +// in: query // // security: // - OAuth2 Bearer: @@ -164,8 +178,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } // Get the initial stream type, if there is one. - // streamType will be an empty string if one wasn't supplied. Open() will deal with this + // By appending other query params to the streamType, + // we can allow for streaming for specific list IDs + // or hashtags. streamType := c.Query(StreamQueryKey) + if list := c.Query(StreamListKey); list != "" { + streamType += ":" + list + } else if tag := c.Query(StreamTagKey); tag != "" { + streamType += ":" + tag + } + stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) @@ -240,28 +262,41 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // If the message contains 'stream' and 'type' fields, we can // update the set of timelines that are subscribed for events. - // everything else is ignored. - action := msg["type"] - streamType := msg["stream"] + updateType, ok := msg["type"] + if !ok { + l.Warn("'type' field not provided") + continue + } + + updateStream, ok := msg["stream"] + if !ok { + l.Warn("'stream' field not provided") + continue + } - // Ignore if the streamType is unknown (or missing), so a bad - // client can't cause extra memory allocations - if !slices.Contains(streampkg.AllStatusTimelines, streamType) { - l.Warnf("Unknown 'stream' field: %v", msg) + // Ignore if the updateStreamType is unknown (or missing), + // so a bad client can't cause extra memory allocations + if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { + l.Warnf("unknown 'stream' field: %v", msg) continue } - switch action { + updateList, ok := msg["list"] + if ok { + updateStream += ":" + updateList + } + + switch updateType { case "subscribe": stream.Lock() - stream.Timelines[streamType] = true + stream.StreamTypes[updateStream] = true stream.Unlock() case "unsubscribe": stream.Lock() - delete(stream.Timelines, streamType) + delete(stream.StreamTypes, updateStream) stream.Unlock() default: - l.Warnf("Invalid 'type' field: %v", msg) + l.Warnf("invalid 'type' field: %v", msg) } } }() @@ -276,7 +311,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { case msg := <-stream.Messages: l.Tracef("sending message to websocket: %+v", msg) if err := wsConn.WriteJSON(msg); err != nil { - l.Errorf("error writing json to websocket: %v", err) + l.Debugf("error writing json to websocket: %v", err) return } @@ -290,7 +325,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { websocket.PingMessage, []byte{}, ); err != nil { - l.Errorf("error writing ping to websocket: %v", err) + l.Debugf("error writing ping to websocket: %v", err) return } } |