summaryrefslogtreecommitdiff
path: root/internal/api/client/streaming/stream.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/client/streaming/stream.go')
-rw-r--r--internal/api/client/streaming/stream.go63
1 files changed, 49 insertions, 14 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index f41bc0ac2..88c682a75 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -82,6 +82,20 @@ import (
// `direct`: receive updates for direct messages.
// in: query
// required: true
+// -
+// name: list
+// type: string
+// description: |-
+// ID of the list to subscribe to.
+// Only used if stream type is 'list'.
+// in: query
+// -
+// name: tag
+// type: string
+// description: |-
+// Name of the tag to subscribe to.
+// Only used if stream type is 'hashtag' or 'hashtag:local'.
+// in: query
//
// security:
// - OAuth2 Bearer:
@@ -164,8 +178,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
// Get the initial stream type, if there is one.
- // streamType will be an empty string if one wasn't supplied. Open() will deal with this
+ // By appending other query params to the streamType,
+ // we can allow for streaming for specific list IDs
+ // or hashtags.
streamType := c.Query(StreamQueryKey)
+ if list := c.Query(StreamListKey); list != "" {
+ streamType += ":" + list
+ } else if tag := c.Query(StreamTagKey); tag != "" {
+ streamType += ":" + tag
+ }
+
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
@@ -240,28 +262,41 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
- // everything else is ignored.
- action := msg["type"]
- streamType := msg["stream"]
+ updateType, ok := msg["type"]
+ if !ok {
+ l.Warn("'type' field not provided")
+ continue
+ }
+
+ updateStream, ok := msg["stream"]
+ if !ok {
+ l.Warn("'stream' field not provided")
+ continue
+ }
- // Ignore if the streamType is unknown (or missing), so a bad
- // client can't cause extra memory allocations
- if !slices.Contains(streampkg.AllStatusTimelines, streamType) {
- l.Warnf("Unknown 'stream' field: %v", msg)
+ // Ignore if the updateStreamType is unknown (or missing),
+ // so a bad client can't cause extra memory allocations
+ if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
+ l.Warnf("unknown 'stream' field: %v", msg)
continue
}
- switch action {
+ updateList, ok := msg["list"]
+ if ok {
+ updateStream += ":" + updateList
+ }
+
+ switch updateType {
case "subscribe":
stream.Lock()
- stream.Timelines[streamType] = true
+ stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
- delete(stream.Timelines, streamType)
+ delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
- l.Warnf("Invalid 'type' field: %v", msg)
+ l.Warnf("invalid 'type' field: %v", msg)
}
}
}()
@@ -276,7 +311,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
case msg := <-stream.Messages:
l.Tracef("sending message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
- l.Errorf("error writing json to websocket: %v", err)
+ l.Debugf("error writing json to websocket: %v", err)
return
}
@@ -290,7 +325,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
websocket.PingMessage,
[]byte{},
); err != nil {
- l.Errorf("error writing ping to websocket: %v", err)
+ l.Debugf("error writing ping to websocket: %v", err)
return
}
}