diff options
Diffstat (limited to 'internal/api/client')
| -rw-r--r-- | internal/api/client/streaming/stream.go | 79 | 
1 files changed, 57 insertions, 22 deletions
| diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 444157c1b..067f87392 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -20,14 +20,16 @@ package streaming  import (  	"context" -	"errors" -	"fmt"  	"time"  	"codeberg.org/gruf/go-kv"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/oauth" +	streampkg "github.com/superseriousbusiness/gotosocial/internal/stream" +	"golang.org/x/exp/slices"  	"github.com/gin-gonic/gin"  	"github.com/gorilla/websocket" @@ -134,32 +136,37 @@ import (  //		'400':  //			description: bad request  func (m *Module) StreamGETHandler(c *gin.Context) { -	streamType := c.Query(StreamQueryKey) -	if streamType == "" { -		err := fmt.Errorf("no stream type provided under query key %s", StreamQueryKey) -		apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) -		return -	} - -	var token string  	// First we check for a query param provided access token -	if token = c.Query(AccessTokenQueryKey); token == "" { +	token := c.Query(AccessTokenQueryKey) +	if token == "" {  		// Else we check the HTTP header provided token -		if token = c.GetHeader(AccessTokenHeader); token == "" { -			const errStr = "no access token provided" -			err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr) +		token = c.GetHeader(AccessTokenHeader) +	} + +	var account *gtsmodel.Account +	if token != "" { +		// Check the explicit token +		var errWithCode gtserror.WithCode +		account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) +		if errWithCode != nil { +			apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) +			return +		} +	} else { +		// If no explicit token was provided, try regular oauth +		auth, errStr := oauth.Authed(c, true, true, true, true) +		if errStr != nil { +			err := gtserror.NewErrorUnauthorized(errStr, errStr.Error())  			apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1)  			return  		} +		account = auth.Account  	} -	account, errWithCode := m.processor.Stream().Authorize(c.Request.Context(), token) -	if errWithCode != nil { -		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) -		return -	} - +	// Get the initial stream type, if there is one. +	// streamType will be an empty string if one wasn't supplied. Open() will deal with this +	streamType := c.Query(StreamQueryKey)  	stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)  	if errWithCode != nil {  		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) @@ -219,8 +226,9 @@ func (m *Module) StreamGETHandler(c *gin.Context) {  				// We have to listen for received websocket messages in  				// order to trigger the underlying wsConn.PingHandler().  				// -				// So we wait on received messages but only act on errors. -				_, _, err := wsConn.ReadMessage() +				// Read JSON objects from the client and act on them +				var msg map[string]string +				err := wsConn.ReadJSON(&msg)  				if err != nil {  					if ctx.Err() == nil {  						// Only log error if the connection was not closed @@ -229,6 +237,33 @@ func (m *Module) StreamGETHandler(c *gin.Context) {  					}  					return  				} +				l.Tracef("received message from websocket: %v", msg) + +				// If the message contains 'stream' and 'type' fields, we can +				// update the set of timelines that are subscribed for events. +				// everything else is ignored. +				action := msg["type"] +				streamType := msg["stream"] + +				// Ignore if the streamType is unknown (or missing), so a bad +				// client can't cause extra memory allocations +				if !slices.Contains(streampkg.AllStatusTimelines, streamType) { +					l.Warnf("Unknown 'stream' field: %v", msg) +					continue +				} + +				switch action { +				case "subscribe": +					stream.Lock() +					stream.Timelines[streamType] = true +					stream.Unlock() +				case "unsubscribe": +					stream.Lock() +					delete(stream.Timelines, streamType) +					stream.Unlock() +				default: +					l.Warnf("Invalid 'type' field: %v", msg) +				}  			}  		}() | 
