diff options
Diffstat (limited to 'internal/api/client/streaming/stream.go')
-rw-r--r-- | internal/api/client/streaming/stream.go | 79 |
1 files changed, 57 insertions, 22 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 444157c1b..067f87392 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -20,14 +20,16 @@ package streaming import ( "context" - "errors" - "fmt" "time" "codeberg.org/gruf/go-kv" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + streampkg "github.com/superseriousbusiness/gotosocial/internal/stream" + "golang.org/x/exp/slices" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -134,32 +136,37 @@ import ( // '400': // description: bad request func (m *Module) StreamGETHandler(c *gin.Context) { - streamType := c.Query(StreamQueryKey) - if streamType == "" { - err := fmt.Errorf("no stream type provided under query key %s", StreamQueryKey) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - - var token string // First we check for a query param provided access token - if token = c.Query(AccessTokenQueryKey); token == "" { + token := c.Query(AccessTokenQueryKey) + if token == "" { // Else we check the HTTP header provided token - if token = c.GetHeader(AccessTokenHeader); token == "" { - const errStr = "no access token provided" - err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr) + token = c.GetHeader(AccessTokenHeader) + } + + var account *gtsmodel.Account + if token != "" { + // Check the explicit token + var errWithCode gtserror.WithCode + account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + } else { + // If no explicit token was provided, try regular oauth + auth, errStr := oauth.Authed(c, true, true, true, true) + if errStr != nil { + err := gtserror.NewErrorUnauthorized(errStr, errStr.Error()) apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) return } + account = auth.Account } - account, errWithCode := m.processor.Stream().Authorize(c.Request.Context(), token) - if errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return - } - + // Get the initial stream type, if there is one. + // streamType will be an empty string if one wasn't supplied. Open() will deal with this + streamType := c.Query(StreamQueryKey) stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) @@ -219,8 +226,9 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // We have to listen for received websocket messages in // order to trigger the underlying wsConn.PingHandler(). // - // So we wait on received messages but only act on errors. - _, _, err := wsConn.ReadMessage() + // Read JSON objects from the client and act on them + var msg map[string]string + err := wsConn.ReadJSON(&msg) if err != nil { if ctx.Err() == nil { // Only log error if the connection was not closed @@ -229,6 +237,33 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } return } + l.Tracef("received message from websocket: %v", msg) + + // If the message contains 'stream' and 'type' fields, we can + // update the set of timelines that are subscribed for events. + // everything else is ignored. + action := msg["type"] + streamType := msg["stream"] + + // Ignore if the streamType is unknown (or missing), so a bad + // client can't cause extra memory allocations + if !slices.Contains(streampkg.AllStatusTimelines, streamType) { + l.Warnf("Unknown 'stream' field: %v", msg) + continue + } + + switch action { + case "subscribe": + stream.Lock() + stream.Timelines[streamType] = true + stream.Unlock() + case "unsubscribe": + stream.Lock() + delete(stream.Timelines, streamType) + stream.Unlock() + default: + l.Warnf("Invalid 'type' field: %v", msg) + } } }() |