diff options
Diffstat (limited to 'internal/api/client/streaming')
-rw-r--r-- | internal/api/client/streaming/stream.go | 63 | ||||
-rw-r--r-- | internal/api/client/streaming/streaming.go | 17 | ||||
-rw-r--r-- | internal/api/client/streaming/streaming_test.go | 9 |
3 files changed, 63 insertions, 26 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index f41bc0ac2..88c682a75 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -82,6 +82,20 @@ import ( // `direct`: receive updates for direct messages. // in: query // required: true +// - +// name: list +// type: string +// description: |- +// ID of the list to subscribe to. +// Only used if stream type is 'list'. +// in: query +// - +// name: tag +// type: string +// description: |- +// Name of the tag to subscribe to. +// Only used if stream type is 'hashtag' or 'hashtag:local'. +// in: query // // security: // - OAuth2 Bearer: @@ -164,8 +178,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } // Get the initial stream type, if there is one. - // streamType will be an empty string if one wasn't supplied. Open() will deal with this + // By appending other query params to the streamType, + // we can allow for streaming for specific list IDs + // or hashtags. streamType := c.Query(StreamQueryKey) + if list := c.Query(StreamListKey); list != "" { + streamType += ":" + list + } else if tag := c.Query(StreamTagKey); tag != "" { + streamType += ":" + tag + } + stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) @@ -240,28 +262,41 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // If the message contains 'stream' and 'type' fields, we can // update the set of timelines that are subscribed for events. - // everything else is ignored. - action := msg["type"] - streamType := msg["stream"] + updateType, ok := msg["type"] + if !ok { + l.Warn("'type' field not provided") + continue + } + + updateStream, ok := msg["stream"] + if !ok { + l.Warn("'stream' field not provided") + continue + } - // Ignore if the streamType is unknown (or missing), so a bad - // client can't cause extra memory allocations - if !slices.Contains(streampkg.AllStatusTimelines, streamType) { - l.Warnf("Unknown 'stream' field: %v", msg) + // Ignore if the updateStreamType is unknown (or missing), + // so a bad client can't cause extra memory allocations + if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { + l.Warnf("unknown 'stream' field: %v", msg) continue } - switch action { + updateList, ok := msg["list"] + if ok { + updateStream += ":" + updateList + } + + switch updateType { case "subscribe": stream.Lock() - stream.Timelines[streamType] = true + stream.StreamTypes[updateStream] = true stream.Unlock() case "unsubscribe": stream.Lock() - delete(stream.Timelines, streamType) + delete(stream.StreamTypes, updateStream) stream.Unlock() default: - l.Warnf("Invalid 'type' field: %v", msg) + l.Warnf("invalid 'type' field: %v", msg) } } }() @@ -276,7 +311,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { case msg := <-stream.Messages: l.Tracef("sending message to websocket: %+v", msg) if err := wsConn.WriteJSON(msg); err != nil { - l.Errorf("error writing json to websocket: %v", err) + l.Debugf("error writing json to websocket: %v", err) return } @@ -290,7 +325,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { websocket.PingMessage, []byte{}, ); err != nil { - l.Errorf("error writing ping to websocket: %v", err) + l.Debugf("error writing ping to websocket: %v", err) return } } diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go index 71b325089..edddeab73 100644 --- a/internal/api/client/streaming/streaming.go +++ b/internal/api/client/streaming/streaming.go @@ -27,17 +27,12 @@ import ( ) const ( - // BasePath is the path for the streaming api, minus the 'api' prefix - BasePath = "/v1/streaming" - - // StreamQueryKey is the query key for the type of stream being requested - StreamQueryKey = "stream" - - // AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests. - AccessTokenQueryKey = "access_token" - // AccessTokenHeader is the header for an oauth access token that can be passed in streaming requests instead of AccessTokenQueryKey - //nolint:gosec - AccessTokenHeader = "Sec-Websocket-Protocol" + BasePath = "/v1/streaming" // path for the streaming api, minus the 'api' prefix + StreamQueryKey = "stream" // type of stream being requested + StreamListKey = "list" // id of list being requested + StreamTagKey = "tag" // name of tag being requested + AccessTokenQueryKey = "access_token" // oauth access token + AccessTokenHeader = "Sec-Websocket-Protocol" //nolint:gosec ) type Module struct { diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index b429461c6..cece99bac 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -41,6 +41,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -94,6 +95,13 @@ func (suite *StreamingTestSuite) SetupTest() { suite.state.Storage = suite.storage suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -102,7 +110,6 @@ func (suite *StreamingTestSuite) SetupTest() { suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.streamingModule = streaming.New(suite.processor, 1, 4096) - suite.NoError(suite.processor.Start()) } func (suite *StreamingTestSuite) TearDownTest() { |