summaryrefslogtreecommitdiff
path: root/internal/api/client/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/client/streaming')
-rw-r--r--internal/api/client/streaming/stream.go63
-rw-r--r--internal/api/client/streaming/streaming.go17
-rw-r--r--internal/api/client/streaming/streaming_test.go9
3 files changed, 63 insertions, 26 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index f41bc0ac2..88c682a75 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -82,6 +82,20 @@ import (
// `direct`: receive updates for direct messages.
// in: query
// required: true
+// -
+// name: list
+// type: string
+// description: |-
+// ID of the list to subscribe to.
+// Only used if stream type is 'list'.
+// in: query
+// -
+// name: tag
+// type: string
+// description: |-
+// Name of the tag to subscribe to.
+// Only used if stream type is 'hashtag' or 'hashtag:local'.
+// in: query
//
// security:
// - OAuth2 Bearer:
@@ -164,8 +178,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
// Get the initial stream type, if there is one.
- // streamType will be an empty string if one wasn't supplied. Open() will deal with this
+ // By appending other query params to the streamType,
+ // we can allow for streaming for specific list IDs
+ // or hashtags.
streamType := c.Query(StreamQueryKey)
+ if list := c.Query(StreamListKey); list != "" {
+ streamType += ":" + list
+ } else if tag := c.Query(StreamTagKey); tag != "" {
+ streamType += ":" + tag
+ }
+
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
@@ -240,28 +262,41 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
- // everything else is ignored.
- action := msg["type"]
- streamType := msg["stream"]
+ updateType, ok := msg["type"]
+ if !ok {
+ l.Warn("'type' field not provided")
+ continue
+ }
+
+ updateStream, ok := msg["stream"]
+ if !ok {
+ l.Warn("'stream' field not provided")
+ continue
+ }
- // Ignore if the streamType is unknown (or missing), so a bad
- // client can't cause extra memory allocations
- if !slices.Contains(streampkg.AllStatusTimelines, streamType) {
- l.Warnf("Unknown 'stream' field: %v", msg)
+ // Ignore if the updateStreamType is unknown (or missing),
+ // so a bad client can't cause extra memory allocations
+ if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
+ l.Warnf("unknown 'stream' field: %v", msg)
continue
}
- switch action {
+ updateList, ok := msg["list"]
+ if ok {
+ updateStream += ":" + updateList
+ }
+
+ switch updateType {
case "subscribe":
stream.Lock()
- stream.Timelines[streamType] = true
+ stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
- delete(stream.Timelines, streamType)
+ delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
- l.Warnf("Invalid 'type' field: %v", msg)
+ l.Warnf("invalid 'type' field: %v", msg)
}
}
}()
@@ -276,7 +311,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
case msg := <-stream.Messages:
l.Tracef("sending message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
- l.Errorf("error writing json to websocket: %v", err)
+ l.Debugf("error writing json to websocket: %v", err)
return
}
@@ -290,7 +325,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
websocket.PingMessage,
[]byte{},
); err != nil {
- l.Errorf("error writing ping to websocket: %v", err)
+ l.Debugf("error writing ping to websocket: %v", err)
return
}
}
diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go
index 71b325089..edddeab73 100644
--- a/internal/api/client/streaming/streaming.go
+++ b/internal/api/client/streaming/streaming.go
@@ -27,17 +27,12 @@ import (
)
const (
- // BasePath is the path for the streaming api, minus the 'api' prefix
- BasePath = "/v1/streaming"
-
- // StreamQueryKey is the query key for the type of stream being requested
- StreamQueryKey = "stream"
-
- // AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests.
- AccessTokenQueryKey = "access_token"
- // AccessTokenHeader is the header for an oauth access token that can be passed in streaming requests instead of AccessTokenQueryKey
- //nolint:gosec
- AccessTokenHeader = "Sec-Websocket-Protocol"
+ BasePath = "/v1/streaming" // path for the streaming api, minus the 'api' prefix
+ StreamQueryKey = "stream" // type of stream being requested
+ StreamListKey = "list" // id of list being requested
+ StreamTagKey = "tag" // name of tag being requested
+ AccessTokenQueryKey = "access_token" // oauth access token
+ AccessTokenHeader = "Sec-Websocket-Protocol" //nolint:gosec
)
type Module struct {
diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go
index b429461c6..cece99bac 100644
--- a/internal/api/client/streaming/streaming_test.go
+++ b/internal/api/client/streaming/streaming_test.go
@@ -41,6 +41,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
+ "github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -94,6 +95,13 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
+
+ testrig.StartTimelines(
+ &suite.state,
+ visibility.NewFilter(&suite.state),
+ suite.tc,
+ )
+
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -102,7 +110,6 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.streamingModule = streaming.New(suite.processor, 1, 4096)
- suite.NoError(suite.processor.Start())
}
func (suite *StreamingTestSuite) TearDownTest() {