diff options
Diffstat (limited to 'cmd/gotosocial/action/testrig/testrig.go')
-rw-r--r-- | cmd/gotosocial/action/testrig/testrig.go | 148 |
1 files changed, 58 insertions, 90 deletions
diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go index 88ac369de..f13845b20 100644 --- a/cmd/gotosocial/action/testrig/testrig.go +++ b/cmd/gotosocial/action/testrig/testrig.go @@ -21,6 +21,7 @@ package testrig import ( "bytes" "context" + "errors" "fmt" "io" "net/http" @@ -28,35 +29,17 @@ import ( "os/signal" "syscall" + "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/internal/api" - "github.com/superseriousbusiness/gotosocial/internal/api/client/account" - "github.com/superseriousbusiness/gotosocial/internal/api/client/admin" - "github.com/superseriousbusiness/gotosocial/internal/api/client/app" - "github.com/superseriousbusiness/gotosocial/internal/api/client/auth" - "github.com/superseriousbusiness/gotosocial/internal/api/client/blocks" - "github.com/superseriousbusiness/gotosocial/internal/api/client/emoji" - "github.com/superseriousbusiness/gotosocial/internal/api/client/favourites" - "github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver" - "github.com/superseriousbusiness/gotosocial/internal/api/client/filter" - "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest" - "github.com/superseriousbusiness/gotosocial/internal/api/client/instance" - "github.com/superseriousbusiness/gotosocial/internal/api/client/list" - mediaModule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" - "github.com/superseriousbusiness/gotosocial/internal/api/client/notification" - "github.com/superseriousbusiness/gotosocial/internal/api/client/search" - "github.com/superseriousbusiness/gotosocial/internal/api/client/status" - "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" - "github.com/superseriousbusiness/gotosocial/internal/api/client/timeline" - userClient "github.com/superseriousbusiness/gotosocial/internal/api/client/user" - "github.com/superseriousbusiness/gotosocial/internal/api/s2s/nodeinfo" - "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" - "github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger" - "github.com/superseriousbusiness/gotosocial/internal/api/security" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/concurrency" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gotosocial" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/web" @@ -70,7 +53,6 @@ var Start action.GTSAction = func(ctx context.Context) error { dbService := testrig.NewTestDB() testrig.StandardDBSetup(dbService, nil) - router := testrig.NewTestRouter(dbService) var storageBackend *storage.Driver if os.Getenv("GTS_STORAGE_BACKEND") == "s3" { storageBackend, _ = storage.NewS3Storage() @@ -84,7 +66,6 @@ var Start action.GTSAction = func(ctx context.Context) error { fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) // build backend handlers - oauthServer := testrig.NewTestOauthServer(dbService) transportController := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { r := io.NopCloser(bytes.NewReader([]byte{})) return &http.Response{ @@ -102,77 +83,64 @@ var Start action.GTSAction = func(ctx context.Context) error { return fmt.Errorf("error starting processor: %s", err) } - idp, err := oidc.NewIDP(ctx) - if err != nil { - return fmt.Errorf("error creating oidc idp: %s", err) + /* + HTTP router initialization + */ + + router := testrig.NewTestRouter(dbService) + + // attach global middlewares which are used for every request + router.AttachGlobalMiddleware( + middleware.Logger(), + middleware.UserAgent(), + middleware.CORS(), + middleware.ExtraHeaders(), + ) + + // attach global no route / 404 handler to the router + router.AttachNoRouteHandler(func(c *gin.Context) { + apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGet) + }) + + // build router modules + var idp oidc.IDP + var err error + if config.GetOIDCEnabled() { + idp, err = oidc.NewIDP(ctx) + if err != nil { + return fmt.Errorf("error creating oidc idp: %w", err) + } } - // build web module - webModule := web.New(processor) - - // build client api modules - authModule := auth.New(dbService, idp, processor) - accountModule := account.New(processor) - instanceModule := instance.New(processor) - appsModule := app.New(processor) - followRequestsModule := followrequest.New(processor) - webfingerModule := webfinger.New(processor) - nodeInfoModule := nodeinfo.New(processor) - usersModule := user.New(processor) - timelineModule := timeline.New(processor) - notificationModule := notification.New(processor) - searchModule := search.New(processor) - filtersModule := filter.New(processor) - emojiModule := emoji.New(processor) - listsModule := list.New(processor) - mm := mediaModule.New(processor) - fileServerModule := fileserver.New(processor) - adminModule := admin.New(processor) - statusModule := status.New(processor) - securityModule := security.New(dbService, oauthServer) - streamingModule := streaming.New(processor) - favouritesModule := favourites.New(processor) - blocksModule := blocks.New(processor) - userClientModule := userClient.New(processor) - - apis := []api.ClientModule{ - // modules with middleware go first - securityModule, - authModule, - - // now the web module - webModule, - - // now everything else - accountModule, - instanceModule, - appsModule, - followRequestsModule, - mm, - fileServerModule, - adminModule, - statusModule, - webfingerModule, - nodeInfoModule, - usersModule, - timelineModule, - notificationModule, - searchModule, - filtersModule, - emojiModule, - listsModule, - streamingModule, - favouritesModule, - blocksModule, - userClientModule, + routerSession, err := dbService.GetSession(ctx) + if err != nil { + return fmt.Errorf("error retrieving router session for session middleware: %w", err) } - for _, m := range apis { - if err := m.Route(router); err != nil { - return fmt.Errorf("routing error: %s", err) - } + sessionName, err := middleware.SessionName() + if err != nil { + return fmt.Errorf("error generating session name for session middleware: %w", err) } + var ( + authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths + clientModule = api.NewClient(dbService, processor) // api client endpoints + fileserverModule = api.NewFileserver(processor) // fileserver endpoints + wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints + nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint + activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints + webModule = web.New(processor) // web pages + user profiles + settings panels etc + ) + + // these should be routed in order + authModule.Route(router) + clientModule.Route(router) + fileserverModule.Route(router) + wellKnownModule.Route(router) + nodeInfoModule.Route(router) + activityPubModule.Route(router) + webModule.Route(router) + gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager) if err != nil { return fmt.Errorf("error creating gotosocial service: %s", err) |