diff options
Diffstat (limited to 'internal')
54 files changed, 1612 insertions, 660 deletions
diff --git a/internal/api/client/auth/auth.go b/internal/api/client/auth/auth.go index 67643244b..717d997a3 100644 --- a/internal/api/client/auth/auth.go +++ b/internal/api/client/auth/auth.go @@ -32,10 +32,23 @@ import ( const ( // AuthSignInPath is the API path for users to sign in through AuthSignInPath = "/auth/sign_in" + + // CheckYourEmailPath users land here after registering a new account, instructs them to confirm thier email + CheckYourEmailPath = "/check_your_email" + + // WaitForApprovalPath users land here after confirming thier email but before an admin approves thier account + // (if such is required) + WaitForApprovalPath = "/wait_for_approval" + + // AccountDisabledPath users land here when thier account is suspended by an admin + AccountDisabledPath = "/account_disabled" + // OauthTokenPath is the API path to use for granting token requests to users with valid credentials OauthTokenPath = "/oauth/token" + // OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user) OauthAuthorizePath = "/oauth/authorize" + // CallbackPath is the API path for receiving callback tokens from external OIDC providers CallbackPath = oidc.CallbackPath diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go index a0ee8892d..fdf1b6baf 100644 --- a/internal/api/client/auth/auth_test.go +++ b/internal/api/client/auth/auth_test.go @@ -18,4 +18,96 @@ package auth_test -// TODO +import ( + "context" + "fmt" + "net/http/httptest" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/memstore" + "github.com/gin-gonic/gin" + "github.com/spf13/viper" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/client/auth" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/oidc" + "github.com/superseriousbusiness/gotosocial/internal/router" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type AuthStandardTestSuite struct { + suite.Suite + db db.DB + idp oidc.IDP + oauthServer oauth.Server + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + + // module being tested + authModule *auth.Module +} + +const ( + sessionUserID = "userid" + sessionClientID = "client_id" +) + +func (suite *AuthStandardTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() +} + +func (suite *AuthStandardTestSuite) SetupTest() { + testrig.InitTestConfig() + suite.db = testrig.NewTestDB() + testrig.InitTestLog() + + suite.oauthServer = testrig.NewTestOauthServer(suite.db) + var err error + suite.idp, err = oidc.NewIDP(context.Background()) + if err != nil { + panic(err) + } + suite.authModule = auth.New(suite.db, suite.oauthServer, suite.idp).(*auth.Module) + testrig.StandardDBSetup(suite.db, nil) +} + +func (suite *AuthStandardTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string) (*gin.Context, *httptest.ResponseRecorder) { + // create the recorder and gin test context + recorder := httptest.NewRecorder() + ctx, engine := gin.CreateTestContext(recorder) + + // load templates into the engine + testrig.ConfigureTemplatesWithGin(engine) + + // create the request + protocol := viper.GetString(config.Keys.Protocol) + host := viper.GetString(config.Keys.Host) + baseURI := fmt.Sprintf("%s://%s", protocol, host) + requestURI := fmt.Sprintf("%s/%s", baseURI, requestPath) + ctx.Request = httptest.NewRequest(requestMethod, requestURI, nil) // the endpoint we're hitting + ctx.Request.Header.Set("accept", "text/html") + + // trigger the session middleware on the context + store := memstore.NewStore(make([]byte, 32), make([]byte, 32)) + store.Options(router.SessionOptions()) + sessionMiddleware := sessions.Sessions("gotosocial-localhost", store) + sessionMiddleware(ctx) + + return ctx, recorder +} diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index 99f3cca68..387b83c1e 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -44,7 +44,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { s := sessions.Default(c) if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil { - c.JSON(http.StatusNotAcceptable, gin.H{"error": err.Error()}) + c.HTML(http.StatusNotAcceptable, "error.tmpl", gin.H{"error": err.Error()}) return } @@ -57,7 +57,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { if err := c.Bind(form); err != nil { l.Debugf("invalid auth form: %s", err) m.clearSession(s) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()}) return } l.Debugf("parsed auth form: %+v", form) @@ -65,7 +65,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { if err := extractAuthForm(s, form); err != nil { l.Debugf(fmt.Sprintf("error parsing form at /oauth/authorize: %s", err)) m.clearSession(s) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()}) return } c.Redirect(http.StatusSeeOther, AuthSignInPath) @@ -75,28 +75,33 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { // We can use the client_id on the session to retrieve info about the app associated with the client_id clientID, ok := s.Get(sessionClientID).(string) if !ok || clientID == "" { - c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no client_id found in session"}) return } app := >smodel.Application{} if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { m.clearSession(s) - c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{ + "error": fmt.Sprintf("no application found for client id %s", clientID), + }) return } - // we can also use the userid of the user to fetch their username from the db to greet them nicely <3 + // redirect the user if they have not confirmed their email yet, thier account has not been approved yet, + // or thier account has been disabled. user := >smodel.User{} if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { m.clearSession(s) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) return } - acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) if err != nil { m.clearSession(s) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) + return + } + if !ensureUserIsAuthorizedOrRedirect(c, user, acct) { return } @@ -104,13 +109,13 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { redirect, ok := s.Get(sessionRedirectURI).(string) if !ok || redirect == "" { m.clearSession(s) - c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"}) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no redirect_uri found in session"}) return } scope, ok := s.Get(sessionScope).(string) if !ok || scope == "" { m.clearSession(s) - c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"}) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no scope found in session"}) return } @@ -170,10 +175,28 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) { errs = append(errs, "session missing userid") } + // redirect the user if they have not confirmed their email yet, thier account has not been approved yet, + // or thier account has been disabled. + user := >smodel.User{} + if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { + m.clearSession(s) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) + return + } + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { + m.clearSession(s) + c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) + return + } + if !ensureUserIsAuthorizedOrRedirect(c, user, acct) { + return + } + m.clearSession(s) if len(errs) != 0 { - c.JSON(http.StatusBadRequest, gin.H{"error": strings.Join(errs, ": ")}) + c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": strings.Join(errs, ": ")}) return } @@ -190,7 +213,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) { // and proceed with authorization using the oauth2 library if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()}) } } @@ -216,3 +239,27 @@ func extractAuthForm(s sessions.Session, form *model.OAuthAuthorize) error { s.Set(sessionState, uuid.NewString()) return s.Save() } + +func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) bool { + if user.ConfirmedAt.IsZero() { + ctx.Redirect(http.StatusSeeOther, CheckYourEmailPath) + return false + } + + if !user.Approved { + ctx.Redirect(http.StatusSeeOther, WaitForApprovalPath) + return false + } + + if user.Disabled { + ctx.Redirect(http.StatusSeeOther, AccountDisabledPath) + return false + } + + if !account.SuspendedAt.IsZero() { + ctx.Redirect(http.StatusSeeOther, AccountDisabledPath) + return false + } + + return true +} diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go new file mode 100644 index 000000000..8f16702da --- /dev/null +++ b/internal/api/client/auth/authorize_test.go @@ -0,0 +1,113 @@ +package auth_test + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "codeberg.org/gruf/go-errors" + "github.com/gin-contrib/sessions" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/client/auth" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type AuthAuthorizeTestSuite struct { + AuthStandardTestSuite +} + +type authorizeHandlerTestCase struct { + description string + mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account) + expectedStatusCode int + expectedLocationHeader string +} + +func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { + + var tests = []authorizeHandlerTestCase{ + { + description: "user has their email unconfirmed", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { + // nothing to do, weed_lord420 already has their email unconfirmed + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: auth.CheckYourEmailPath, + }, + { + description: "user has their email confirmed but is not approved", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { + user.ConfirmedAt = time.Now() + user.Email = user.UnconfirmedEmail + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: auth.WaitForApprovalPath, + }, + { + description: "user has their email confirmed and is approved, but User entity has been disabled", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { + user.ConfirmedAt = time.Now() + user.Email = user.UnconfirmedEmail + user.Approved = true + user.Disabled = true + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: auth.AccountDisabledPath, + }, + { + description: "user has their email confirmed and is approved, but Account entity has been suspended", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { + user.ConfirmedAt = time.Now() + user.Email = user.UnconfirmedEmail + user.Approved = true + user.Disabled = false + account.SuspendedAt = time.Now() + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: auth.AccountDisabledPath, + }, + } + + doTest := func(testCase authorizeHandlerTestCase) { + ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath) + + user := suite.testUsers["unconfirmed_account"] + account := suite.testAccounts["unconfirmed_account"] + + testSession := sessions.Default(ctx) + testSession.Set(sessionUserID, user.ID) + testSession.Set(sessionClientID, suite.testApplications["application_1"].ClientID) + if err := testSession.Save(); err != nil { + panic(errors.WrapMsgf(err, "failed on case: %s", testCase.description)) + } + + testCase.mutateUserAccount(user, account) + + testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, user.Disabled, account.SuspendedAt) + + user.UpdatedAt = time.Now() + err := suite.db.UpdateByPrimaryKey(context.Background(), user) + suite.NoError(err) + _, err = suite.db.UpdateAccount(context.Background(), account) + suite.NoError(err) + + // call the handler + suite.authModule.AuthorizeGETHandler(ctx) + + // 1. we should have a redirect + suite.Equal(testCase.expectedStatusCode, recorder.Code, fmt.Sprintf("failed on case: %s", testCase.description)) + + // 2. we should have a redirect to the check your email path, as this user has not confirmed their email yet. + suite.Equal(testCase.expectedLocationHeader, recorder.Header().Get("Location"), fmt.Sprintf("failed on case: %s", testCase.description)) + } + + for _, testCase := range tests { + doTest(testCase) + } +} + +func TestAccountUpdateTestSuite(t *testing.T) { + suite.Run(t, new(AuthAuthorizeTestSuite)) +} diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index 8188cb7ce..a5c58647c 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -30,8 +30,6 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/google/uuid" - "github.com/spf13/viper" - "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oidc" @@ -206,19 +204,27 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i } } - // we still need to set *a* password even if it's not a password the user will end up using, so set something random - // in this case, we'll just set two uuids on top of each other, which should be long + random enough to baffle any attempts to crack. + // We still need to set *a* password even if it's not a password the user will end up using, so set something random. + // We'll just set two uuids on top of each other, which should be long + random enough to baffle any attempts to crack. // - // if the user ever wants to log in using gts password rather than oidc flow, they'll have to request a password reset, which is fine + // If the user ever wants to log in using gts password rather than oidc flow, they'll have to request a password reset, which is fine password := uuid.NewString() + uuid.NewString() + // Since this user is created via oidc, which has been set up by the admin, we can assume that the account is already + // implicitly approved, and that the email address has already been verified: otherwise, we end up in situations where + // the admin first approves the user in OIDC, and then has to approve them again in GoToSocial, which doesn't make sense. + // + // In other words, if a user logs in via OIDC, they should be able to use their account straight away. + // + // See: https://github.com/superseriousbusiness/gotosocial/issues/357 + requireApproval := false + emailVerified := true + // create the user! this will also create an account and store it in the database so we don't need to do that here - requireApproval := viper.GetBool(config.Keys.AccountsApprovalRequired) - user, err = m.db.NewSignup(ctx, username, "", requireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) + user, err = m.db.NewSignup(ctx, username, "", requireApproval, claims.Email, password, ip, "", appID, emailVerified, admin) if err != nil { return nil, fmt.Errorf("error creating user: %s", err) } return user, nil - } diff --git a/internal/api/client/media/mediacreate.go b/internal/api/client/media/mediacreate.go index 7887461ee..5946ed398 100644 --- a/internal/api/client/media/mediacreate.go +++ b/internal/api/client/media/mediacreate.go @@ -149,11 +149,9 @@ func validateCreateMedia(form *model.AttachmentRequest) error { return fmt.Errorf("file size limit exceeded: limit is %d bytes but attachment was %d bytes", maxSize, form.File.Size) } - if len(form.Description) < minDescriptionChars || len(form.Description) > maxDescriptionChars { + if len(form.Description) > maxDescriptionChars { return fmt.Errorf("image description length must be between %d and %d characters (inclusive), but provided image description was %d chars", minDescriptionChars, maxDescriptionChars, len(form.Description)) } - // TODO: validate focus here - return nil } diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index 22e0e2188..ca0b6483f 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -21,6 +21,8 @@ package media_test import ( "bytes" "context" + "crypto/rand" + "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -31,10 +33,11 @@ import ( "codeberg.org/gruf/go-store/kv" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + "github.com/spf13/viper" "github.com/stretchr/testify/suite" mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -117,7 +120,7 @@ func (suite *MediaCreateTestSuite) TearDownTest() { ACTUAL TESTS */ -func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful() { +func (suite *MediaCreateTestSuite) TestMediaCreateSuccessful() { // set up the context for the request t := suite.testTokens["local_account_1"] oauthToken := oauth.DBTokenToToken(t) @@ -171,16 +174,16 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful() result := recorder.Result() defer result.Body.Close() b, err := ioutil.ReadAll(result.Body) - assert.NoError(suite.T(), err) + suite.NoError(err) fmt.Println(string(b)) attachmentReply := &model.Attachment{} err = json.Unmarshal(b, attachmentReply) - assert.NoError(suite.T(), err) + suite.NoError(err) - assert.Equal(suite.T(), "this is a test image -- a cool background from somewhere", attachmentReply.Description) - assert.Equal(suite.T(), "image", attachmentReply.Type) - assert.EqualValues(suite.T(), model.MediaMeta{ + suite.Equal("this is a test image -- a cool background from somewhere", attachmentReply.Description) + suite.Equal("image", attachmentReply.Type) + suite.EqualValues(model.MediaMeta{ Original: model.MediaDimensions{ Width: 1920, Height: 1080, @@ -198,11 +201,89 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful() Y: 0.5, }, }, attachmentReply.Meta) - assert.Equal(suite.T(), "LjBzUo#6RQR._NvzRjWF?urqV@a$", attachmentReply.Blurhash) - assert.NotEmpty(suite.T(), attachmentReply.ID) - assert.NotEmpty(suite.T(), attachmentReply.URL) - assert.NotEmpty(suite.T(), attachmentReply.PreviewURL) - assert.Equal(suite.T(), len(storageKeysBeforeRequest)+2, len(storageKeysAfterRequest)) // 2 images should be added to storage: the original and the thumbnail + suite.Equal("LjBzUo#6RQR._NvzRjWF?urqV@a$", attachmentReply.Blurhash) + suite.NotEmpty(attachmentReply.ID) + suite.NotEmpty(attachmentReply.URL) + suite.NotEmpty(attachmentReply.PreviewURL) + suite.Equal(len(storageKeysBeforeRequest)+2, len(storageKeysAfterRequest)) // 2 images should be added to storage: the original and the thumbnail +} + +func (suite *MediaCreateTestSuite) TestMediaCreateLongDescription() { + // set up the context for the request + t := suite.testTokens["local_account_1"] + oauthToken := oauth.DBTokenToToken(t) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauthToken) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + + // read a random string of a really long description + descriptionBytes := make([]byte, 5000) + if _, err := rand.Read(descriptionBytes); err != nil { + panic(err) + } + description := base64.RawStdEncoding.EncodeToString(descriptionBytes) + + // create the request + buf, w, err := testrig.CreateMultipartFormData("file", "../../../../testrig/media/test-jpeg.jpg", map[string]string{ + "description": description, + "focus": "-0.5,0.5", + }) + if err != nil { + panic(err) + } + ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", mediamodule.BasePath), bytes.NewReader(buf.Bytes())) // the endpoint we're hitting + ctx.Request.Header.Set("Content-Type", w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + + // do the actual request + suite.mediaModule.MediaCreatePOSTHandler(ctx) + + // check response + suite.EqualValues(http.StatusUnprocessableEntity, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + expectedErr := fmt.Sprintf(`{"error":"image description length must be between 0 and 500 characters (inclusive), but provided image description was %d chars"}`, len(description)) + suite.Equal(expectedErr, string(b)) +} + +func (suite *MediaCreateTestSuite) TestMediaCreateTooShortDescription() { + // set the min description length + viper.Set(config.Keys.MediaDescriptionMinChars, 500) + + // set up the context for the request + t := suite.testTokens["local_account_1"] + oauthToken := oauth.DBTokenToToken(t) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauthToken) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + + // create the request + buf, w, err := testrig.CreateMultipartFormData("file", "../../../../testrig/media/test-jpeg.jpg", map[string]string{ + "description": "", // provide an empty description + "focus": "-0.5,0.5", + }) + if err != nil { + panic(err) + } + ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", mediamodule.BasePath), bytes.NewReader(buf.Bytes())) // the endpoint we're hitting + ctx.Request.Header.Set("Content-Type", w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + + // do the actual request + suite.mediaModule.MediaCreatePOSTHandler(ctx) + + // check response -- there should be no error because minimum description length is checked on *UPDATE*, not initial upload + suite.EqualValues(http.StatusOK, recorder.Code) } func TestMediaCreateTestSuite(t *testing.T) { diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go new file mode 100644 index 000000000..cac6c304e --- /dev/null +++ b/internal/api/client/media/mediaupdate_test.go @@ -0,0 +1,235 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package media_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "codeberg.org/gruf/go-store/kv" + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stretchr/testify/suite" + mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" + "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type MediaUpdateTestSuite struct { + // standard suite interfaces + suite.Suite + db db.DB + storage *kv.KVStore + federator federation.Federator + tc typeutils.TypeConverter + mediaHandler media.Handler + oauthServer oauth.Server + emailSender email.Sender + processor processing.Processor + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testAttachments map[string]*gtsmodel.MediaAttachment + + // item being tested + mediaModule *mediamodule.Module +} + +/* + TEST INFRASTRUCTURE +*/ + +func (suite *MediaUpdateTestSuite) SetupSuite() { + // setup standard items + testrig.InitTestConfig() + testrig.InitTestLog() + suite.db = testrig.NewTestDB() + suite.storage = testrig.NewTestStorage() + suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.mediaHandler = testrig.NewTestMediaHandler(suite.db, suite.storage) + suite.oauthServer = testrig.NewTestOauthServer(suite.db) + suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db), suite.storage) + suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) + suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender) + + // setup module being tested + suite.mediaModule = mediamodule.New(suite.processor).(*mediamodule.Module) +} + +func (suite *MediaUpdateTestSuite) TearDownSuite() { + if err := suite.db.Stop(context.Background()); err != nil { + logrus.Panicf("error closing db connection: %s", err) + } +} + +func (suite *MediaUpdateTestSuite) SetupTest() { + testrig.StandardDBSetup(suite.db, nil) + testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() +} + +func (suite *MediaUpdateTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) + testrig.StandardStorageTeardown(suite.storage) +} + +/* + ACTUAL TESTS +*/ + +func (suite *MediaUpdateTestSuite) TestUpdateImage() { + toUpdate := suite.testAttachments["local_account_1_unattached_1"] + + // set up the context for the request + t := suite.testTokens["local_account_1"] + oauthToken := oauth.DBTokenToToken(t) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauthToken) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + + // create the request + buf, w, err := testrig.CreateMultipartFormData("", "", map[string]string{ + "id": toUpdate.ID, + "description": "new description!", + "focus": "-0.1,0.3", + }) + if err != nil { + panic(err) + } + ctx.Request = httptest.NewRequest(http.MethodPut, fmt.Sprintf("http://localhost:8080/%s/%s", mediamodule.BasePath, toUpdate.ID), bytes.NewReader(buf.Bytes())) // the endpoint we're hitting + ctx.Request.Header.Set("Content-Type", w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + ctx.Params = gin.Params{ + gin.Param{ + Key: mediamodule.IDKey, + Value: toUpdate.ID, + }, + } + + // do the actual request + suite.mediaModule.MediaPUTHandler(ctx) + + // check response + suite.EqualValues(http.StatusOK, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + // reply should be an attachment + attachmentReply := &model.Attachment{} + err = json.Unmarshal(b, attachmentReply) + suite.NoError(err) + + // the reply should contain the updated fields + suite.Equal("new description!", attachmentReply.Description) + suite.EqualValues("gif", attachmentReply.Type) + suite.EqualValues(model.MediaMeta{ + Original: model.MediaDimensions{Width: 800, Height: 450, FrameRate: "", Duration: 0, Bitrate: 0, Size: "800x450", Aspect: 1.7777778}, + Small: model.MediaDimensions{Width: 256, Height: 144, FrameRate: "", Duration: 0, Bitrate: 0, Size: "256x144", Aspect: 1.7777778}, + Focus: model.MediaFocus{X: -0.1, Y: 0.3}, + }, attachmentReply.Meta) + suite.Equal(toUpdate.Blurhash, attachmentReply.Blurhash) + suite.Equal(toUpdate.ID, attachmentReply.ID) + suite.Equal(toUpdate.URL, attachmentReply.URL) + suite.NotEmpty(toUpdate.Thumbnail.URL, attachmentReply.PreviewURL) +} + +func (suite *MediaUpdateTestSuite) TestUpdateImageShortDescription() { + // set the min description length + viper.Set(config.Keys.MediaDescriptionMinChars, 50) + + toUpdate := suite.testAttachments["local_account_1_unattached_1"] + + // set up the context for the request + t := suite.testTokens["local_account_1"] + oauthToken := oauth.DBTokenToToken(t) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauthToken) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + + // create the request + buf, w, err := testrig.CreateMultipartFormData("", "", map[string]string{ + "id": toUpdate.ID, + "description": "new description!", + "focus": "-0.1,0.3", + }) + if err != nil { + panic(err) + } + ctx.Request = httptest.NewRequest(http.MethodPut, fmt.Sprintf("http://localhost:8080/%s/%s", mediamodule.BasePath, toUpdate.ID), bytes.NewReader(buf.Bytes())) // the endpoint we're hitting + ctx.Request.Header.Set("Content-Type", w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + ctx.Params = gin.Params{ + gin.Param{ + Key: mediamodule.IDKey, + Value: toUpdate.ID, + }, + } + + // do the actual request + suite.mediaModule.MediaPUTHandler(ctx) + + // check response + suite.EqualValues(http.StatusBadRequest, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + // reply should be an error message + suite.Equal(`{"error":"image description length must be between 50 and 500 characters (inclusive), but provided image description was 16 chars"}`, string(b)) +} + +func TestMediaUpdateTestSuite(t *testing.T) { + suite.Run(t, new(MediaUpdateTestSuite)) +} diff --git a/internal/api/model/status.go b/internal/api/model/status.go index 3ff3f791d..fade58a49 100644 --- a/internal/api/model/status.go +++ b/internal/api/model/status.go @@ -96,6 +96,36 @@ type Status struct { Text string `json:"text"` } +/* +** The below functions are added onto the API model status so that it satisfies +** the Preparable interface in internal/timeline. + */ + +func (s *Status) GetID() string { + return s.ID +} + +func (s *Status) GetAccountID() string { + if s.Account != nil { + return s.Account.ID + } + return "" +} + +func (s *Status) GetBoostOfID() string { + if s.Reblog != nil { + return s.Reblog.ID + } + return "" +} + +func (s *Status) GetBoostOfAccountID() string { + if s.Reblog != nil && s.Reblog.Account != nil { + return s.Reblog.Account.ID + } + return "" +} + // StatusReblogged represents a reblogged status. // // swagger:model statusReblogged diff --git a/internal/api/s2s/webfinger/webfingerget.go b/internal/api/s2s/webfinger/webfingerget.go index e05b8e388..5d237408f 100644 --- a/internal/api/s2s/webfinger/webfingerget.go +++ b/internal/api/s2s/webfinger/webfingerget.go @@ -28,7 +28,6 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/config" ) @@ -68,11 +67,6 @@ func (m *Module) WebfingerGETRequest(c *gin.Context) { return } - if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { - c.JSON(http.StatusNotAcceptable, gin.H{"error": err.Error()}) - return - } - // remove the acct: prefix if it's present trimAcct := strings.TrimPrefix(resourceQuery, "acct:") // remove the first @ in @whatever@example.org if it's present diff --git a/internal/api/s2s/webfinger/webfingerget_test.go b/internal/api/s2s/webfinger/webfingerget_test.go index c5df1f7e5..d3b0c32e8 100644 --- a/internal/api/s2s/webfinger/webfingerget_test.go +++ b/internal/api/s2s/webfinger/webfingerget_test.go @@ -69,7 +69,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() { func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() { viper.Set(config.Keys.Host, "gts.example.org") viper.Set(config.Keys.AccountDomain, "example.org") - suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, testrig.NewTestTimelineManager(suite.db), suite.db, suite.emailSender) + suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaHandler(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender) suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module) targetAccount := accountDomainAccount() @@ -103,7 +103,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() { viper.Set(config.Keys.Host, "gts.example.org") viper.Set(config.Keys.AccountDomain, "example.org") - suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, testrig.NewTestTimelineManager(suite.db), suite.db, suite.emailSender) + suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaHandler(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender) suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module) targetAccount := accountDomainAccount() diff --git a/internal/api/security/tokencheck.go b/internal/api/security/tokencheck.go index b68f0b94f..e366af2ea 100644 --- a/internal/api/security/tokencheck.go +++ b/internal/api/security/tokencheck.go @@ -62,6 +62,22 @@ func (m *Module) TokenCheck(c *gin.Context) { l.Warnf("no user found for userID %s", userID) return } + + if user.ConfirmedAt.IsZero() { + l.Warnf("authenticated user %s has never confirmed thier email address", userID) + return + } + + if !user.Approved { + l.Warnf("authenticated user %s's account was never approved by an admin", userID) + return + } + + if user.Disabled { + l.Warnf("authenticated user %s's account was disabled'", userID) + return + } + c.Set(oauth.SessionAuthorizedUser, user) // fetch account for this token @@ -74,6 +90,12 @@ func (m *Module) TokenCheck(c *gin.Context) { l.Warnf("no account found for userID %s", userID) return } + + if !acct.SuspendedAt.IsZero() { + l.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID) + return + } + c.Set(oauth.SessionAuthorizedAccount, acct) } diff --git a/internal/config/defaults.go b/internal/config/defaults.go index cf3e7b449..322e1e2c1 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -63,7 +63,7 @@ var Defaults = Values{ StatusesPollOptionMaxChars: 50, StatusesMediaMaxFiles: 6, - LetsEncryptEnabled: true, + LetsEncryptEnabled: false, LetsEncryptPort: 80, LetsEncryptCertDir: "/gotosocial/storage/certs", LetsEncryptEmailAddress: "", diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 37c0db6d3..a92834f9c 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -94,13 +94,13 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, // if something went wrong while creating a user, we might already have an account, so check here first... acct := >smodel.Account{} - err = a.conn.NewSelect(). + q := a.conn.NewSelect(). Model(acct). Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")). - Scan(ctx) - if err != nil { - // we just don't have an account yet so create one + WhereGroup(" AND ", whereEmptyOrNull("domain")) + + if err := q.Scan(ctx); err != nil { + // we just don't have an account yet so create one before we proceed accountURIs := uris.GenerateURIsForAccount(username) accountID, err := id.NewRandomULID() if err != nil { @@ -125,6 +125,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, FollowingURI: accountURIs.FollowingURI, FeaturedCollectionURI: accountURIs.CollectionURI, } + if _, err = a.conn. NewInsert(). Model(acct). @@ -158,6 +159,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, if emailVerified { u.ConfirmedAt = time.Now() u.Email = email + u.UnconfirmedEmail = "" } if admin { diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 47fe4fb47..ebdbc4ba2 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -204,7 +204,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { } func sqliteConn(ctx context.Context) (*DBConn, error) { + // validate db address has actually been set dbAddress := viper.GetString(config.Keys.DbAddress) + if dbAddress == "" { + return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.Keys.DbAddress) + } // Drop anything fancy from DB address dbAddress = strings.Split(dbAddress, "?")[0] diff --git a/internal/db/bundb/bundbnew_test.go b/internal/db/bundb/bundbnew_test.go new file mode 100644 index 000000000..40a05cb50 --- /dev/null +++ b/internal/db/bundb/bundbnew_test.go @@ -0,0 +1,52 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" +) + +type BundbNewTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *BundbNewTestSuite) TestCreateNewDB() { + // create a new db with standard test settings + db, err := bundb.NewBunDBService(context.Background()) + suite.NoError(err) + suite.NotNil(db) +} + +func (suite *BundbNewTestSuite) TestCreateNewSqliteDBNoAddress() { + // create a new db with no address specified + viper.Set(config.Keys.DbAddress, "") + db, err := bundb.NewBunDBService(context.Background()) + suite.EqualError(err, "'db-address' was not set when attempting to start sqlite") + suite.Nil(db) +} + +func TestBundbNewTestSuite(t *testing.T) { + suite.Run(t, new(BundbNewTestSuite)) +} diff --git a/internal/db/bundb/errors.go b/internal/db/bundb/errors.go index 7d0157373..113679226 100644 --- a/internal/db/bundb/errors.go +++ b/internal/db/bundb/errors.go @@ -19,7 +19,7 @@ func processPostgresError(err error) db.Error { // (https://www.postgresql.org/docs/10/errcodes-appendix.html) switch pgErr.Code { case "23505" /* unique_violation */ : - return db.ErrAlreadyExists + return db.NewErrAlreadyExists(pgErr.Message) default: return err } @@ -36,7 +36,7 @@ func processSQLiteError(err error) db.Error { // Handle supplied error code: switch sqliteErr.Code() { case sqlite3.SQLITE_CONSTRAINT_UNIQUE, sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY: - return db.ErrAlreadyExists + return db.NewErrAlreadyExists(err.Error()) default: return err } diff --git a/internal/db/error.go b/internal/db/error.go index 984f96401..9ac0b6aa0 100644 --- a/internal/db/error.go +++ b/internal/db/error.go @@ -28,8 +28,19 @@ var ( ErrNoEntries Error = fmt.Errorf("no entries") // ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. ErrMultipleEntries Error = fmt.Errorf("multiple entries") - // ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db. - ErrAlreadyExists Error = fmt.Errorf("already exists") // ErrUnknown denotes an unknown database error. ErrUnknown Error = fmt.Errorf("unknown error") ) + +// ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db. +type ErrAlreadyExists struct { + message string +} + +func (e *ErrAlreadyExists) Error() string { + return e.message +} + +func NewErrAlreadyExists(msg string) error { + return &ErrAlreadyExists{message: msg} +} diff --git a/internal/email/confirm.go b/internal/email/confirm.go index 4503137b3..34e2fb660 100644 --- a/internal/email/confirm.go +++ b/internal/email/confirm.go @@ -21,11 +21,15 @@ package email import ( "bytes" "net/smtp" + + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/superseriousbusiness/gotosocial/internal/config" ) const ( - confirmTemplate = "email_confirm.tmpl" - confirmSubject = "Subject: GoToSocial Email Confirmation" + confirmTemplate = "email_confirm_text.tmpl" + confirmSubject = "GoToSocial Email Confirmation" ) func (s *sender) SendConfirmEmail(toAddress string, data ConfirmData) error { @@ -35,7 +39,11 @@ func (s *sender) SendConfirmEmail(toAddress string, data ConfirmData) error { } confirmBody := buf.String() - msg := assembleMessage(confirmSubject, confirmBody, toAddress, s.from) + msg, err := assembleMessage(confirmSubject, confirmBody, toAddress, s.from) + if err != nil { + return err + } + logrus.WithField("func", "SendConfirmEmail").Trace(s.hostAddress + "\n" + viper.GetString(config.Keys.SMTPUsername) + ":password" + "\n" + s.from + "\n" + toAddress + "\n\n" + string(msg) + "\n") return smtp.SendMail(s.hostAddress, s.auth, s.from, []string{toAddress}, msg) } diff --git a/internal/email/noopsender.go b/internal/email/noopsender.go index efec303f0..9f587f319 100644 --- a/internal/email/noopsender.go +++ b/internal/email/noopsender.go @@ -20,7 +20,7 @@ package email import ( "bytes" - "html/template" + "text/template" "github.com/sirupsen/logrus" "github.com/spf13/viper" @@ -57,7 +57,10 @@ func (s *noopSender) SendConfirmEmail(toAddress string, data ConfirmData) error } confirmBody := buf.String() - msg := assembleMessage(confirmSubject, confirmBody, toAddress, "test@example.org") + msg, err := assembleMessage(confirmSubject, confirmBody, toAddress, "test@example.org") + if err != nil { + return err + } logrus.Tracef("NOT SENDING confirmation email to %s with contents: %s", toAddress, msg) @@ -74,7 +77,10 @@ func (s *noopSender) SendResetEmail(toAddress string, data ResetData) error { } resetBody := buf.String() - msg := assembleMessage(resetSubject, resetBody, toAddress, "test@example.org") + msg, err := assembleMessage(resetSubject, resetBody, toAddress, "test@example.org") + if err != nil { + return err + } logrus.Tracef("NOT SENDING reset email to %s with contents: %s", toAddress, msg) diff --git a/internal/email/reset.go b/internal/email/reset.go index 7a08ebda9..b646ef99b 100644 --- a/internal/email/reset.go +++ b/internal/email/reset.go @@ -24,8 +24,8 @@ import ( ) const ( - resetTemplate = "email_reset.tmpl" - resetSubject = "Subject: GoToSocial Password Reset" + resetTemplate = "email_reset_text.tmpl" + resetSubject = "GoToSocial Password Reset" ) func (s *sender) SendResetEmail(toAddress string, data ResetData) error { @@ -35,7 +35,10 @@ func (s *sender) SendResetEmail(toAddress string, data ResetData) error { } resetBody := buf.String() - msg := assembleMessage(resetSubject, resetBody, toAddress, s.from) + msg, err := assembleMessage(resetSubject, resetBody, toAddress, s.from) + if err != nil { + return err + } return smtp.SendMail(s.hostAddress, s.auth, s.from, []string{toAddress}, msg) } diff --git a/internal/email/sender.go b/internal/email/sender.go index 97bbcd23b..f44627496 100644 --- a/internal/email/sender.go +++ b/internal/email/sender.go @@ -20,8 +20,8 @@ package email import ( "fmt" - "html/template" "net/smtp" + "text/template" "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/config" diff --git a/internal/email/util.go b/internal/email/util.go index db95128fa..52290dbe4 100644 --- a/internal/email/util.go +++ b/internal/email/util.go @@ -19,15 +19,12 @@ package email import ( + "errors" "fmt" - "html/template" "os" "path/filepath" -) - -const ( - mime = `MIME-version: 1.0; -Content-Type: text/html;` + "strings" + "text/template" ) func loadTemplates(templateBaseDir string) (*template.Template, error) { @@ -41,16 +38,34 @@ func loadTemplates(templateBaseDir string) (*template.Template, error) { return template.ParseGlob(tmPath) } -func assembleMessage(mailSubject string, mailBody string, mailTo string, mailFrom string) []byte { - from := fmt.Sprintf("From: GoToSocial <%s>", mailFrom) - to := fmt.Sprintf("To: %s", mailTo) +// https://datatracker.ietf.org/doc/html/rfc2822 +// I did not read the RFC, I just copy and pasted from +// https://pkg.go.dev/net/smtp#SendMail +// and it did seem to work. +func assembleMessage(mailSubject string, mailBody string, mailTo string, mailFrom string) ([]byte, error) { + + if strings.Contains(mailSubject, "\r") || strings.Contains(mailSubject, "\n") { + return nil, errors.New("email subject must not contain newline characters") + } + + if strings.Contains(mailFrom, "\r") || strings.Contains(mailFrom, "\n") { + return nil, errors.New("email from address must not contain newline characters") + } + + if strings.Contains(mailTo, "\r") || strings.Contains(mailTo, "\n") { + return nil, errors.New("email to address must not contain newline characters") + } + + // normalize the message body to use CRLF line endings + mailBody = strings.ReplaceAll(mailBody, "\r\n", "\n") + mailBody = strings.ReplaceAll(mailBody, "\n", "\r\n") msg := []byte( - mailSubject + "\r\n" + - from + "\r\n" + - to + "\r\n" + - mime + "\r\n" + - mailBody + "\r\n") + "To: " + mailTo + "\r\n" + + "Subject: " + mailSubject + "\r\n" + + "\r\n" + + mailBody + "\r\n", + ) - return msg + return msg, nil } diff --git a/internal/email/util_test.go b/internal/email/util_test.go index b5c7a9852..8895785f7 100644 --- a/internal/email/util_test.go +++ b/internal/email/util_test.go @@ -39,7 +39,7 @@ func (suite *UtilTestSuite) TestTemplateConfirm() { suite.sender.SendConfirmEmail("user@example.org", confirmData) suite.Len(suite.sentEmails, 1) - suite.Equal("Subject: GoToSocial Email Confirmation\r\nFrom: GoToSocial <test@example.org>\r\nTo: user@example.org\r\nMIME-version: 1.0;\nContent-Type: text/html;\r\n<!DOCTYPE html>\n<html>\n </head>\n <body>\n <div>\n <h1>\n Hello test!\n </h1>\n </div>\n <div>\n <p>\n You are receiving this mail because you've requested an account on <a href=\"https://example.org\">Test Instance</a>.\n </p>\n <p>\n We just need to confirm that this is your email address. To confirm your email, <a href=\"https://example.org/confirm_email?token=ee24f71d-e615-43f9-afae-385c0799b7fa\">click here</a> or paste the following in your browser's address bar:\n </p>\n <p>\n <code>\n https://example.org/confirm_email?token=ee24f71d-e615-43f9-afae-385c0799b7fa\n </code>\n </p>\n </div>\n <div>\n <p>\n If you believe you've been sent this email in error, feel free to ignore it, or contact the administrator of <a href=\"https://example.org\">Test Instance</a>.\n </p>\n </div>\n </body>\n</html>\r\n", suite.sentEmails["user@example.org"]) + suite.Equal("To: user@example.org\r\nSubject: GoToSocial Email Confirmation\r\n\r\nHello test!\r\n\r\nYou are receiving this mail because you've requested an account on https://example.org.\r\n\r\nWe just need to confirm that this is your email address. To confirm your email, paste the following in your browser's address bar:\r\n\r\nhttps://example.org/confirm_email?token=ee24f71d-e615-43f9-afae-385c0799b7fa\r\n\r\nIf you believe you've been sent this email in error, feel free to ignore it, or contact the administrator of https://example.org\r\n\r\n", suite.sentEmails["user@example.org"]) } func (suite *UtilTestSuite) TestTemplateReset() { @@ -52,7 +52,7 @@ func (suite *UtilTestSuite) TestTemplateReset() { suite.sender.SendResetEmail("user@example.org", resetData) suite.Len(suite.sentEmails, 1) - suite.Equal("Subject: GoToSocial Password Reset\r\nFrom: GoToSocial <test@example.org>\r\nTo: user@example.org\r\nMIME-version: 1.0;\nContent-Type: text/html;\r\n<!DOCTYPE html>\n<html>\n </head>\n <body>\n <div>\n <h1>\n Hello test!\n </h1>\n </div>\n <div>\n <p>\n You are receiving this mail because a password reset has been requested for your account on <a href=\"https://example.org\">Test Instance</a>.\n </p>\n <p>\n To reset your password, <a href=\"https://example.org/reset_email?token=ee24f71d-e615-43f9-afae-385c0799b7fa\">click here</a> or paste the following in your browser's address bar:\n </p>\n <p>\n <code>\n https://example.org/reset_email?token=ee24f71d-e615-43f9-afae-385c0799b7fa\n </code>\n </p>\n </div>\n <div>\n <p>\n If you believe you've been sent this email in error, feel free to ignore it, or contact the administrator of <a href=\"https://example.org\">Test Instance</a>.\n </p>\n </div>\n </body>\n</html>\r\n", suite.sentEmails["user@example.org"]) + suite.Equal("To: user@example.org\r\nSubject: GoToSocial Password Reset\r\n\r\nHello test!\r\n\r\nYou are receiving this mail because a password reset has been requested for your account on https://example.org.\r\n\r\nTo reset your password, paste the following in your browser's address bar:\r\n\r\nhttps://example.org/reset_email?token=ee24f71d-e615-43f9-afae-385c0799b7fa\r\n\r\nIf you believe you've been sent this email in error, feel free to ignore it, or contact the administrator of https://example.org.\r\n\r\n", suite.sentEmails["user@example.org"]) } func TestUtilTestSuite(t *testing.T) { diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go index b07f4f04f..6c86151f3 100644 --- a/internal/federation/federatingdb/create.go +++ b/internal/federation/federatingdb/create.go @@ -231,7 +231,8 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream status.ID = statusID if err := f.db.PutStatus(ctx, status); err != nil { - if err == db.ErrAlreadyExists { + var alreadyExistsError *db.ErrAlreadyExists + if errors.As(err, &alreadyExistsError) { // the status already exists in the database, which means we've already handled everything else, // so we can just return nil here and be done with it. return nil diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go index e80924ca3..e798ea41b 100644 --- a/internal/gtsmodel/status.go +++ b/internal/gtsmodel/status.go @@ -66,6 +66,27 @@ type Status struct { Likeable bool `validate:"-" bun:",notnull"` // This status can be liked/faved } +/* + The below functions are added onto the gtsmodel status so that it satisfies + the Timelineable interface in internal/timeline. +*/ + +func (s *Status) GetID() string { + return s.ID +} + +func (s *Status) GetAccountID() string { + return s.AccountID +} + +func (s *Status) GetBoostOfID() string { + return s.BoostOfID +} + +func (s *Status) GetBoostOfAccountID() string { + return s.BoostOfAccountID +} + // StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags. type StatusToTag struct { StatusID string `validate:"ulid,required" bun:"type:CHAR(26),unique:statustag,nullzero,notnull"` diff --git a/internal/oauth/util.go b/internal/oauth/util.go index 540045f80..6f69f0ee4 100644 --- a/internal/oauth/util.go +++ b/internal/oauth/util.go @@ -78,25 +78,12 @@ func Authed(c *gin.Context, requireToken bool, requireApp bool, requireUser bool return nil, errors.New("application not supplied") } - if requireUser { - if a.User == nil { - return nil, errors.New("user not supplied") - } - if a.User.Disabled || !a.User.Approved { - return nil, errors.New("user disabled or not approved") - } - if a.User.Email == "" { - return nil, errors.New("user has no confirmed email address") - } + if requireUser && a.User == nil { + return nil, errors.New("user not supplied or not authorized") } - if requireAccount { - if a.Account == nil { - return nil, errors.New("account not supplied") - } - if !a.Account.SuspendedAt.IsZero() { - return nil, errors.New("account suspended") - } + if requireAccount && a.Account == nil { + return nil, errors.New("account not supplied or not authorized") } return a, nil diff --git a/internal/processing/federation/getwebfinger.go b/internal/processing/federation/getwebfinger.go index 14536549d..cbc4a7ebc 100644 --- a/internal/processing/federation/getwebfinger.go +++ b/internal/processing/federation/getwebfinger.go @@ -44,6 +44,9 @@ func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername s } accountDomain := viper.GetString(config.Keys.AccountDomain) + if accountDomain == "" { + accountDomain = viper.GetString(config.Keys.Host) + } // return the webfinger representation return &apimodel.WellKnownResponse{ diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index 51c896291..11ce2215e 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -192,10 +192,10 @@ func (p *processor) processCreateBlockFromClientAPI(ctx context.Context, clientM } // remove any of the blocking account's statuses from the blocked account's timeline, and vice versa - if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { + if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { return err } - if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { + if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { return err } diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 3c52cf669..8e7f20145 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -413,7 +413,7 @@ func (p *processor) timelineStatusForAccount(ctx context.Context, status *gtsmod } // stick the status in the timeline for the account and then immediately prepare it so they can see it right away - inserted, err := p.timelineManager.IngestAndPrepare(ctx, status, timelineAccount.ID) + inserted, err := p.statusTimelines.IngestAndPrepare(ctx, status, timelineAccount.ID) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %s", status.ID, err) return @@ -436,7 +436,7 @@ func (p *processor) timelineStatusForAccount(ctx context.Context, status *gtsmod // deleteStatusFromTimelines completely removes the given status from all timelines. // It will also stream deletion of the status to all open streams. func (p *processor) deleteStatusFromTimelines(ctx context.Context, status *gtsmodel.Status) error { - if err := p.timelineManager.WipeStatusFromAllTimelines(ctx, status.ID); err != nil { + if err := p.statusTimelines.WipeItemFromAllTimelines(ctx, status.ID); err != nil { return err } diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index 3514614b5..bb2cb5323 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -309,10 +309,10 @@ func (p *processor) processCreateBlockFromFederator(ctx context.Context, federat } // remove any of the blocking account's statuses from the blocked account's timeline, and vice versa - if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { + if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { return err } - if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { + if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { return err } // TODO: same with notifications diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 2406681ea..46d17a160 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -237,7 +237,7 @@ type processor struct { oauthServer oauth.Server mediaManager media.Manager storage *kv.KVStore - timelineManager timeline.Manager + statusTimelines timeline.Manager db db.DB filter visibility.Filter @@ -261,7 +261,6 @@ func NewProcessor( oauthServer oauth.Server, mediaManager media.Manager, storage *kv.KVStore, - timelineManager timeline.Manager, db db.DB, emailSender email.Sender) Processor { fromClientAPI := make(chan messages.FromClientAPI, 1000) @@ -274,6 +273,7 @@ func NewProcessor( mediaProcessor := mediaProcessor.New(db, tc, mediaManager, storage) userProcessor := user.New(db, emailSender) federationProcessor := federationProcessor.New(db, tc, federator, fromFederator) + filter := visibility.NewFilter(db) return &processor{ fromClientAPI: fromClientAPI, @@ -284,7 +284,7 @@ func NewProcessor( oauthServer: oauthServer, mediaManager: mediaManager, storage: storage, - timelineManager: timelineManager, + statusTimelines: timeline.NewManager(StatusGrabFunction(db), StatusFilterFunction(db, filter), StatusPrepareFunction(db, tc), StatusSkipInsertFunction()), db: db, filter: visibility.NewFilter(db), diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go index 851d3d5fb..794bcc197 100644 --- a/internal/processing/processor_test.go +++ b/internal/processing/processor_test.go @@ -219,10 +219,9 @@ func (suite *ProcessingStandardTestSuite) SetupTest() { suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager) suite.oauthServer = testrig.NewTestOauthServer(suite.db) - suite.timelineManager = testrig.NewTestTimelineManager(suite.db) suite.emailSender = testrig.NewEmailSender("../../web/template/", nil) - suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, suite.storage, suite.timelineManager, suite.db, suite.emailSender) + suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, suite.storage, suite.db, suite.emailSender) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../testrig/media") diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go index 05a3bf48e..f2640929d 100644 --- a/internal/processing/status/util.go +++ b/internal/processing/status/util.go @@ -223,8 +223,11 @@ func (p *processor) ProcessTags(ctx context.Context, form *apimodel.AdvancedStat return fmt.Errorf("error generating hashtags from status: %s", err) } for _, tag := range gtsTags { - if err := p.db.Put(ctx, tag); err != nil && err != db.ErrAlreadyExists { - return fmt.Errorf("error putting tags in db: %s", err) + if err := p.db.Put(ctx, tag); err != nil { + var alreadyExistsError *db.ErrAlreadyExists + if !errors.As(err, &alreadyExistsError) { + return fmt.Errorf("error putting tags in db: %s", err) + } } tags = append(tags, tag.ID) } diff --git a/internal/processing/timeline.go b/internal/processing/statustimeline.go index 2e2b7d637..355825900 100644 --- a/internal/processing/timeline.go +++ b/internal/processing/statustimeline.go @@ -20,6 +20,7 @@ package processing import ( "context" + "errors" "fmt" "net/url" @@ -32,8 +33,113 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" ) +const boostReinsertionDepth = 50 + +// StatusGrabFunction returns a function that satisfies the GrabFunction interface in internal/timeline. +func StatusGrabFunction(database db.DB) timeline.GrabFunction { + return func(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) { + statuses, err := database.GetHomeTimeline(ctx, timelineAccountID, maxID, sinceID, minID, limit, false) + if err != nil { + if err == db.ErrNoEntries { + return nil, true, nil // we just don't have enough statuses left in the db so return stop = true + } + return nil, false, fmt.Errorf("statusGrabFunction: error getting statuses from db: %s", err) + } + + items := []timeline.Timelineable{} + for _, s := range statuses { + items = append(items, s) + } + + return items, false, nil + } +} + +// StatusFilterFunction returns a function that satisfies the FilterFunction interface in internal/timeline. +func StatusFilterFunction(database db.DB, filter visibility.Filter) timeline.FilterFunction { + return func(ctx context.Context, timelineAccountID string, item timeline.Timelineable) (shouldIndex bool, err error) { + status, ok := item.(*gtsmodel.Status) + if !ok { + return false, errors.New("statusFilterFunction: could not convert item to *gtsmodel.Status") + } + + requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID) + if err != nil { + return false, fmt.Errorf("statusFilterFunction: error getting account with id %s", timelineAccountID) + } + + timelineable, err := filter.StatusHometimelineable(ctx, status, requestingAccount) + if err != nil { + logrus.Warnf("error checking hometimelineability of status %s for account %s: %s", status.ID, timelineAccountID, err) + } + + return timelineable, nil // we don't return the error here because we want to just skip this item if something goes wrong + } +} + +// StatusPrepareFunction returns a function that satisfies the PrepareFunction interface in internal/timeline. +func StatusPrepareFunction(database db.DB, tc typeutils.TypeConverter) timeline.PrepareFunction { + return func(ctx context.Context, timelineAccountID string, itemID string) (timeline.Preparable, error) { + status, err := database.GetStatusByID(ctx, itemID) + if err != nil { + return nil, fmt.Errorf("statusPrepareFunction: error getting status with id %s", itemID) + } + + requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID) + if err != nil { + return nil, fmt.Errorf("statusPrepareFunction: error getting account with id %s", timelineAccountID) + } + + return tc.StatusToAPIStatus(ctx, status, requestingAccount) + } +} + +// StatusSkipInsertFunction returns a function that satisifes the SkipInsertFunction interface in internal/timeline. +func StatusSkipInsertFunction() timeline.SkipInsertFunction { + return func( + ctx context.Context, + newItemID string, + newItemAccountID string, + newItemBoostOfID string, + newItemBoostOfAccountID string, + nextItemID string, + nextItemAccountID string, + nextItemBoostOfID string, + nextItemBoostOfAccountID string, + depth int) (bool, error) { + + // make sure we don't insert a duplicate + if newItemID == nextItemID { + return true, nil + } + + // check if it's a boost + if newItemBoostOfID != "" { + // skip if we've recently put another boost of this status in the timeline + if newItemBoostOfID == nextItemBoostOfID { + if depth < boostReinsertionDepth { + return true, nil + } + } + + // skip if we've recently put the original status in the timeline + if newItemBoostOfID == nextItemID { + if depth < boostReinsertionDepth { + return true, nil + } + } + } + + // insert the item + return false, nil + } +} + func (p *processor) packageStatusResponse(statuses []*apimodel.Status, path string, nextMaxID string, prevMinID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { resp := &apimodel.StatusTimelineResponse{ Statuses: []*apimodel.Status{}, @@ -67,18 +173,27 @@ func (p *processor) packageStatusResponse(statuses []*apimodel.Status, path stri } func (p *processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { - statuses, err := p.timelineManager.HomeTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local) + preparedItems, err := p.statusTimelines.GetTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local) if err != nil { return nil, gtserror.NewErrorInternalError(err) } - if len(statuses) == 0 { + if len(preparedItems) == 0 { return &apimodel.StatusTimelineResponse{ Statuses: []*apimodel.Status{}, }, nil } - return p.packageStatusResponse(statuses, "api/v1/timelines/home", statuses[len(statuses)-1].ID, statuses[0].ID, limit) + statuses := []*apimodel.Status{} + for _, i := range preparedItems { + status, ok := i.(*apimodel.Status) + if !ok { + return nil, gtserror.NewErrorInternalError(errors.New("error converting prepared timeline entry to api status")) + } + statuses = append(statuses, status) + } + + return p.packageStatusResponse(statuses, "api/v1/timelines/home", statuses[len(preparedItems)-1].ID, statuses[0].ID, limit) } func (p *processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { diff --git a/internal/processing/user/emailconfirm_test.go b/internal/processing/user/emailconfirm_test.go index 58836d40d..6f22306a1 100644 --- a/internal/processing/user/emailconfirm_test.go +++ b/internal/processing/user/emailconfirm_test.go @@ -54,7 +54,7 @@ func (suite *EmailConfirmTestSuite) TestSendConfirmEmail() { suite.NotEmpty(token) // email should contain the token - emailShould := fmt.Sprintf("Subject: GoToSocial Email Confirmation\r\nFrom: GoToSocial <test@example.org>\r\nTo: some.email@example.org\r\nMIME-version: 1.0;\nContent-Type: text/html;\r\n<!DOCTYPE html>\n<html>\n </head>\n <body>\n <div>\n <h1>\n Hello the_mighty_zork!\n </h1>\n </div>\n <div>\n <p>\n You are receiving this mail because you've requested an account on <a href=\"http://localhost:8080\">localhost:8080</a>.\n </p>\n <p>\n We just need to confirm that this is your email address. To confirm your email, <a href=\"http://localhost:8080/confirm_email?token=%s\">click here</a> or paste the following in your browser's address bar:\n </p>\n <p>\n <code>\n http://localhost:8080/confirm_email?token=%s\n </code>\n </p>\n </div>\n <div>\n <p>\n If you believe you've been sent this email in error, feel free to ignore it, or contact the administrator of <a href=\"http://localhost:8080\">localhost:8080</a>.\n </p>\n </div>\n </body>\n</html>\r\n", token, token) + emailShould := fmt.Sprintf("To: some.email@example.org\r\nSubject: GoToSocial Email Confirmation\r\n\r\nHello the_mighty_zork!\r\n\r\nYou are receiving this mail because you've requested an account on http://localhost:8080.\r\n\r\nWe just need to confirm that this is your email address. To confirm your email, paste the following in your browser's address bar:\r\n\r\nhttp://localhost:8080/confirm_email?token=%s\r\n\r\nIf you believe you've been sent this email in error, feel free to ignore it, or contact the administrator of http://localhost:8080\r\n\r\n", token) suite.Equal(emailShould, email) // confirmationSentAt should be recent diff --git a/internal/router/router.go b/internal/router/router.go index 88d900a9e..f1247d274 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -138,7 +138,7 @@ func New(ctx context.Context, db db.DB) (Router, error) { } // set template functions - loadTemplateFunctions(engine) + LoadTemplateFunctions(engine) // load templates onto the engine if err := loadTemplates(engine); err != nil { diff --git a/internal/router/session.go b/internal/router/session.go index 2127d70a7..066024601 100644 --- a/internal/router/session.go +++ b/internal/router/session.go @@ -33,8 +33,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" ) -// sessionOptions returns the standard set of options to use for each session. -func sessionOptions() sessions.Options { +// SessionOptions returns the standard set of options to use for each session. +func SessionOptions() sessions.Options { return sessions.Options{ Path: "/", Domain: viper.GetString(config.Keys.Host), @@ -75,7 +75,7 @@ func useSession(ctx context.Context, sessionDB db.Session, engine *gin.Engine) e } store := memstore.NewStore(rs.Auth, rs.Crypt) - store.Options(sessionOptions()) + store.Options(SessionOptions()) sessionName, err := SessionName() if err != nil { diff --git a/internal/router/template.go b/internal/router/template.go index 4cc9fde1d..1a0186d6d 100644 --- a/internal/router/template.go +++ b/internal/router/template.go @@ -31,7 +31,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" ) -// loadTemplates loads html templates for use by the given engine +// LoadTemplates loads html templates for use by the given engine func loadTemplates(engine *gin.Engine) error { cwd, err := os.Getwd() if err != nil { @@ -39,8 +39,13 @@ func loadTemplates(engine *gin.Engine) error { } templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir) - tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir)) + _, err = os.Stat(filepath.Join(cwd, templateBaseDir, "index.tmpl")) + if err != nil { + return fmt.Errorf("%s doesn't seem to contain the templates; index.tmpl is missing: %s", filepath.Join(cwd, templateBaseDir), err) + } + + tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir)) engine.LoadHTMLGlob(tmPath) return nil } @@ -87,7 +92,7 @@ func visibilityIcon(visibility model.Visibility) template.HTML { return template.HTML(fmt.Sprintf(`<i aria-label="Visibility: %v" class="fa fa-%v"></i>`, icon.label, icon.faIcon)) } -func loadTemplateFunctions(engine *gin.Engine) { +func LoadTemplateFunctions(engine *gin.Engine) { engine.SetFuncMap(template.FuncMap{ "noescape": noescape, "oddOrEven": oddOrEven, diff --git a/internal/timeline/get.go b/internal/timeline/get.go index da6dc6f76..a3cc8ef9b 100644 --- a/internal/timeline/get.go +++ b/internal/timeline/get.go @@ -25,12 +25,11 @@ import ( "fmt" "github.com/sirupsen/logrus" - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" ) const retries = 5 -func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) { +func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]Preparable, error) { l := logrus.WithFields(logrus.Fields{ "func": "Get", "accountID": t.accountID, @@ -41,16 +40,16 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st }) l.Debug("entering get") - var statuses []*apimodel.Status + var items []Preparable var err error // no params are defined to just fetch from the top - // this is equivalent to a user asking for the top x posts from their timeline + // this is equivalent to a user asking for the top x items from their timeline if maxID == "" && sinceID == "" && minID == "" { - statuses, err = t.GetXFromTop(ctx, amount) + items, err = t.GetXFromTop(ctx, amount) // aysnchronously prepare the next predicted query so it's ready when the user asks for it - if len(statuses) != 0 { - nextMaxID := statuses[len(statuses)-1].ID + if len(items) != 0 { + nextMaxID := items[len(items)-1].GetID() if prepareNext { // already cache the next query to speed up scrolling go func() { @@ -64,13 +63,13 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st } // maxID is defined but sinceID isn't so take from behind - // this is equivalent to a user asking for the next x posts from their timeline, starting from maxID + // this is equivalent to a user asking for the next x items from their timeline, starting from maxID if maxID != "" && sinceID == "" { attempts := 0 - statuses, err = t.GetXBehindID(ctx, amount, maxID, &attempts) + items, err = t.GetXBehindID(ctx, amount, maxID, &attempts) // aysnchronously prepare the next predicted query so it's ready when the user asks for it - if len(statuses) != 0 { - nextMaxID := statuses[len(statuses)-1].ID + if len(items) != 0 { + nextMaxID := items[len(items)-1].GetID() if prepareNext { // already cache the next query to speed up scrolling go func() { @@ -84,59 +83,59 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st } // maxID is defined and sinceID || minID are as well, so take a slice between them - // this is equivalent to a user asking for posts older than x but newer than y + // this is equivalent to a user asking for items older than x but newer than y if maxID != "" && sinceID != "" { - statuses, err = t.GetXBetweenID(ctx, amount, maxID, minID) + items, err = t.GetXBetweenID(ctx, amount, maxID, minID) } if maxID != "" && minID != "" { - statuses, err = t.GetXBetweenID(ctx, amount, maxID, minID) + items, err = t.GetXBetweenID(ctx, amount, maxID, minID) } // maxID isn't defined, but sinceID || minID are, so take x before - // this is equivalent to a user asking for posts newer than x (eg., refreshing the top of their timeline) + // this is equivalent to a user asking for items newer than x (eg., refreshing the top of their timeline) if maxID == "" && sinceID != "" { - statuses, err = t.GetXBeforeID(ctx, amount, sinceID, true) + items, err = t.GetXBeforeID(ctx, amount, sinceID, true) } if maxID == "" && minID != "" { - statuses, err = t.GetXBeforeID(ctx, amount, minID, true) + items, err = t.GetXBeforeID(ctx, amount, minID, true) } - return statuses, err + return items, err } -func (t *timeline) GetXFromTop(ctx context.Context, amount int) ([]*apimodel.Status, error) { - // make a slice of statuses with the length we need to return - statuses := make([]*apimodel.Status, 0, amount) +func (t *timeline) GetXFromTop(ctx context.Context, amount int) ([]Preparable, error) { + // make a slice of preparedItems with the length we need to return + preparedItems := make([]Preparable, 0, amount) - if t.preparedPosts.data == nil { - t.preparedPosts.data = &list.List{} + if t.preparedItems.data == nil { + t.preparedItems.data = &list.List{} } - // make sure we have enough posts prepared to return - if t.preparedPosts.data.Len() < amount { + // make sure we have enough items prepared to return + if t.preparedItems.data.Len() < amount { if err := t.PrepareFromTop(ctx, amount); err != nil { return nil, err } } - // work through the prepared posts from the top and return + // work through the prepared items from the top and return var served int - for e := t.preparedPosts.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*preparedPostsEntry) + for e := t.preparedItems.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*preparedItemsEntry) if !ok { - return nil, errors.New("GetXFromTop: could not parse e as a preparedPostsEntry") + return nil, errors.New("GetXFromTop: could not parse e as a preparedItemsEntry") } - statuses = append(statuses, entry.prepared) + preparedItems = append(preparedItems, entry.prepared) served++ if served >= amount { break } } - return statuses, nil + return preparedItems, nil } -func (t *timeline) GetXBehindID(ctx context.Context, amount int, behindID string, attempts *int) ([]*apimodel.Status, error) { +func (t *timeline) GetXBehindID(ctx context.Context, amount int, behindID string, attempts *int) ([]Preparable, error) { l := logrus.WithFields(logrus.Fields{ "func": "GetXBehindID", "amount": amount, @@ -148,11 +147,11 @@ func (t *timeline) GetXBehindID(ctx context.Context, amount int, behindID string newAttempts++ attempts = &newAttempts - // make a slice of statuses with the length we need to return - statuses := make([]*apimodel.Status, 0, amount) + // make a slice of items with the length we need to return + items := make([]Preparable, 0, amount) - if t.preparedPosts.data == nil { - t.preparedPosts.data = &list.List{} + if t.preparedItems.data == nil { + t.preparedItems.data = &list.List{} } // iterate through the modified list until we hit the mark we're looking for @@ -160,14 +159,14 @@ func (t *timeline) GetXBehindID(ctx context.Context, amount int, behindID string var behindIDMark *list.Element findMarkLoop: - for e := t.preparedPosts.data.Front(); e != nil; e = e.Next() { + for e := t.preparedItems.data.Front(); e != nil; e = e.Next() { position++ - entry, ok := e.Value.(*preparedPostsEntry) + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return nil, errors.New("GetXBehindID: could not parse e as a preparedPostsEntry") } - if entry.statusID <= behindID { + if entry.itemID <= behindID { l.Trace("found behindID mark") behindIDMark = e break findMarkLoop @@ -175,33 +174,33 @@ findMarkLoop: } // we didn't find it, so we need to make sure it's indexed and prepared and then try again - // this can happen when a user asks for really old posts + // this can happen when a user asks for really old items if behindIDMark == nil { if err := t.PrepareBehind(ctx, behindID, amount); err != nil { return nil, fmt.Errorf("GetXBehindID: error preparing behind and including ID %s", behindID) } - oldestID, err := t.OldestPreparedPostID(ctx) + oldestID, err := t.OldestPreparedItemID(ctx) if err != nil { return nil, err } if oldestID == "" { l.Tracef("oldestID is empty so we can't return behindID %s", behindID) - return statuses, nil + return items, nil } if oldestID == behindID { l.Tracef("given behindID %s is the same as oldestID %s so there's nothing to return behind it", behindID, oldestID) - return statuses, nil + return items, nil } if *attempts > retries { l.Tracef("exceeded retries looking for behindID %s", behindID) - return statuses, nil + return items, nil } l.Trace("trying GetXBehindID again") return t.GetXBehindID(ctx, amount, behindID, attempts) } - // make sure we have enough posts prepared behind it to return what we're being asked for - if t.preparedPosts.data.Len() < amount+position { + // make sure we have enough items prepared behind it to return what we're being asked for + if t.preparedItems.data.Len() < amount+position { if err := t.PrepareBehind(ctx, behindID, amount); err != nil { return nil, err } @@ -211,40 +210,40 @@ findMarkLoop: var served int serveloop: for e := behindIDMark.Next(); e != nil; e = e.Next() { - entry, ok := e.Value.(*preparedPostsEntry) + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return nil, errors.New("GetXBehindID: could not parse e as a preparedPostsEntry") } // serve up to the amount requested - statuses = append(statuses, entry.prepared) + items = append(items, entry.prepared) served++ if served >= amount { break serveloop } } - return statuses, nil + return items, nil } -func (t *timeline) GetXBeforeID(ctx context.Context, amount int, beforeID string, startFromTop bool) ([]*apimodel.Status, error) { - // make a slice of statuses with the length we need to return - statuses := make([]*apimodel.Status, 0, amount) +func (t *timeline) GetXBeforeID(ctx context.Context, amount int, beforeID string, startFromTop bool) ([]Preparable, error) { + // make a slice of items with the length we need to return + items := make([]Preparable, 0, amount) - if t.preparedPosts.data == nil { - t.preparedPosts.data = &list.List{} + if t.preparedItems.data == nil { + t.preparedItems.data = &list.List{} } // iterate through the modified list until we hit the mark we're looking for, or as close as possible to it var beforeIDMark *list.Element findMarkLoop: - for e := t.preparedPosts.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*preparedPostsEntry) + for e := t.preparedItems.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return nil, errors.New("GetXBeforeID: could not parse e as a preparedPostsEntry") } - if entry.statusID >= beforeID { + if entry.itemID >= beforeID { beforeIDMark = e } else { break findMarkLoop @@ -252,26 +251,26 @@ findMarkLoop: } if beforeIDMark == nil { - return statuses, nil + return items, nil } var served int if startFromTop { - // start serving from the front/top and keep going until we hit mark or get x amount statuses + // start serving from the front/top and keep going until we hit mark or get x amount items serveloopFromTop: - for e := t.preparedPosts.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*preparedPostsEntry) + for e := t.preparedItems.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return nil, errors.New("GetXBeforeID: could not parse e as a preparedPostsEntry") } - if entry.statusID == beforeID { + if entry.itemID == beforeID { break serveloopFromTop } // serve up to the amount requested - statuses = append(statuses, entry.prepared) + items = append(items, entry.prepared) served++ if served >= amount { break serveloopFromTop @@ -281,13 +280,13 @@ findMarkLoop: // start serving from the entry right before the mark serveloopFromBottom: for e := beforeIDMark.Prev(); e != nil; e = e.Prev() { - entry, ok := e.Value.(*preparedPostsEntry) + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return nil, errors.New("GetXBeforeID: could not parse e as a preparedPostsEntry") } // serve up to the amount requested - statuses = append(statuses, entry.prepared) + items = append(items, entry.prepared) served++ if served >= amount { break serveloopFromBottom @@ -295,29 +294,29 @@ findMarkLoop: } } - return statuses, nil + return items, nil } -func (t *timeline) GetXBetweenID(ctx context.Context, amount int, behindID string, beforeID string) ([]*apimodel.Status, error) { - // make a slice of statuses with the length we need to return - statuses := make([]*apimodel.Status, 0, amount) +func (t *timeline) GetXBetweenID(ctx context.Context, amount int, behindID string, beforeID string) ([]Preparable, error) { + // make a slice of items with the length we need to return + items := make([]Preparable, 0, amount) - if t.preparedPosts.data == nil { - t.preparedPosts.data = &list.List{} + if t.preparedItems.data == nil { + t.preparedItems.data = &list.List{} } // iterate through the modified list until we hit the mark we're looking for var position int var behindIDMark *list.Element findMarkLoop: - for e := t.preparedPosts.data.Front(); e != nil; e = e.Next() { + for e := t.preparedItems.data.Front(); e != nil; e = e.Next() { position++ - entry, ok := e.Value.(*preparedPostsEntry) + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return nil, errors.New("GetXBetweenID: could not parse e as a preparedPostsEntry") } - if entry.statusID == behindID { + if entry.itemID == behindID { behindIDMark = e break findMarkLoop } @@ -325,11 +324,11 @@ findMarkLoop: // we didn't find it if behindIDMark == nil { - return nil, fmt.Errorf("GetXBetweenID: couldn't find status with ID %s", behindID) + return nil, fmt.Errorf("GetXBetweenID: couldn't find item with ID %s", behindID) } - // make sure we have enough posts prepared behind it to return what we're being asked for - if t.preparedPosts.data.Len() < amount+position { + // make sure we have enough items prepared behind it to return what we're being asked for + if t.preparedItems.data.Len() < amount+position { if err := t.PrepareBehind(ctx, behindID, amount); err != nil { return nil, err } @@ -339,22 +338,22 @@ findMarkLoop: var served int serveloop: for e := behindIDMark.Next(); e != nil; e = e.Next() { - entry, ok := e.Value.(*preparedPostsEntry) + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return nil, errors.New("GetXBetweenID: could not parse e as a preparedPostsEntry") } - if entry.statusID == beforeID { + if entry.itemID == beforeID { break serveloop } // serve up to the amount requested - statuses = append(statuses, entry.prepared) + items = append(items, entry.prepared) served++ if served >= amount { break serveloop } } - return statuses, nil + return items, nil } diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go index b3a19f488..a1640790e 100644 --- a/internal/timeline/get_test.go +++ b/internal/timeline/get_test.go @@ -24,7 +24,9 @@ import ( "time" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -43,18 +45,26 @@ func (suite *GetTestSuite) SetupTest() { suite.db = testrig.NewTestDB() suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.filter = visibility.NewFilter(suite.db) testrig.StandardDBSetup(suite.db, nil) // let's take local_account_1 as the timeline owner - tl, err := timeline.NewTimeline(context.Background(), suite.testAccounts["local_account_1"].ID, suite.db, suite.tc) + tl, err := timeline.NewTimeline( + context.Background(), + suite.testAccounts["local_account_1"].ID, + processing.StatusGrabFunction(suite.db), + processing.StatusFilterFunction(suite.db, suite.filter), + processing.StatusPrepareFunction(suite.db, suite.tc), + processing.StatusSkipInsertFunction(), + ) if err != nil { suite.FailNow(err.Error()) } // prepare the timeline by just shoving all test statuses in it -- let's not be fussy about who sees what for _, s := range suite.testStatuses { - _, err := tl.IndexAndPrepareOne(context.Background(), s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID) + _, err := tl.IndexAndPrepareOne(context.Background(), s.GetID(), s.BoostOfID, s.AccountID, s.BoostOfAccountID) if err != nil { suite.FailNow(err.Error()) } @@ -81,10 +91,10 @@ func (suite *GetTestSuite) TestGetDefault() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } } @@ -102,10 +112,10 @@ func (suite *GetTestSuite) TestGetDefaultPrepareNext() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } @@ -127,10 +137,10 @@ func (suite *GetTestSuite) TestGetMaxID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } } @@ -149,10 +159,10 @@ func (suite *GetTestSuite) TestGetMaxIDPrepareNext() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } @@ -174,10 +184,10 @@ func (suite *GetTestSuite) TestGetMinID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } } @@ -196,10 +206,10 @@ func (suite *GetTestSuite) TestGetSinceID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } } @@ -218,10 +228,10 @@ func (suite *GetTestSuite) TestGetSinceIDPrepareNext() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } @@ -243,10 +253,10 @@ func (suite *GetTestSuite) TestGetBetweenID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } } @@ -265,10 +275,10 @@ func (suite *GetTestSuite) TestGetBetweenIDPrepareNext() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } @@ -289,10 +299,10 @@ func (suite *GetTestSuite) TestGetXFromTop() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } } } @@ -314,12 +324,12 @@ func (suite *GetTestSuite) TestGetXBehindID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } - suite.Less(s.ID, "01F8MHBQCBTDKN6X5VHGMMN4MA") + suite.Less(s.GetID(), "01F8MHBQCBTDKN6X5VHGMMN4MA") } } @@ -353,12 +363,12 @@ func (suite *GetTestSuite) TestGetXBehindNonexistentReasonableID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } - suite.Less(s.ID, "01F8MHBCN8120SYH7D5S050MGK") + suite.Less(s.GetID(), "01F8MHBCN8120SYH7D5S050MGK") } } @@ -380,12 +390,12 @@ func (suite *GetTestSuite) TestGetXBehindVeryHighID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } - suite.Less(s.ID, "9998MHBQCBTDKN6X5VHGMMN4MA") + suite.Less(s.GetID(), "9998MHBQCBTDKN6X5VHGMMN4MA") } } @@ -403,12 +413,12 @@ func (suite *GetTestSuite) TestGetXBeforeID() { var highest string for i, s := range statuses { if i == 0 { - highest = s.ID + highest = s.GetID() } else { - suite.Less(s.ID, highest) - highest = s.ID + suite.Less(s.GetID(), highest) + highest = s.GetID() } - suite.Greater(s.ID, "01F8MHBQCBTDKN6X5VHGMMN4MA") + suite.Greater(s.GetID(), "01F8MHBQCBTDKN6X5VHGMMN4MA") } } @@ -426,12 +436,12 @@ func (suite *GetTestSuite) TestGetXBeforeIDNoStartFromTop() { var lowest string for i, s := range statuses { if i == 0 { - lowest = s.ID + lowest = s.GetID() } else { - suite.Greater(s.ID, lowest) - lowest = s.ID + suite.Greater(s.GetID(), lowest) + lowest = s.GetID() } - suite.Greater(s.ID, "01F8MHBQCBTDKN6X5VHGMMN4MA") + suite.Greater(s.GetID(), "01F8MHBQCBTDKN6X5VHGMMN4MA") } } diff --git a/internal/timeline/index.go b/internal/timeline/index.go index 3d940af80..bda3a9c6c 100644 --- a/internal/timeline/index.go +++ b/internal/timeline/index.go @@ -23,173 +23,166 @@ import ( "context" "errors" "fmt" - "time" "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (t *timeline) IndexBefore(ctx context.Context, statusID string, include bool, amount int) error { +func (t *timeline) IndexBefore(ctx context.Context, itemID string, amount int) error { + l := logrus.WithFields(logrus.Fields{ + "func": "IndexBefore", + "amount": amount, + }) + // lazily initialize index if it hasn't been done already - if t.postIndex.data == nil { - t.postIndex.data = &list.List{} - t.postIndex.data.Init() + if t.itemIndex.data == nil { + t.itemIndex.data = &list.List{} + t.itemIndex.data.Init() } - filtered := []*gtsmodel.Status{} - offsetStatus := statusID - - if include { - // if we have the status with given statusID in the database, include it in the results set as well - s := >smodel.Status{} - if err := t.db.GetByID(ctx, statusID, s); err == nil { - filtered = append(filtered, s) - } - } + toIndex := []Timelineable{} + offsetID := itemID - i := 0 + l.Trace("entering grabloop") grabloop: - for ; len(filtered) < amount && i < 5; i++ { // try the grabloop 5 times only - statuses, err := t.db.GetHomeTimeline(ctx, t.accountID, "", "", offsetStatus, amount, false) + for i := 0; len(toIndex) < amount && i < 5; i++ { // try the grabloop 5 times only + // first grab items using the caller-provided grab function + l.Trace("grabbing...") + items, stop, err := t.grabFunction(ctx, t.accountID, "", "", offsetID, amount) if err != nil { - if err == db.ErrNoEntries { - break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail - } - return fmt.Errorf("IndexBefore: error getting statuses from db: %s", err) + return err + } + if stop { + break grabloop } - for _, s := range statuses { - timelineable, err := t.filter.StatusHometimelineable(ctx, s, t.account) + l.Trace("filtering...") + // now filter each item using the caller-provided filter function + for _, item := range items { + shouldIndex, err := t.filterFunction(ctx, t.accountID, item) if err != nil { - continue + return err } - if timelineable { - filtered = append(filtered, s) + if shouldIndex { + toIndex = append(toIndex, item) } - offsetStatus = s.ID + offsetID = item.GetID() } } + l.Trace("left grabloop") - for _, s := range filtered { - if _, err := t.IndexOne(ctx, s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil { - return fmt.Errorf("IndexBefore: error indexing status with id %s: %s", s.ID, err) + // index the items we got + for _, s := range toIndex { + if _, err := t.IndexOne(ctx, s.GetID(), s.GetBoostOfID(), s.GetAccountID(), s.GetBoostOfAccountID()); err != nil { + return fmt.Errorf("IndexBehind: error indexing item with id %s: %s", s.GetID(), err) } } return nil } -func (t *timeline) IndexBehind(ctx context.Context, statusID string, include bool, amount int) error { +func (t *timeline) IndexBehind(ctx context.Context, itemID string, amount int) error { l := logrus.WithFields(logrus.Fields{ - "func": "IndexBehind", - "include": include, - "amount": amount, + "func": "IndexBehind", + "amount": amount, }) // lazily initialize index if it hasn't been done already - if t.postIndex.data == nil { - t.postIndex.data = &list.List{} - t.postIndex.data.Init() + if t.itemIndex.data == nil { + t.itemIndex.data = &list.List{} + t.itemIndex.data.Init() } - // If we're already indexedBehind given statusID by the required amount, we can return nil. - // First find position of statusID (or as near as possible). + // If we're already indexedBehind given itemID by the required amount, we can return nil. + // First find position of itemID (or as near as possible). var position int positionLoop: - for e := t.postIndex.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*postIndexEntry) + for e := t.itemIndex.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*itemIndexEntry) if !ok { - return errors.New("IndexBehind: could not parse e as a postIndexEntry") + return errors.New("IndexBehind: could not parse e as an itemIndexEntry") } - if entry.statusID <= statusID { + if entry.itemID <= itemID { // we've found it break positionLoop } position++ } - // now check if the length of indexed posts exceeds the amount of posts required (position of statusID, plus amount of posts requested after that) - if t.postIndex.data.Len() > position+amount { + + // now check if the length of indexed items exceeds the amount of items required (position of itemID, plus amount of posts requested after that) + if t.itemIndex.data.Len() > position+amount { // we have enough indexed behind already to satisfy amount, so don't need to make db calls - l.Trace("returning nil since we already have enough posts indexed") + l.Trace("returning nil since we already have enough items indexed") return nil } - filtered := []*gtsmodel.Status{} - offsetStatus := statusID + toIndex := []Timelineable{} + offsetID := itemID - if include { - // if we have the status with given statusID in the database, include it in the results set as well - s := >smodel.Status{} - if err := t.db.GetByID(ctx, statusID, s); err == nil { - filtered = append(filtered, s) - } - } - - i := 0 + l.Trace("entering grabloop") grabloop: - for ; len(filtered) < amount && i < 5; i++ { // try the grabloop 5 times only - l.Tracef("entering grabloop; i is %d; len(filtered) is %d", i, len(filtered)) - statuses, err := t.db.GetHomeTimeline(ctx, t.accountID, offsetStatus, "", "", amount, false) + for i := 0; len(toIndex) < amount && i < 5; i++ { // try the grabloop 5 times only + // first grab items using the caller-provided grab function + l.Trace("grabbing...") + items, stop, err := t.grabFunction(ctx, t.accountID, offsetID, "", "", amount) if err != nil { - if err == db.ErrNoEntries { - break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail - } - return fmt.Errorf("IndexBehind: error getting statuses from db: %s", err) + return err + } + if stop { + break grabloop } - l.Tracef("got %d statuses", len(statuses)) - for _, s := range statuses { - timelineable, err := t.filter.StatusHometimelineable(ctx, s, t.account) + l.Trace("filtering...") + // now filter each item using the caller-provided filter function + for _, item := range items { + shouldIndex, err := t.filterFunction(ctx, t.accountID, item) if err != nil { - l.Tracef("status was not hometimelineable: %s", err) - continue + return err } - if timelineable { - filtered = append(filtered, s) + if shouldIndex { + toIndex = append(toIndex, item) } - offsetStatus = s.ID + offsetID = item.GetID() } } l.Trace("left grabloop") - for _, s := range filtered { - if _, err := t.IndexOne(ctx, s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil { - return fmt.Errorf("IndexBehind: error indexing status with id %s: %s", s.ID, err) + // index the items we got + for _, s := range toIndex { + if _, err := t.IndexOne(ctx, s.GetID(), s.GetBoostOfID(), s.GetAccountID(), s.GetBoostOfAccountID()); err != nil { + return fmt.Errorf("IndexBehind: error indexing item with id %s: %s", s.GetID(), err) } } - l.Trace("exiting function") return nil } -func (t *timeline) IndexOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { +func (t *timeline) IndexOne(ctx context.Context, itemID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { t.Lock() defer t.Unlock() - postIndexEntry := &postIndexEntry{ - statusID: statusID, + postIndexEntry := &itemIndexEntry{ + itemID: itemID, boostOfID: boostOfID, accountID: accountID, boostOfAccountID: boostOfAccountID, } - return t.postIndex.insertIndexed(postIndexEntry) + return t.itemIndex.insertIndexed(ctx, postIndexEntry) } -func (t *timeline) IndexAndPrepareOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { +func (t *timeline) IndexAndPrepareOne(ctx context.Context, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { t.Lock() defer t.Unlock() - postIndexEntry := &postIndexEntry{ - statusID: statusID, + postIndexEntry := &itemIndexEntry{ + itemID: statusID, boostOfID: boostOfID, accountID: accountID, boostOfAccountID: boostOfAccountID, } - inserted, err := t.postIndex.insertIndexed(postIndexEntry) + inserted, err := t.itemIndex.insertIndexed(ctx, postIndexEntry) if err != nil { return inserted, fmt.Errorf("IndexAndPrepareOne: error inserting indexed: %s", err) } @@ -203,32 +196,32 @@ func (t *timeline) IndexAndPrepareOne(ctx context.Context, statusCreatedAt time. return inserted, nil } -func (t *timeline) OldestIndexedPostID(ctx context.Context) (string, error) { +func (t *timeline) OldestIndexedItemID(ctx context.Context) (string, error) { var id string - if t.postIndex == nil || t.postIndex.data == nil || t.postIndex.data.Back() == nil { + if t.itemIndex == nil || t.itemIndex.data == nil || t.itemIndex.data.Back() == nil { // return an empty string if postindex hasn't been initialized yet return id, nil } - e := t.postIndex.data.Back() - entry, ok := e.Value.(*postIndexEntry) + e := t.itemIndex.data.Back() + entry, ok := e.Value.(*itemIndexEntry) if !ok { - return id, errors.New("OldestIndexedPostID: could not parse e as a postIndexEntry") + return id, errors.New("OldestIndexedItemID: could not parse e as itemIndexEntry") } - return entry.statusID, nil + return entry.itemID, nil } -func (t *timeline) NewestIndexedPostID(ctx context.Context) (string, error) { +func (t *timeline) NewestIndexedItemID(ctx context.Context) (string, error) { var id string - if t.postIndex == nil || t.postIndex.data == nil || t.postIndex.data.Front() == nil { + if t.itemIndex == nil || t.itemIndex.data == nil || t.itemIndex.data.Front() == nil { // return an empty string if postindex hasn't been initialized yet return id, nil } - e := t.postIndex.data.Front() - entry, ok := e.Value.(*postIndexEntry) + e := t.itemIndex.data.Front() + entry, ok := e.Value.(*itemIndexEntry) if !ok { - return id, errors.New("NewestIndexedPostID: could not parse e as a postIndexEntry") + return id, errors.New("NewestIndexedItemID: could not parse e as itemIndexEntry") } - return entry.statusID, nil + return entry.itemID, nil } diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go index e09421801..3f4822dcb 100644 --- a/internal/timeline/index_test.go +++ b/internal/timeline/index_test.go @@ -25,7 +25,9 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -44,11 +46,19 @@ func (suite *IndexTestSuite) SetupTest() { suite.db = testrig.NewTestDB() suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.filter = visibility.NewFilter(suite.db) testrig.StandardDBSetup(suite.db, nil) // let's take local_account_1 as the timeline owner, and start with an empty timeline - tl, err := timeline.NewTimeline(context.Background(), suite.testAccounts["local_account_1"].ID, suite.db, suite.tc) + tl, err := timeline.NewTimeline( + context.Background(), + suite.testAccounts["local_account_1"].ID, + processing.StatusGrabFunction(suite.db), + processing.StatusFilterFunction(suite.db, suite.filter), + processing.StatusPrepareFunction(suite.db, suite.tc), + processing.StatusSkipInsertFunction(), + ) if err != nil { suite.FailNow(err.Error()) } @@ -61,82 +71,82 @@ func (suite *IndexTestSuite) TearDownTest() { func (suite *IndexTestSuite) TestIndexBeforeLowID() { // index 10 before the lowest status ID possible - err := suite.timeline.IndexBefore(context.Background(), "00000000000000000000000000", true, 10) + err := suite.timeline.IndexBefore(context.Background(), "00000000000000000000000000", 10) suite.NoError(err) // the oldest indexed post should be the lowest one we have in our testrig - postID, err := suite.timeline.OldestIndexedPostID(context.Background()) + postID, err := suite.timeline.OldestIndexedItemID(context.Background()) suite.NoError(err) suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", postID) - indexLength := suite.timeline.PostIndexLength(context.Background()) + indexLength := suite.timeline.ItemIndexLength(context.Background()) suite.Equal(10, indexLength) } func (suite *IndexTestSuite) TestIndexBeforeHighID() { // index 10 before the highest status ID possible - err := suite.timeline.IndexBefore(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10) + err := suite.timeline.IndexBefore(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", 10) suite.NoError(err) // the oldest indexed post should be empty - postID, err := suite.timeline.OldestIndexedPostID(context.Background()) + postID, err := suite.timeline.OldestIndexedItemID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength(context.Background()) + indexLength := suite.timeline.ItemIndexLength(context.Background()) suite.Equal(0, indexLength) } func (suite *IndexTestSuite) TestIndexBehindHighID() { // index 10 behind the highest status ID possible - err := suite.timeline.IndexBehind(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10) + err := suite.timeline.IndexBehind(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", 10) suite.NoError(err) // the newest indexed post should be the highest one we have in our testrig - postID, err := suite.timeline.NewestIndexedPostID(context.Background()) + postID, err := suite.timeline.NewestIndexedItemID(context.Background()) suite.NoError(err) suite.Equal("01FN3VJGFH10KR7S2PB0GFJZYG", postID) // indexLength should be 10 because that's all this user has hometimelineable - indexLength := suite.timeline.PostIndexLength(context.Background()) + indexLength := suite.timeline.ItemIndexLength(context.Background()) suite.Equal(10, indexLength) } func (suite *IndexTestSuite) TestIndexBehindLowID() { // index 10 behind the lowest status ID possible - err := suite.timeline.IndexBehind(context.Background(), "00000000000000000000000000", true, 10) + err := suite.timeline.IndexBehind(context.Background(), "00000000000000000000000000", 10) suite.NoError(err) // the newest indexed post should be empty - postID, err := suite.timeline.NewestIndexedPostID(context.Background()) + postID, err := suite.timeline.NewestIndexedItemID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength(context.Background()) + indexLength := suite.timeline.ItemIndexLength(context.Background()) suite.Equal(0, indexLength) } -func (suite *IndexTestSuite) TestOldestIndexedPostIDEmpty() { +func (suite *IndexTestSuite) TestOldestIndexedItemIDEmpty() { // the oldest indexed post should be an empty string since there's nothing indexed yet - postID, err := suite.timeline.OldestIndexedPostID(context.Background()) + postID, err := suite.timeline.OldestIndexedItemID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength(context.Background()) + indexLength := suite.timeline.ItemIndexLength(context.Background()) suite.Equal(0, indexLength) } -func (suite *IndexTestSuite) TestNewestIndexedPostIDEmpty() { +func (suite *IndexTestSuite) TestNewestIndexedItemIDEmpty() { // the newest indexed post should be an empty string since there's nothing indexed yet - postID, err := suite.timeline.NewestIndexedPostID(context.Background()) + postID, err := suite.timeline.NewestIndexedItemID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength(context.Background()) + indexLength := suite.timeline.ItemIndexLength(context.Background()) suite.Equal(0, indexLength) } @@ -144,12 +154,12 @@ func (suite *IndexTestSuite) TestIndexAlreadyIndexed() { testStatus := suite.testStatuses["local_account_1_status_1"] // index one post -- it should be indexed - indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.True(indexed) // try to index the same post again -- it should not be indexed - indexed, err = suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err = suite.timeline.IndexOne(context.Background(), testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.False(indexed) } @@ -158,12 +168,12 @@ func (suite *IndexTestSuite) TestIndexAndPrepareAlreadyIndexedAndPrepared() { testStatus := suite.testStatuses["local_account_1_status_1"] // index and prepare one post -- it should be indexed - indexed, err := suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err := suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.True(indexed) // try to index and prepare the same post again -- it should not be indexed - indexed, err = suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err = suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.False(indexed) } @@ -179,12 +189,12 @@ func (suite *IndexTestSuite) TestIndexBoostOfAlreadyIndexed() { } // index one post -- it should be indexed - indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.True(indexed) // try to index the a boost of that post -- it should not be indexed - indexed, err = suite.timeline.IndexOne(context.Background(), boostOfTestStatus.CreatedAt, boostOfTestStatus.ID, boostOfTestStatus.BoostOfID, boostOfTestStatus.AccountID, boostOfTestStatus.BoostOfAccountID) + indexed, err = suite.timeline.IndexOne(context.Background(), boostOfTestStatus.ID, boostOfTestStatus.BoostOfID, boostOfTestStatus.AccountID, boostOfTestStatus.BoostOfAccountID) suite.NoError(err) suite.False(indexed) } diff --git a/internal/timeline/postindex.go b/internal/timeline/itemindex.go index 5fe795d5a..968650e07 100644 --- a/internal/timeline/postindex.go +++ b/internal/timeline/itemindex.go @@ -20,21 +20,23 @@ package timeline import ( "container/list" + "context" "errors" ) -type postIndex struct { - data *list.List +type itemIndex struct { + data *list.List + skipInsert SkipInsertFunction } -type postIndexEntry struct { - statusID string +type itemIndexEntry struct { + itemID string boostOfID string accountID string boostOfAccountID string } -func (p *postIndex) insertIndexed(i *postIndexEntry) (bool, error) { +func (p *itemIndex) insertIndexed(ctx context.Context, i *itemIndexEntry) (bool, error) { if p.data == nil { p.data = &list.List{} } @@ -47,36 +49,30 @@ func (p *postIndex) insertIndexed(i *postIndexEntry) (bool, error) { var insertMark *list.Element var position int - // We need to iterate through the index to make sure we put this post in the appropriate place according to when it was created. - // We also need to make sure we're not inserting a duplicate post -- this can happen sometimes and it's not nice UX (*shudder*). + // We need to iterate through the index to make sure we put this item in the appropriate place according to when it was created. + // We also need to make sure we're not inserting a duplicate item -- this can happen sometimes and it's not nice UX (*shudder*). for e := p.data.Front(); e != nil; e = e.Next() { position++ - entry, ok := e.Value.(*postIndexEntry) + entry, ok := e.Value.(*itemIndexEntry) if !ok { - return false, errors.New("index: could not parse e as a postIndexEntry") + return false, errors.New("index: could not parse e as an itemIndexEntry") } - // don't insert this if it's a boost of a status we've seen recently - if i.boostOfID != "" { - if i.boostOfID == entry.boostOfID || i.boostOfID == entry.statusID { - if position < boostReinsertionDepth { - return false, nil - } - } + skip, err := p.skipInsert(ctx, i.itemID, i.accountID, i.boostOfID, i.boostOfAccountID, entry.itemID, entry.accountID, entry.boostOfID, entry.boostOfAccountID, position) + if err != nil { + return false, err + } + if skip { + return false, nil } - // if the post to index is newer than e, insert it before e in the list + // if the item to index is newer than e, insert it before e in the list if insertMark == nil { - if i.statusID > entry.statusID { + if i.itemID > entry.itemID { insertMark = e } } - - // make sure we don't insert a duplicate - if entry.statusID == i.statusID { - return false, nil - } } if insertMark != nil { @@ -84,7 +80,7 @@ func (p *postIndex) insertIndexed(i *postIndexEntry) (bool, error) { return true, nil } - // if we reach this point it's the oldest post we've seen so put it at the back + // if we reach this point it's the oldest item we've seen so put it at the back p.data.PushBack(i) return true, nil } diff --git a/internal/timeline/manager.go b/internal/timeline/manager.go index 5aa74ef91..02a388aba 100644 --- a/internal/timeline/manager.go +++ b/internal/timeline/manager.go @@ -25,10 +25,6 @@ import ( "sync" "github.com/sirupsen/logrus" - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/typeutils" ) const ( @@ -37,71 +33,75 @@ const ( // Manager abstracts functions for creating timelines for multiple accounts, and adding, removing, and fetching entries from those timelines. // -// By the time a status hits the manager interface, it should already have been filtered and it should be established that the status indeed -// belongs in the home timeline of the given account ID. +// By the time a timelineable hits the manager interface, it should already have been filtered and it should be established that the item indeed +// belongs in the timeline of the given account ID. // -// The manager makes a distinction between *indexed* posts and *prepared* posts. +// The manager makes a distinction between *indexed* items and *prepared* items. // -// Indexed posts consist of just that post's ID (in the database) and the time it was created. An indexed post takes up very little memory, so -// it's not a huge priority to keep trimming the indexed posts list. +// Indexed items consist of just that item's ID (in the database) and the time it was created. An indexed item takes up very little memory, so +// it's not a huge priority to keep trimming the indexed items list. // -// Prepared posts consist of the post's database ID, the time it was created, AND the apimodel representation of that post, for quick serialization. -// Prepared posts of course take up more memory than indexed posts, so they should be regularly pruned if they're not being actively served. +// Prepared items consist of the item's database ID, the time it was created, AND the apimodel representation of that item, for quick serialization. +// Prepared items of course take up more memory than indexed items, so they should be regularly pruned if they're not being actively served. type Manager interface { - // Ingest takes one status and indexes it into the timeline for the given account ID. + // Ingest takes one item and indexes it into the timeline for the given account ID. // - // It should already be established before calling this function that the status/post actually belongs in the timeline! + // It should already be established before calling this function that the item actually belongs in the timeline! // - // The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where - // the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline. - Ingest(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) - // IngestAndPrepare takes one status and indexes it into the timeline for the given account ID, and then immediately prepares it for serving. - // This is useful in cases where we know the status will need to be shown at the top of a user's timeline immediately (eg., a new status is created). + // The returned bool indicates whether the item was actually put in the timeline. This could be false in cases where + // the item is a boosted status, but a boost of the original status or the status itself already exists recently in the timeline. + Ingest(ctx context.Context, item Timelineable, timelineAccountID string) (bool, error) + // IngestAndPrepare takes one timelineable and indexes it into the timeline for the given account ID, and then immediately prepares it for serving. + // This is useful in cases where we know the item will need to be shown at the top of a user's timeline immediately (eg., a new status is created). // - // It should already be established before calling this function that the status/post actually belongs in the timeline! + // It should already be established before calling this function that the item actually belongs in the timeline! // - // The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where - // the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline. - IngestAndPrepare(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) - // HomeTimeline returns limit n amount of entries from the home timeline of the given account ID, in descending chronological order. - // If maxID is provided, it will return entries from that maxID onwards, inclusive. - HomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) - // GetIndexedLength returns the amount of posts/statuses that have been *indexed* for the given account ID. + // The returned bool indicates whether the item was actually put in the timeline. This could be false in cases where + // a status is a boost, but a boost of the original status or the status itself already exists recently in the timeline. + IngestAndPrepare(ctx context.Context, item Timelineable, timelineAccountID string) (bool, error) + // GetTimeline returns limit n amount of prepared entries from the timeline of the given account ID, in descending chronological order. + // If maxID is provided, it will return prepared entries from that maxID onwards, inclusive. + GetTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]Preparable, error) + // GetIndexedLength returns the amount of items that have been *indexed* for the given account ID. GetIndexedLength(ctx context.Context, timelineAccountID string) int - // GetDesiredIndexLength returns the amount of posts that we, ideally, index for each user. + // GetDesiredIndexLength returns the amount of items that we, ideally, index for each user. GetDesiredIndexLength(ctx context.Context) int - // GetOldestIndexedID returns the status ID for the oldest post that we have indexed for the given account. + // GetOldestIndexedID returns the id ID for the oldest item that we have indexed for the given account. GetOldestIndexedID(ctx context.Context, timelineAccountID string) (string, error) - // PrepareXFromTop prepares limit n amount of posts, based on their indexed representations, from the top of the index. + // PrepareXFromTop prepares limit n amount of items, based on their indexed representations, from the top of the index. PrepareXFromTop(ctx context.Context, timelineAccountID string, limit int) error - // Remove removes one status from the timeline of the given timelineAccountID - Remove(ctx context.Context, timelineAccountID string, statusID string) (int, error) - // WipeStatusFromAllTimelines removes one status from the index and prepared posts of all timelines - WipeStatusFromAllTimelines(ctx context.Context, statusID string) error - // WipeStatusesFromAccountID removes all statuses by the given accountID from the timelineAccountID's timelines. - WipeStatusesFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error + // Remove removes one item from the timeline of the given timelineAccountID + Remove(ctx context.Context, timelineAccountID string, itemID string) (int, error) + // WipeItemFromAllTimelines removes one item from the index and prepared items of all timelines + WipeItemFromAllTimelines(ctx context.Context, itemID string) error + // WipeStatusesFromAccountID removes all items by the given accountID from the timelineAccountID's timelines. + WipeItemsFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error } -// NewManager returns a new timeline manager with the given database, typeconverter, config, and log. -func NewManager(db db.DB, tc typeutils.TypeConverter) Manager { +// NewManager returns a new timeline manager. +func NewManager(grabFunction GrabFunction, filterFunction FilterFunction, prepareFunction PrepareFunction, skipInsertFunction SkipInsertFunction) Manager { return &manager{ - accountTimelines: sync.Map{}, - db: db, - tc: tc, + accountTimelines: sync.Map{}, + grabFunction: grabFunction, + filterFunction: filterFunction, + prepareFunction: prepareFunction, + skipInsertFunction: skipInsertFunction, } } type manager struct { - accountTimelines sync.Map - db db.DB - tc typeutils.TypeConverter + accountTimelines sync.Map + grabFunction GrabFunction + filterFunction FilterFunction + prepareFunction PrepareFunction + skipInsertFunction SkipInsertFunction } -func (m *manager) Ingest(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) { +func (m *manager) Ingest(ctx context.Context, item Timelineable, timelineAccountID string) (bool, error) { l := logrus.WithFields(logrus.Fields{ "func": "Ingest", "timelineAccountID": timelineAccountID, - "statusID": status.ID, + "itemID": item.GetID(), }) t, err := m.getOrCreateTimeline(ctx, timelineAccountID) @@ -109,15 +109,15 @@ func (m *manager) Ingest(ctx context.Context, status *gtsmodel.Status, timelineA return false, err } - l.Trace("ingesting status") - return t.IndexOne(ctx, status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID) + l.Trace("ingesting item") + return t.IndexOne(ctx, item.GetID(), item.GetBoostOfID(), item.GetAccountID(), item.GetBoostOfAccountID()) } -func (m *manager) IngestAndPrepare(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) { +func (m *manager) IngestAndPrepare(ctx context.Context, item Timelineable, timelineAccountID string) (bool, error) { l := logrus.WithFields(logrus.Fields{ "func": "IngestAndPrepare", "timelineAccountID": timelineAccountID, - "statusID": status.ID, + "itemID": item.GetID(), }) t, err := m.getOrCreateTimeline(ctx, timelineAccountID) @@ -125,15 +125,15 @@ func (m *manager) IngestAndPrepare(ctx context.Context, status *gtsmodel.Status, return false, err } - l.Trace("ingesting status") - return t.IndexAndPrepareOne(ctx, status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID) + l.Trace("ingesting item") + return t.IndexAndPrepareOne(ctx, item.GetID(), item.GetBoostOfID(), item.GetAccountID(), item.GetBoostOfAccountID()) } -func (m *manager) Remove(ctx context.Context, timelineAccountID string, statusID string) (int, error) { +func (m *manager) Remove(ctx context.Context, timelineAccountID string, itemID string) (int, error) { l := logrus.WithFields(logrus.Fields{ "func": "Remove", "timelineAccountID": timelineAccountID, - "statusID": statusID, + "itemID": itemID, }) t, err := m.getOrCreateTimeline(ctx, timelineAccountID) @@ -141,13 +141,13 @@ func (m *manager) Remove(ctx context.Context, timelineAccountID string, statusID return 0, err } - l.Trace("removing status") - return t.Remove(ctx, statusID) + l.Trace("removing item") + return t.Remove(ctx, itemID) } -func (m *manager) HomeTimeline(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) { +func (m *manager) GetTimeline(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]Preparable, error) { l := logrus.WithFields(logrus.Fields{ - "func": "HomeTimelineGet", + "func": "GetTimeline", "timelineAccountID": timelineAccountID, }) @@ -156,11 +156,11 @@ func (m *manager) HomeTimeline(ctx context.Context, timelineAccountID string, ma return nil, err } - statuses, err := t.Get(ctx, limit, maxID, sinceID, minID, true) + items, err := t.Get(ctx, limit, maxID, sinceID, minID, true) if err != nil { l.Errorf("error getting statuses: %s", err) } - return statuses, nil + return items, nil } func (m *manager) GetIndexedLength(ctx context.Context, timelineAccountID string) int { @@ -169,7 +169,7 @@ func (m *manager) GetIndexedLength(ctx context.Context, timelineAccountID string return 0 } - return t.PostIndexLength(ctx) + return t.ItemIndexLength(ctx) } func (m *manager) GetDesiredIndexLength(ctx context.Context) int { @@ -182,7 +182,7 @@ func (m *manager) GetOldestIndexedID(ctx context.Context, timelineAccountID stri return "", err } - return t.OldestIndexedPostID(ctx) + return t.OldestIndexedItemID(ctx) } func (m *manager) PrepareXFromTop(ctx context.Context, timelineAccountID string, limit int) error { @@ -194,7 +194,7 @@ func (m *manager) PrepareXFromTop(ctx context.Context, timelineAccountID string, return t.PrepareFromTop(ctx, limit) } -func (m *manager) WipeStatusFromAllTimelines(ctx context.Context, statusID string) error { +func (m *manager) WipeItemFromAllTimelines(ctx context.Context, statusID string) error { errors := []string{} m.accountTimelines.Range(func(k interface{}, i interface{}) bool { t, ok := i.(Timeline) @@ -217,7 +217,7 @@ func (m *manager) WipeStatusFromAllTimelines(ctx context.Context, statusID strin return err } -func (m *manager) WipeStatusesFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error { +func (m *manager) WipeItemsFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error { t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return err @@ -232,7 +232,7 @@ func (m *manager) getOrCreateTimeline(ctx context.Context, timelineAccountID str i, ok := m.accountTimelines.Load(timelineAccountID) if !ok { var err error - t, err = NewTimeline(ctx, timelineAccountID, m.db, m.tc) + t, err = NewTimeline(ctx, timelineAccountID, m.grabFunction, m.filterFunction, m.prepareFunction, m.skipInsertFunction) if err != nil { return nil, err } diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go index 4e81bfe17..11b5dfb78 100644 --- a/internal/timeline/manager_test.go +++ b/internal/timeline/manager_test.go @@ -23,6 +23,9 @@ import ( "testing" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -41,10 +44,16 @@ func (suite *ManagerTestSuite) SetupTest() { suite.db = testrig.NewTestDB() suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.filter = visibility.NewFilter(suite.db) testrig.StandardDBSetup(suite.db, nil) - manager := testrig.NewTestTimelineManager(suite.db) + manager := timeline.NewManager( + processing.StatusGrabFunction(suite.db), + processing.StatusFilterFunction(suite.db, suite.filter), + processing.StatusPrepareFunction(suite.db, suite.tc), + processing.StatusSkipInsertFunction(), + ) suite.manager = manager } @@ -78,12 +87,12 @@ func (suite *ManagerTestSuite) TestManagerIntegration() { suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", oldestIndexed) // get hometimeline - statuses, err := suite.manager.HomeTimeline(context.Background(), testAccount.ID, "", "", "", 20, false) + statuses, err := suite.manager.GetTimeline(context.Background(), testAccount.ID, "", "", "", 20, false) suite.NoError(err) suite.Len(statuses, 14) // now wipe the last status from all timelines, as though it had been deleted by the owner - err = suite.manager.WipeStatusFromAllTimelines(context.Background(), "01F8MH75CBF9JFX4ZAD54N0W0R") + err = suite.manager.WipeItemFromAllTimelines(context.Background(), "01F8MH75CBF9JFX4ZAD54N0W0R") suite.NoError(err) // timeline should be shorter @@ -110,7 +119,7 @@ func (suite *ManagerTestSuite) TestManagerIntegration() { suite.Equal("01F8MHAAY43M6RJ473VQFCVH37", oldestIndexed) // now remove all entries by local_account_2 from the timeline - err = suite.manager.WipeStatusesFromAccountID(context.Background(), testAccount.ID, suite.testAccounts["local_account_2"].ID) + err = suite.manager.WipeItemsFromAccountID(context.Background(), testAccount.ID, suite.testAccounts["local_account_2"].ID) suite.NoError(err) // timeline should be shorter diff --git a/internal/timeline/preparable.go b/internal/timeline/preparable.go new file mode 100644 index 000000000..c38acb450 --- /dev/null +++ b/internal/timeline/preparable.go @@ -0,0 +1,26 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package timeline + +type Preparable interface { + GetID() string + GetAccountID() string + GetBoostOfID() string + GetBoostOfAccountID() string +} diff --git a/internal/timeline/prepare.go b/internal/timeline/prepare.go index 4ddaad8a5..dae9031e5 100644 --- a/internal/timeline/prepare.go +++ b/internal/timeline/prepare.go @@ -26,7 +26,6 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) func (t *timeline) prepareNextQuery(ctx context.Context, amount int, maxID string, sinceID string, minID string) error { @@ -59,19 +58,19 @@ func (t *timeline) prepareNextQuery(ctx context.Context, amount int, maxID strin return err } -func (t *timeline) PrepareBehind(ctx context.Context, statusID string, amount int) error { - // lazily initialize prepared posts if it hasn't been done already - if t.preparedPosts.data == nil { - t.preparedPosts.data = &list.List{} - t.preparedPosts.data.Init() +func (t *timeline) PrepareBehind(ctx context.Context, itemID string, amount int) error { + // lazily initialize prepared items if it hasn't been done already + if t.preparedItems.data == nil { + t.preparedItems.data = &list.List{} + t.preparedItems.data.Init() } - if err := t.IndexBehind(ctx, statusID, true, amount); err != nil { - return fmt.Errorf("PrepareBehind: error indexing behind id %s: %s", statusID, err) + if err := t.IndexBehind(ctx, itemID, amount); err != nil { + return fmt.Errorf("PrepareBehind: error indexing behind id %s: %s", itemID, err) } - // if the postindex is nil, nothing has been indexed yet so there's nothing to prepare - if t.postIndex.data == nil { + // if the itemindex is nil, nothing has been indexed yet so there's nothing to prepare + if t.itemIndex.data == nil { return nil } @@ -80,25 +79,25 @@ func (t *timeline) PrepareBehind(ctx context.Context, statusID string, amount in t.Lock() defer t.Unlock() prepareloop: - for e := t.postIndex.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*postIndexEntry) + for e := t.itemIndex.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*itemIndexEntry) if !ok { - return errors.New("PrepareBehind: could not parse e as a postIndexEntry") + return errors.New("PrepareBehind: could not parse e as itemIndexEntry") } if !preparing { // we haven't hit the position we need to prepare from yet - if entry.statusID == statusID { + if entry.itemID == itemID { preparing = true } } if preparing { - if err := t.prepare(ctx, entry.statusID); err != nil { + if err := t.prepare(ctx, entry.itemID); err != nil { // there's been an error if err != db.ErrNoEntries { // it's a real error - return fmt.Errorf("PrepareBehind: error preparing status with id %s: %s", entry.statusID, err) + return fmt.Errorf("PrepareBehind: error preparing item with id %s: %s", entry.itemID, err) } // the status just doesn't exist (anymore) so continue to the next one continue @@ -119,28 +118,28 @@ func (t *timeline) PrepareBefore(ctx context.Context, statusID string, include b defer t.Unlock() // lazily initialize prepared posts if it hasn't been done already - if t.preparedPosts.data == nil { - t.preparedPosts.data = &list.List{} - t.preparedPosts.data.Init() + if t.preparedItems.data == nil { + t.preparedItems.data = &list.List{} + t.preparedItems.data.Init() } // if the postindex is nil, nothing has been indexed yet so there's nothing to prepare - if t.postIndex.data == nil { + if t.itemIndex.data == nil { return nil } var prepared int var preparing bool prepareloop: - for e := t.postIndex.data.Back(); e != nil; e = e.Prev() { - entry, ok := e.Value.(*postIndexEntry) + for e := t.itemIndex.data.Back(); e != nil; e = e.Prev() { + entry, ok := e.Value.(*itemIndexEntry) if !ok { return errors.New("PrepareBefore: could not parse e as a postIndexEntry") } if !preparing { // we haven't hit the position we need to prepare from yet - if entry.statusID == statusID { + if entry.itemID == statusID { preparing = true if !include { continue @@ -149,11 +148,11 @@ prepareloop: } if preparing { - if err := t.prepare(ctx, entry.statusID); err != nil { + if err := t.prepare(ctx, entry.itemID); err != nil { // there's been an error if err != db.ErrNoEntries { // it's a real error - return fmt.Errorf("PrepareBefore: error preparing status with id %s: %s", entry.statusID, err) + return fmt.Errorf("PrepareBefore: error preparing status with id %s: %s", entry.itemID, err) } // the status just doesn't exist (anymore) so continue to the next one continue @@ -176,15 +175,15 @@ func (t *timeline) PrepareFromTop(ctx context.Context, amount int) error { }) // lazily initialize prepared posts if it hasn't been done already - if t.preparedPosts.data == nil { - t.preparedPosts.data = &list.List{} - t.preparedPosts.data.Init() + if t.preparedItems.data == nil { + t.preparedItems.data = &list.List{} + t.preparedItems.data.Init() } // if the postindex is nil, nothing has been indexed yet so index from the highest ID possible - if t.postIndex.data == nil { + if t.itemIndex.data == nil { l.Debug("postindex.data was nil, indexing behind highest possible ID") - if err := t.IndexBehind(ctx, "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", false, amount); err != nil { + if err := t.IndexBehind(ctx, "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", amount); err != nil { return fmt.Errorf("PrepareFromTop: error indexing behind id %s: %s", "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", err) } } @@ -194,21 +193,21 @@ func (t *timeline) PrepareFromTop(ctx context.Context, amount int) error { defer t.Unlock() var prepared int prepareloop: - for e := t.postIndex.data.Front(); e != nil; e = e.Next() { + for e := t.itemIndex.data.Front(); e != nil; e = e.Next() { if e == nil { continue } - entry, ok := e.Value.(*postIndexEntry) + entry, ok := e.Value.(*itemIndexEntry) if !ok { return errors.New("PrepareFromTop: could not parse e as a postIndexEntry") } - if err := t.prepare(ctx, entry.statusID); err != nil { + if err := t.prepare(ctx, entry.itemID); err != nil { // there's been an error if err != db.ErrNoEntries { // it's a real error - return fmt.Errorf("PrepareFromTop: error preparing status with id %s: %s", entry.statusID, err) + return fmt.Errorf("PrepareFromTop: error preparing status with id %s: %s", entry.itemID, err) } // the status just doesn't exist (anymore) so continue to the next one continue @@ -226,57 +225,42 @@ prepareloop: return nil } -func (t *timeline) prepare(ctx context.Context, statusID string) error { - - // start by getting the status out of the database according to its indexed ID - gtsStatus := >smodel.Status{} - if err := t.db.GetByID(ctx, statusID, gtsStatus); err != nil { - return err - } - - // if the account pointer hasn't been set on this timeline already, set it lazily here - if t.account == nil { - timelineOwnerAccount := >smodel.Account{} - if err := t.db.GetByID(ctx, t.accountID, timelineOwnerAccount); err != nil { - return err - } - t.account = timelineOwnerAccount - } - - // serialize the status (or, at least, convert it to a form that's ready to be serialized) - apiModelStatus, err := t.tc.StatusToAPIStatus(ctx, gtsStatus, t.account) +func (t *timeline) prepare(ctx context.Context, itemID string) error { + // trigger the caller-provided prepare function + prepared, err := t.prepareFunction(ctx, t.accountID, itemID) if err != nil { return err } - // shove it in prepared posts as a prepared posts entry - preparedPostsEntry := &preparedPostsEntry{ - statusID: gtsStatus.ID, - boostOfID: gtsStatus.BoostOfID, - accountID: gtsStatus.AccountID, - boostOfAccountID: gtsStatus.BoostOfAccountID, - prepared: apiModelStatus, + // shove it in prepared items as a prepared items entry + preparedItemsEntry := &preparedItemsEntry{ + itemID: prepared.GetID(), + boostOfID: prepared.GetBoostOfID(), + accountID: prepared.GetAccountID(), + boostOfAccountID: prepared.GetBoostOfAccountID(), + prepared: prepared, } - return t.preparedPosts.insertPrepared(preparedPostsEntry) + return t.preparedItems.insertPrepared(ctx, preparedItemsEntry) } -func (t *timeline) OldestPreparedPostID(ctx context.Context) (string, error) { +func (t *timeline) OldestPreparedItemID(ctx context.Context) (string, error) { var id string - if t.preparedPosts == nil || t.preparedPosts.data == nil { - // return an empty string if prepared posts hasn't been initialized yet + if t.preparedItems == nil || t.preparedItems.data == nil { + // return an empty string if prepared items hasn't been initialized yet return id, nil } - e := t.preparedPosts.data.Back() + e := t.preparedItems.data.Back() if e == nil { // return an empty string if there's no back entry (ie., the index list hasn't been initialized yet) return id, nil } - entry, ok := e.Value.(*preparedPostsEntry) + entry, ok := e.Value.(*preparedItemsEntry) if !ok { - return id, errors.New("OldestPreparedPostID: could not parse e as a preparedPostsEntry") + return id, errors.New("OldestPreparedItemID: could not parse e as a preparedItemsEntry") } - return entry.statusID, nil + + return entry.itemID, nil } diff --git a/internal/timeline/preparedposts.go b/internal/timeline/prepareditems.go index 54e0e61f3..07a8c69ee 100644 --- a/internal/timeline/preparedposts.go +++ b/internal/timeline/prepareditems.go @@ -20,24 +20,24 @@ package timeline import ( "container/list" + "context" "errors" - - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" ) -type preparedPosts struct { - data *list.List +type preparedItems struct { + data *list.List + skipInsert SkipInsertFunction } -type preparedPostsEntry struct { - statusID string +type preparedItemsEntry struct { + itemID string boostOfID string accountID string boostOfAccountID string - prepared *apimodel.Status + prepared Preparable } -func (p *preparedPosts) insertPrepared(i *preparedPostsEntry) error { +func (p *preparedItems) insertPrepared(ctx context.Context, i *preparedItemsEntry) error { if p.data == nil { p.data = &list.List{} } @@ -55,35 +55,28 @@ func (p *preparedPosts) insertPrepared(i *preparedPostsEntry) error { for e := p.data.Front(); e != nil; e = e.Next() { position++ - entry, ok := e.Value.(*preparedPostsEntry) + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return errors.New("index: could not parse e as a preparedPostsEntry") } - // don't insert this if it's a boost of a status we've seen recently - if i.prepared.Reblog != nil { - if entry.prepared.Reblog != nil && i.prepared.Reblog.ID == entry.prepared.Reblog.ID { - if position < boostReinsertionDepth { - return nil - } - } - - if i.prepared.Reblog.ID == entry.statusID { - if position < boostReinsertionDepth { - return nil - } - } + skip, err := p.skipInsert(ctx, i.itemID, i.accountID, i.boostOfID, i.boostOfAccountID, entry.itemID, entry.accountID, entry.boostOfID, entry.boostOfAccountID, position) + if err != nil { + return err + } + if skip { + return nil } // if the post to index is newer than e, insert it before e in the list if insertMark == nil { - if i.statusID > entry.statusID { + if i.itemID > entry.itemID { insertMark = e } } // make sure we don't insert a duplicate - if entry.statusID == i.statusID { + if entry.itemID == i.itemID { return nil } } diff --git a/internal/timeline/remove.go b/internal/timeline/remove.go index 833d7126a..60d8108ec 100644 --- a/internal/timeline/remove.go +++ b/internal/timeline/remove.go @@ -38,39 +38,39 @@ func (t *timeline) Remove(ctx context.Context, statusID string) (int, error) { // remove entr(ies) from the post index removeIndexes := []*list.Element{} - if t.postIndex != nil && t.postIndex.data != nil { - for e := t.postIndex.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*postIndexEntry) + if t.itemIndex != nil && t.itemIndex.data != nil { + for e := t.itemIndex.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*itemIndexEntry) if !ok { return removed, errors.New("Remove: could not parse e as a postIndexEntry") } - if entry.statusID == statusID { + if entry.itemID == statusID { l.Debug("found status in postIndex") removeIndexes = append(removeIndexes, e) } } } for _, e := range removeIndexes { - t.postIndex.data.Remove(e) + t.itemIndex.data.Remove(e) removed++ } // remove entr(ies) from prepared posts removePrepared := []*list.Element{} - if t.preparedPosts != nil && t.preparedPosts.data != nil { - for e := t.preparedPosts.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*preparedPostsEntry) + if t.preparedItems != nil && t.preparedItems.data != nil { + for e := t.preparedItems.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return removed, errors.New("Remove: could not parse e as a preparedPostsEntry") } - if entry.statusID == statusID { + if entry.itemID == statusID { l.Debug("found status in preparedPosts") removePrepared = append(removePrepared, e) } } } for _, e := range removePrepared { - t.preparedPosts.data.Remove(e) + t.preparedItems.data.Remove(e) removed++ } @@ -90,9 +90,9 @@ func (t *timeline) RemoveAllBy(ctx context.Context, accountID string) (int, erro // remove entr(ies) from the post index removeIndexes := []*list.Element{} - if t.postIndex != nil && t.postIndex.data != nil { - for e := t.postIndex.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*postIndexEntry) + if t.itemIndex != nil && t.itemIndex.data != nil { + for e := t.itemIndex.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*itemIndexEntry) if !ok { return removed, errors.New("Remove: could not parse e as a postIndexEntry") } @@ -103,15 +103,15 @@ func (t *timeline) RemoveAllBy(ctx context.Context, accountID string) (int, erro } } for _, e := range removeIndexes { - t.postIndex.data.Remove(e) + t.itemIndex.data.Remove(e) removed++ } // remove entr(ies) from prepared posts removePrepared := []*list.Element{} - if t.preparedPosts != nil && t.preparedPosts.data != nil { - for e := t.preparedPosts.data.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*preparedPostsEntry) + if t.preparedItems != nil && t.preparedItems.data != nil { + for e := t.preparedItems.data.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*preparedItemsEntry) if !ok { return removed, errors.New("Remove: could not parse e as a preparedPostsEntry") } @@ -122,7 +122,7 @@ func (t *timeline) RemoveAllBy(ctx context.Context, accountID string) (int, erro } } for _, e := range removePrepared { - t.preparedPosts.data.Remove(e) + t.preparedItems.data.Remove(e) removed++ } diff --git a/internal/timeline/timeline.go b/internal/timeline/timeline.go index fc4875d1a..1d82914f8 100644 --- a/internal/timeline/timeline.go +++ b/internal/timeline/timeline.go @@ -21,104 +21,135 @@ package timeline import ( "context" "sync" - "time" - - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/visibility" ) -const boostReinsertionDepth = 50 - -// Timeline represents a timeline for one account, and contains indexed and prepared posts. +// GrabFunction is used by a Timeline to grab more items to index. +// +// It should be provided to NewTimeline when the caller is creating a timeline +// (of statuses, notifications, etc). +// +// timelineAccountID: the owner of the timeline +// maxID: the maximum item ID desired. +// sinceID: the minimum item ID desired. +// minID: see sinceID +// limit: the maximum amount of items to be returned +// +// If an error is returned, the timeline will stop processing whatever request called GrabFunction, +// and return the error. If no error is returned, but stop = true, this indicates to the caller of GrabFunction +// that there are no more items to return, and processing should continue with the items already grabbed. +type GrabFunction func(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int) (items []Timelineable, stop bool, err error) + +// FilterFunction is used by a Timeline to filter whether or not a grabbed item should be indexed. +type FilterFunction func(ctx context.Context, timelineAccountID string, item Timelineable) (shouldIndex bool, err error) + +// PrepareFunction converts a Timelineable into a Preparable. +// +// For example, this might result in the converstion of a *gtsmodel.Status with the given itemID into a serializable *apimodel.Status. +type PrepareFunction func(ctx context.Context, timelineAccountID string, itemID string) (Preparable, error) + +// SkipInsertFunction indicates whether a new item about to be inserted in the prepared list should be skipped, +// based on the item itself, the next item in the timeline, and the depth at which nextItem has been found in the list. +// +// This will be called for every item found while iterating through a timeline, so callers should be very careful +// not to do anything expensive here. +type SkipInsertFunction func(ctx context.Context, + newItemID string, + newItemAccountID string, + newItemBoostOfID string, + newItemBoostOfAccountID string, + nextItemID string, + nextItemAccountID string, + nextItemBoostOfID string, + nextItemBoostOfAccountID string, + depth int) (bool, error) + +// Timeline represents a timeline for one account, and contains indexed and prepared items. type Timeline interface { /* RETRIEVAL FUNCTIONS */ - // Get returns an amount of statuses with the given parameters. + // Get returns an amount of prepared items with the given parameters. // If prepareNext is true, then the next predicted query will be prepared already in a goroutine, // to make the next call to Get faster. - Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) - // GetXFromTop returns x amount of posts from the top of the timeline, from newest to oldest. - GetXFromTop(ctx context.Context, amount int) ([]*apimodel.Status, error) - // GetXBehindID returns x amount of posts from the given id onwards, from newest to oldest. - // This will NOT include the status with the given ID. + Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]Preparable, error) + // GetXFromTop returns x amount of items from the top of the timeline, from newest to oldest. + GetXFromTop(ctx context.Context, amount int) ([]Preparable, error) + // GetXBehindID returns x amount of items from the given id onwards, from newest to oldest. + // This will NOT include the item with the given ID. // // This corresponds to an api call to /timelines/home?max_id=WHATEVER - GetXBehindID(ctx context.Context, amount int, fromID string, attempts *int) ([]*apimodel.Status, error) - // GetXBeforeID returns x amount of posts up to the given id, from newest to oldest. - // This will NOT include the status with the given ID. + GetXBehindID(ctx context.Context, amount int, fromID string, attempts *int) ([]Preparable, error) + // GetXBeforeID returns x amount of items up to the given id, from newest to oldest. + // This will NOT include the item with the given ID. // // This corresponds to an api call to /timelines/home?since_id=WHATEVER - GetXBeforeID(ctx context.Context, amount int, sinceID string, startFromTop bool) ([]*apimodel.Status, error) - // GetXBetweenID returns x amount of posts from the given maxID, up to the given id, from newest to oldest. - // This will NOT include the status with the given IDs. + GetXBeforeID(ctx context.Context, amount int, sinceID string, startFromTop bool) ([]Preparable, error) + // GetXBetweenID returns x amount of items from the given maxID, up to the given id, from newest to oldest. + // This will NOT include the item with the given IDs. // // This corresponds to an api call to /timelines/home?since_id=WHATEVER&max_id=WHATEVER_ELSE - GetXBetweenID(ctx context.Context, amount int, maxID string, sinceID string) ([]*apimodel.Status, error) + GetXBetweenID(ctx context.Context, amount int, maxID string, sinceID string) ([]Preparable, error) /* INDEXING FUNCTIONS */ - // IndexOne puts a status into the timeline at the appropriate place according to its 'createdAt' property. + // IndexOne puts a item into the timeline at the appropriate place according to its 'createdAt' property. // - // The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false - // if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline. - IndexOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) + // The returned bool indicates whether or not the item was actually inserted into the timeline. This will be false + // if the item is a boost and the original item or another boost of it already exists < boostReinsertionDepth back in the timeline. + IndexOne(ctx context.Context, itemID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) - // OldestIndexedPostID returns the id of the rearmost (ie., the oldest) indexed post, or an error if something goes wrong. - // If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this. - OldestIndexedPostID(ctx context.Context) (string, error) - // NewestIndexedPostID returns the id of the frontmost (ie., the newest) indexed post, or an error if something goes wrong. - // If nothing goes wrong but there's no newest post, an empty string will be returned so make sure to check for this. - NewestIndexedPostID(ctx context.Context) (string, error) + // OldestIndexedItemID returns the id of the rearmost (ie., the oldest) indexed item, or an error if something goes wrong. + // If nothing goes wrong but there's no oldest item, an empty string will be returned so make sure to check for this. + OldestIndexedItemID(ctx context.Context) (string, error) + // NewestIndexedItemID returns the id of the frontmost (ie., the newest) indexed item, or an error if something goes wrong. + // If nothing goes wrong but there's no newest item, an empty string will be returned so make sure to check for this. + NewestIndexedItemID(ctx context.Context) (string, error) - IndexBefore(ctx context.Context, statusID string, include bool, amount int) error - IndexBehind(ctx context.Context, statusID string, include bool, amount int) error + IndexBefore(ctx context.Context, itemID string, amount int) error + IndexBehind(ctx context.Context, itemID string, amount int) error /* PREPARATION FUNCTIONS */ - // PrepareXFromTop instructs the timeline to prepare x amount of posts from the top of the timeline. + // PrepareXFromTop instructs the timeline to prepare x amount of items from the top of the timeline. PrepareFromTop(ctx context.Context, amount int) error // PrepareBehind instructs the timeline to prepare the next amount of entries for serialization, from position onwards. - // If include is true, then the given status ID will also be prepared, otherwise only entries behind it will be prepared. - PrepareBehind(ctx context.Context, statusID string, amount int) error - // IndexOne puts a status into the timeline at the appropriate place according to its 'createdAt' property, + // If include is true, then the given item ID will also be prepared, otherwise only entries behind it will be prepared. + PrepareBehind(ctx context.Context, itemID string, amount int) error + // IndexOne puts a item into the timeline at the appropriate place according to its 'createdAt' property, // and then immediately prepares it. // - // The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false - // if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline. - IndexAndPrepareOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) - // OldestPreparedPostID returns the id of the rearmost (ie., the oldest) prepared post, or an error if something goes wrong. - // If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this. - OldestPreparedPostID(ctx context.Context) (string, error) + // The returned bool indicates whether or not the item was actually inserted into the timeline. This will be false + // if the item is a boost and the original item or another boost of it already exists < boostReinsertionDepth back in the timeline. + IndexAndPrepareOne(ctx context.Context, itemID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) + // OldestPreparedItemID returns the id of the rearmost (ie., the oldest) prepared item, or an error if something goes wrong. + // If nothing goes wrong but there's no oldest item, an empty string will be returned so make sure to check for this. + OldestPreparedItemID(ctx context.Context) (string, error) /* INFO FUNCTIONS */ - // ActualPostIndexLength returns the actual length of the post index at this point in time. - PostIndexLength(ctx context.Context) int + // ActualPostIndexLength returns the actual length of the item index at this point in time. + ItemIndexLength(ctx context.Context) int /* UTILITY FUNCTIONS */ - // Reset instructs the timeline to reset to its base state -- cache only the minimum amount of posts. + // Reset instructs the timeline to reset to its base state -- cache only the minimum amount of items. Reset() error - // Remove removes a status from both the index and prepared posts. + // Remove removes a item from both the index and prepared items. // - // If a status has multiple entries in a timeline, they will all be removed. + // If a item has multiple entries in a timeline, they will all be removed. // // The returned int indicates the amount of entries that were removed. - Remove(ctx context.Context, statusID string) (int, error) - // RemoveAllBy removes all statuses by the given accountID, from both the index and prepared posts. + Remove(ctx context.Context, itemID string) (int, error) + // RemoveAllBy removes all items by the given accountID, from both the index and prepared items. // // The returned int indicates the amount of entries that were removed. RemoveAllBy(ctx context.Context, accountID string) (int, error) @@ -126,31 +157,34 @@ type Timeline interface { // timeline fulfils the Timeline interface type timeline struct { - postIndex *postIndex - preparedPosts *preparedPosts - accountID string - account *gtsmodel.Account - db db.DB - filter visibility.Filter - tc typeutils.TypeConverter + itemIndex *itemIndex + preparedItems *preparedItems + grabFunction GrabFunction + filterFunction FilterFunction + prepareFunction PrepareFunction + accountID string sync.Mutex } // NewTimeline returns a new Timeline for the given account ID -func NewTimeline(ctx context.Context, accountID string, db db.DB, typeConverter typeutils.TypeConverter) (Timeline, error) { - timelineOwnerAccount := >smodel.Account{} - if err := db.GetByID(ctx, accountID, timelineOwnerAccount); err != nil { - return nil, err - } - +func NewTimeline( + ctx context.Context, + timelineAccountID string, + grabFunction GrabFunction, + filterFunction FilterFunction, + prepareFunction PrepareFunction, + skipInsertFunction SkipInsertFunction) (Timeline, error) { return &timeline{ - postIndex: &postIndex{}, - preparedPosts: &preparedPosts{}, - accountID: accountID, - account: timelineOwnerAccount, - db: db, - filter: visibility.NewFilter(db), - tc: typeConverter, + itemIndex: &itemIndex{ + skipInsert: skipInsertFunction, + }, + preparedItems: &preparedItems{ + skipInsert: skipInsertFunction, + }, + grabFunction: grabFunction, + filterFunction: filterFunction, + prepareFunction: prepareFunction, + accountID: timelineAccountID, }, nil } @@ -158,10 +192,10 @@ func (t *timeline) Reset() error { return nil } -func (t *timeline) PostIndexLength(ctx context.Context) int { - if t.postIndex == nil || t.postIndex.data == nil { +func (t *timeline) ItemIndexLength(ctx context.Context) int { + if t.itemIndex == nil || t.itemIndex.data == nil { return 0 } - return t.postIndex.data.Len() + return t.itemIndex.data.Len() } diff --git a/internal/timeline/timeline_test.go b/internal/timeline/timeline_test.go index 96a938b07..ef6b66535 100644 --- a/internal/timeline/timeline_test.go +++ b/internal/timeline/timeline_test.go @@ -24,12 +24,14 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" ) type TimelineStandardTestSuite struct { suite.Suite - db db.DB - tc typeutils.TypeConverter + db db.DB + tc typeutils.TypeConverter + filter visibility.Filter testAccounts map[string]*gtsmodel.Account testStatuses map[string]*gtsmodel.Status diff --git a/internal/timeline/timelineable.go b/internal/timeline/timelineable.go new file mode 100644 index 000000000..cf2d06775 --- /dev/null +++ b/internal/timeline/timelineable.go @@ -0,0 +1,27 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package timeline + +// Timelineable represents any item that can be put in a timeline. +type Timelineable interface { + GetID() string + GetAccountID() string + GetBoostOfID() string + GetBoostOfAccountID() string +} diff --git a/internal/typeutils/internaltoas.go b/internal/typeutils/internaltoas.go index 8e28bf0f0..a22b926d6 100644 --- a/internal/typeutils/internaltoas.go +++ b/internal/typeutils/internaltoas.go @@ -625,6 +625,9 @@ func (c *converter) MentionToAS(ctx context.Context, m *gtsmodel.Mention) (vocab var domain string if m.TargetAccount.Domain == "" { accountDomain := viper.GetString(config.Keys.AccountDomain) + if accountDomain == "" { + accountDomain = viper.GetString(config.Keys.Host) + } domain = accountDomain } else { domain = m.TargetAccount.Domain |