diff options
Diffstat (limited to 'internal/db/bundb/account_test.go')
-rw-r--r-- | internal/db/bundb/account_test.go | 179 |
1 files changed, 143 insertions, 36 deletions
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index b7e8aaadc..2241ab783 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -21,6 +21,8 @@ import ( "context" "crypto/rand" "crypto/rsa" + "errors" + "reflect" "strings" "testing" "time" @@ -61,44 +63,149 @@ func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() { suite.Len(statuses, 1) } -func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { - account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) - if err != nil { - suite.FailNow(err.Error()) +func (suite *AccountTestSuite) TestGetAccountBy() { + t := suite.T() + + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Sentinel error to mark avoiding a test case. + sentinelErr := errors.New("sentinel") + + // isEqual checks if 2 account models are equal. + isEqual := func(a1, a2 gtsmodel.Account) bool { + // Clear populated sub-models. + a1.HeaderMediaAttachment = nil + a2.HeaderMediaAttachment = nil + a1.AvatarMediaAttachment = nil + a2.AvatarMediaAttachment = nil + a1.Emojis = nil + a2.Emojis = nil + + // Clear database-set fields. + a1.CreatedAt = time.Time{} + a2.CreatedAt = time.Time{} + a1.UpdatedAt = time.Time{} + a2.UpdatedAt = time.Time{} + + // Manually compare keys. + pk1 := a1.PublicKey + pv1 := a1.PrivateKey + pk2 := a2.PublicKey + pv2 := a2.PrivateKey + a1.PublicKey = nil + a1.PrivateKey = nil + a2.PublicKey = nil + a2.PrivateKey = nil + + return reflect.DeepEqual(a1, a2) && + ((pk1 == nil && pk2 == nil) || pk1.Equal(pk2)) && + ((pv1 == nil && pv2 == nil) || pv1.Equal(pv2)) } - suite.NotNil(account) - suite.NotNil(account.AvatarMediaAttachment) - suite.NotEmpty(account.AvatarMediaAttachment.URL) - suite.NotNil(account.HeaderMediaAttachment) - suite.NotEmpty(account.HeaderMediaAttachment.URL) -} - -func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() { - testAccount1 := suite.testAccounts["local_account_1"] - account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain) - suite.NoError(err) - suite.NotNil(account1) - - testAccount2 := suite.testAccounts["remote_account_1"] - account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain) - suite.NoError(err) - suite.NotNil(account2) -} - -func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() { - testAccount := suite.testAccounts["remote_account_2"] - account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain) - suite.NoError(err) - suite.NotNil(account1) - - account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain) - suite.NoError(err) - suite.NotNil(account2) - - account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain) - suite.NoError(err) - suite.NotNil(account3) + for _, account := range suite.testAccounts { + for lookup, dbfunc := range map[string]func() (*gtsmodel.Account, error){ + "id": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByID(ctx, account.ID) + }, + + "uri": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByURI(ctx, account.URI) + }, + + "url": func() (*gtsmodel.Account, error) { + if account.URL == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByURL(ctx, account.URL) + }, + + "username@domain": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByUsernameDomain(ctx, account.Username, account.Domain) + }, + + "username_upper@domain": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByUsernameDomain(ctx, strings.ToUpper(account.Username), account.Domain) + }, + + "username_lower@domain": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByUsernameDomain(ctx, strings.ToLower(account.Username), account.Domain) + }, + + "public_key_uri": func() (*gtsmodel.Account, error) { + if account.PublicKeyURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByPubkeyID(ctx, account.PublicKeyURI) + }, + + "inbox_uri": func() (*gtsmodel.Account, error) { + if account.InboxURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByInboxURI(ctx, account.InboxURI) + }, + + "outbox_uri": func() (*gtsmodel.Account, error) { + if account.OutboxURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByOutboxURI(ctx, account.OutboxURI) + }, + + "following_uri": func() (*gtsmodel.Account, error) { + if account.FollowingURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByFollowingURI(ctx, account.FollowingURI) + }, + + "followers_uri": func() (*gtsmodel.Account, error) { + if account.FollowersURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByFollowersURI(ctx, account.FollowersURI) + }, + } { + + // Clear database caches. + suite.state.Caches.Init() + + t.Logf("checking database lookup %q", lookup) + + // Perform database function. + checkAcc, err := dbfunc() + if err != nil { + if err == sentinelErr { + continue + } + + t.Errorf("error encountered for database lookup %q: %v", lookup, err) + continue + } + + // Check received account data. + if !isEqual(*checkAcc, *account) { + t.Errorf("account does not contain expected data: %+v", checkAcc) + continue + } + + // Check that avatar attachment populated. + if account.AvatarMediaAttachmentID != "" && + (checkAcc.AvatarMediaAttachment == nil || checkAcc.AvatarMediaAttachment.ID != account.AvatarMediaAttachmentID) { + t.Errorf("account avatar media attachment not correctly populated for: %+v", account) + continue + } + + // Check that header attachment populated. + if account.HeaderMediaAttachmentID != "" && + (checkAcc.HeaderMediaAttachment == nil || checkAcc.HeaderMediaAttachment.ID != account.HeaderMediaAttachmentID) { + t.Errorf("account header media attachment not correctly populated for: %+v", account) + continue + } + } + } } func (suite *AccountTestSuite) TestUpdateAccount() { |