summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-20 12:26:56 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-20 12:26:56 +0200
commit4920229a3b6e1d7dde536bc9ff766542b05d935c (patch)
treea9423beccec5331c372f01eedf38949dfb171e9e /internal/db
parentText/status parsing fixes (#141) (diff)
downloadgotosocial-4920229a3b6e1d7dde536bc9ff766542b05d935c.tar.xz
Database updates (#144)
* start moving some database stuff around * continue moving db stuff around * more fiddling * more updates * and some more * and yet more * i broke SOMETHING but what, it's a mystery * tidy up * vendor ttlcache * use ttlcache * fix up some tests * rename some stuff * little reminder * some more updates
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go66
-rw-r--r--internal/db/admin.go53
-rw-r--r--internal/db/basic.go87
-rw-r--r--internal/db/db.go265
-rw-r--r--internal/db/domain.go36
-rw-r--r--internal/db/error.go28
-rw-r--r--internal/db/instance.go36
-rw-r--r--internal/db/media.go (renamed from internal/db/pg/put.go)18
-rw-r--r--internal/db/mention.go30
-rw-r--r--internal/db/notification.go31
-rw-r--r--internal/db/pg/account.go256
-rw-r--r--internal/db/pg/account_test.go70
-rw-r--r--internal/db/pg/admin.go235
-rw-r--r--internal/db/pg/basic.go205
-rw-r--r--internal/db/pg/blocks.go67
-rw-r--r--internal/db/pg/domain.go83
-rw-r--r--internal/db/pg/get.go75
-rw-r--r--internal/db/pg/instance.go39
-rw-r--r--internal/db/pg/media.go (renamed from internal/db/pg/delete.go)52
-rw-r--r--internal/db/pg/mention.go108
-rw-r--r--internal/db/pg/notification.go135
-rw-r--r--internal/db/pg/pg.go827
-rw-r--r--internal/db/pg/pg_test.go47
-rw-r--r--internal/db/pg/relationship.go276
-rw-r--r--internal/db/pg/status.go318
-rw-r--r--internal/db/pg/status_test.go134
-rw-r--r--internal/db/pg/statuscontext.go104
-rw-r--r--internal/db/pg/timeline.go40
-rw-r--r--internal/db/pg/update.go73
-rw-r--r--internal/db/pg/util.go25
-rw-r--r--internal/db/relationship.go71
-rw-r--r--internal/db/status.go75
-rw-r--r--internal/db/timeline.go44
33 files changed, 2623 insertions, 1386 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
new file mode 100644
index 000000000..0e1575f9b
--- /dev/null
+++ b/internal/db/account.go
@@ -0,0 +1,66 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import (
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// Account contains functions related to account getting/setting/creation.
+type Account interface {
+ // GetAccountByID returns one account with the given ID, or an error if something goes wrong.
+ GetAccountByID(id string) (*gtsmodel.Account, Error)
+
+ // GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
+ GetAccountByURI(uri string) (*gtsmodel.Account, Error)
+
+ // GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
+ GetAccountByURL(uri string) (*gtsmodel.Account, Error)
+
+ // GetLocalAccountByUsername returns an account on this instance by its username.
+ GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
+
+ // GetAccountFaves fetches faves/likes created by the target accountID.
+ GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
+
+ // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
+ CountAccountStatuses(accountID string) (int, Error)
+
+ // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
+ // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
+ // be very memory intensive so you probably shouldn't do this!
+ // In case of no entries, a 'no entries' error will be returned
+ GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
+
+ GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
+
+ // GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
+ //
+ // The returned time will be zero if account has never posted anything.
+ GetAccountLastPosted(accountID string) (time.Time, Error)
+
+ // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
+ SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
+
+ // GetInstanceAccount returns the instance account for the given domain.
+ // If domain is empty, this instance account will be returned.
+ GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
+}
diff --git a/internal/db/admin.go b/internal/db/admin.go
new file mode 100644
index 000000000..aa2b22f47
--- /dev/null
+++ b/internal/db/admin.go
@@ -0,0 +1,53 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import (
+ "net"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// Admin contains functions related to instance administration (new signups etc).
+type Admin interface {
+ // IsUsernameAvailable checks whether a given username is available on our domain.
+ // Returns an error if the username is already taken, or something went wrong in the db.
+ IsUsernameAvailable(username string) Error
+
+ // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
+ // Return an error if:
+ // A) the email is already associated with an account
+ // B) we block signups from this email domain
+ // C) something went wrong in the db
+ IsEmailAvailable(email string) Error
+
+ // NewSignup creates a new user in the database with the given parameters.
+ // By the time this function is called, it should be assumed that all the parameters have passed validation!
+ NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
+
+ // CreateInstanceAccount creates an account in the database with the same username as the instance host value.
+ // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
+ // This is needed for things like serving files that belong to the instance and not an individual user/account.
+ CreateInstanceAccount() Error
+
+ // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
+ // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
+ // This is needed for things like serving instance information through /api/v1/instance
+ CreateInstanceInstance() Error
+}
diff --git a/internal/db/basic.go b/internal/db/basic.go
new file mode 100644
index 000000000..729920bba
--- /dev/null
+++ b/internal/db/basic.go
@@ -0,0 +1,87 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "context"
+
+// Basic wraps basic database functionality.
+type Basic interface {
+ // CreateTable creates a table for the given interface.
+ // For implementations that don't use tables, this can just return nil.
+ CreateTable(i interface{}) Error
+
+ // DropTable drops the table for the given interface.
+ // For implementations that don't use tables, this can just return nil.
+ DropTable(i interface{}) Error
+
+ // RegisterTable registers a table for use in many2many relations.
+ // For implementations that don't use tables, or many2many relations, this can just return nil.
+ RegisterTable(i interface{}) Error
+
+ // Stop should stop and close the database connection cleanly, returning an error if this is not possible.
+ // If the database implementation doesn't need to be stopped, this can just return nil.
+ Stop(ctx context.Context) Error
+
+ // IsHealthy should return nil if the database connection is healthy, or an error if not.
+ IsHealthy(ctx context.Context) Error
+
+ // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
+ // for other implementations (for example, in-memory) it might just be the key of a map.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ // In case of no entries, a 'no entries' error will be returned
+ GetByID(id string, i interface{}) Error
+
+ // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
+ // name of the key to select from.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ // In case of no entries, a 'no entries' error will be returned
+ GetWhere(where []Where, i interface{}) Error
+
+ // GetAll will try to get all entries of type i.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ // In case of no entries, a 'no entries' error will be returned
+ GetAll(i interface{}) Error
+
+ // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ Put(i interface{}) Error
+
+ // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
+ // It is up to the implementation to figure out how to store it, and using what key.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ Upsert(i interface{}, conflictColumn string) Error
+
+ // UpdateByID updates i with id id.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ UpdateByID(id string, i interface{}) Error
+
+ // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
+ UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
+
+ // UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
+ UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
+
+ // DeleteByID removes i with id id.
+ // If i didn't exist anyway, then no error should be returned.
+ DeleteByID(id string, i interface{}) Error
+
+ // DeleteWhere deletes i where key = value
+ // If i didn't exist anyway, then no error should be returned.
+ DeleteWhere(where []Where, i interface{}) Error
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index d0b23fbc6..d6ac883e4 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -19,9 +19,6 @@
package db
import (
- "context"
- "net"
-
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -30,257 +27,19 @@ const (
DBTypePostgres string = "POSTGRES"
)
-// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
-// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
-// by whatever is returned from the database.
+// DB provides methods for interacting with an underlying database or other storage mechanism.
type DB interface {
- /*
- BASIC DB FUNCTIONALITY
- */
-
- // CreateTable creates a table for the given interface.
- // For implementations that don't use tables, this can just return nil.
- CreateTable(i interface{}) error
-
- // DropTable drops the table for the given interface.
- // For implementations that don't use tables, this can just return nil.
- DropTable(i interface{}) error
-
- // Stop should stop and close the database connection cleanly, returning an error if this is not possible.
- // If the database implementation doesn't need to be stopped, this can just return nil.
- Stop(ctx context.Context) error
-
- // IsHealthy should return nil if the database connection is healthy, or an error if not.
- IsHealthy(ctx context.Context) error
-
- // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
- // for other implementations (for example, in-memory) it might just be the key of a map.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- // In case of no entries, a 'no entries' error will be returned
- GetByID(id string, i interface{}) error
-
- // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
- // name of the key to select from.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- // In case of no entries, a 'no entries' error will be returned
- GetWhere(where []Where, i interface{}) error
-
- // GetAll will try to get all entries of type i.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- // In case of no entries, a 'no entries' error will be returned
- GetAll(i interface{}) error
-
- // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Put(i interface{}) error
-
- // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
- // It is up to the implementation to figure out how to store it, and using what key.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Upsert(i interface{}, conflictColumn string) error
-
- // UpdateByID updates i with id id.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- UpdateByID(id string, i interface{}) error
-
- // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
- UpdateOneByID(id string, key string, value interface{}, i interface{}) error
-
- // UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
- UpdateWhere(where []Where, key string, value interface{}, i interface{}) error
-
- // DeleteByID removes i with id id.
- // If i didn't exist anyway, then no error should be returned.
- DeleteByID(id string, i interface{}) error
-
- // DeleteWhere deletes i where key = value
- // If i didn't exist anyway, then no error should be returned.
- DeleteWhere(where []Where, i interface{}) error
-
- /*
- HANDY SHORTCUTS
- */
-
- // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
- // In other words, it should create the follow, and delete the existing follow request.
- //
- // It will return the newly created follow for further processing.
- AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
-
- // CreateInstanceAccount creates an account in the database with the same username as the instance host value.
- // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
- // This is needed for things like serving files that belong to the instance and not an individual user/account.
- CreateInstanceAccount() error
-
- // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
- // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
- // This is needed for things like serving instance information through /api/v1/instance
- CreateInstanceInstance() error
-
- // GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
- // The given account pointer will be set to the result of the query, whatever it is.
- // In case of no entries, a 'no entries' error will be returned
- GetAccountByUserID(userID string, account *gtsmodel.Account) error
-
- // GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE
- // according to its username, which should be unique.
- // The given account pointer will be set to the result of the query, whatever it is.
- // In case of no entries, a 'no entries' error will be returned
- GetLocalAccountByUsername(username string, account *gtsmodel.Account) error
-
- // GetFollowRequestsForAccountID is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
- // The given slice 'followRequests' will be set to the result of the query, whatever it is.
- // In case of no entries, a 'no entries' error will be returned
- GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error
-
- // GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following.
- // The given slice 'following' will be set to the result of the query, whatever it is.
- // In case of no entries, a 'no entries' error will be returned
- GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error
-
- // GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
- // The given slice 'followers' will be set to the result of the query, whatever it is.
- // In case of no entries, a 'no entries' error will be returned
- //
- // If localOnly is set to true, then only followers from *this instance* will be returned.
- GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error
-
- // GetFavesByAccountID is a shortcut for the common action of fetching a list of faves made by the given accountID.
- // The given slice 'faves' will be set to the result of the query, whatever it is.
- // In case of no entries, a 'no entries' error will be returned
- GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error
-
- // CountStatusesByAccountID is a shortcut for the common action of counting statuses produced by accountID.
- CountStatusesByAccountID(accountID string) (int, error)
-
- // GetStatusesForAccount is a shortcut for getting the most recent statuses. accountID is optional, if not provided
- // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
- // be very memory intensive so you probably shouldn't do this!
- // In case of no entries, a 'no entries' error will be returned
- GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error)
-
- GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
-
- // GetLastStatusForAccountID simply gets the most recent status by the given account.
- // The given slice 'status' pointer will be set to the result of the query, whatever it is.
- // In case of no entries, a 'no entries' error will be returned
- GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error
-
- // IsUsernameAvailable checks whether a given username is available on our domain.
- // Returns an error if the username is already taken, or something went wrong in the db.
- IsUsernameAvailable(username string) error
-
- // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
- // Return an error if:
- // A) the email is already associated with an account
- // B) we block signups from this email domain
- // C) something went wrong in the db
- IsEmailAvailable(email string) error
-
- // NewSignup creates a new user in the database with the given parameters.
- // By the time this function is called, it should be assumed that all the parameters have passed validation!
- NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error)
-
- // SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
- SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
-
- // GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
- // The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
- GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error
-
- // GetHeaderForAccountID gets the current header for the given account ID.
- // The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
- GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error
-
- // Blocked checks whether a block exists in eiher direction between two accounts.
- // That is, it returns true if account1 blocks account2, OR if account2 blocks account1.
- Blocked(account1 string, account2 string) (bool, error)
-
- // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
- GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
-
- // Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
-
- // FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
-
- // Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error)
-
- // GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong
- GetReplyCountForStatus(status *gtsmodel.Status) (int, error)
-
- // GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
- GetReblogCountForStatus(status *gtsmodel.Status) (int, error)
-
- // GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong
- GetFaveCountForStatus(status *gtsmodel.Status) (int, error)
-
- // StatusParents get the parent statuses of a given status.
- //
- // If onlyDirect is true, only the immediate parent will be returned.
- StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
-
- // StatusChildren gets the child statuses of a given status.
- //
- // If onlyDirect is true, only the immediate children will be returned.
- StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
-
- // StatusFavedBy checks if a given status has been faved by a given account ID
- StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error)
-
- // StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
- StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error)
-
- // StatusMutedBy checks if a given status has been muted by a given account ID
- StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error)
-
- // StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
- StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error)
-
- // WhoFavedStatus returns a slice of accounts who faved the given status.
- // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
-
- // WhoBoostedStatus returns a slice of accounts who boosted the given status.
- // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
-
- // GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id.
- //
- // Statuses should be returned in descending order of when they were created (newest first).
- GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
-
- // GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
- // It will use the given filters and try to return as many statuses as possible up to the limit.
- //
- // Statuses should be returned in descending order of when they were created (newest first).
- GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
-
- // GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
- // It will use the given filters and try to return as many statuses as possible up to the limit.
- //
- // Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id.
- // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
- //
- // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
- GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
-
- // GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
- GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error)
-
- // GetUserCountForInstance returns the number of known accounts registered with the given domain.
- GetUserCountForInstance(domain string) (int, error)
-
- // GetStatusCountForInstance returns the number of known statuses posted from the given domain.
- GetStatusCountForInstance(domain string) (int, error)
-
- // GetDomainCountForInstance returns the number of known instances known that the given domain federates with.
- GetDomainCountForInstance(domain string) (int, error)
-
- // GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID.
- GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
+ Account
+ Admin
+ Basic
+ Domain
+ Instance
+ Media
+ Mention
+ Notification
+ Relationship
+ Status
+ Timeline
/*
USEFUL CONVERSION FUNCTIONS
diff --git a/internal/db/domain.go b/internal/db/domain.go
new file mode 100644
index 000000000..a6583c80c
--- /dev/null
+++ b/internal/db/domain.go
@@ -0,0 +1,36 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "net/url"
+
+// Domain contains DB functions related to domains and domain blocks.
+type Domain interface {
+ // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
+ IsDomainBlocked(domain string) (bool, Error)
+
+ // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
+ AreDomainsBlocked(domains []string) (bool, Error)
+
+ // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
+ IsURIBlocked(uri *url.URL) (bool, Error)
+
+ // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
+ AreURIsBlocked(uris []*url.URL) (bool, Error)
+}
diff --git a/internal/db/error.go b/internal/db/error.go
index 197c7bd68..c13bd78dd 100644
--- a/internal/db/error.go
+++ b/internal/db/error.go
@@ -18,16 +18,18 @@
package db
-// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
-type ErrNoEntries struct{}
-
-func (e ErrNoEntries) Error() string {
- return "no entries"
-}
-
-// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints.
-type ErrAlreadyExists struct{}
-
-func (e ErrAlreadyExists) Error() string {
- return "already exists"
-}
+import "fmt"
+
+// Error denotes a database error.
+type Error error
+
+var (
+ // ErrNoEntries is returned when a caller expected an entry for a query, but none was found.
+ ErrNoEntries Error = fmt.Errorf("no entries")
+ // ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found.
+ ErrMultipleEntries Error = fmt.Errorf("multiple entries")
+ // ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db.
+ ErrAlreadyExists Error = fmt.Errorf("already exists")
+ // ErrUnknown denotes an unknown database error.
+ ErrUnknown Error = fmt.Errorf("unknown error")
+)
diff --git a/internal/db/instance.go b/internal/db/instance.go
new file mode 100644
index 000000000..1f7c83e4f
--- /dev/null
+++ b/internal/db/instance.go
@@ -0,0 +1,36 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+
+// Instance contains functions for instance-level actions (counting instance users etc.).
+type Instance interface {
+ // CountInstanceUsers returns the number of known accounts registered with the given domain.
+ CountInstanceUsers(domain string) (int, Error)
+
+ // CountInstanceStatuses returns the number of known statuses posted from the given domain.
+ CountInstanceStatuses(domain string) (int, Error)
+
+ // CountInstanceDomains returns the number of known instances known that the given domain federates with.
+ CountInstanceDomains(domain string) (int, Error)
+
+ // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
+ GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
+}
diff --git a/internal/db/pg/put.go b/internal/db/media.go
index 09beca14b..db4db3411 100644
--- a/internal/db/pg/put.go
+++ b/internal/db/media.go
@@ -16,18 +16,12 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package db
-import (
- "strings"
+import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-func (ps *postgresService) Put(i interface{}) error {
- _, err := ps.conn.Model(i).Insert(i)
- if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists{}
- }
- return err
+// Media contains functions related to creating/getting/removing media attachments.
+type Media interface {
+ // GetAttachmentByID gets a single attachment by its ID
+ GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
}
diff --git a/internal/db/mention.go b/internal/db/mention.go
new file mode 100644
index 000000000..cb1c56dc1
--- /dev/null
+++ b/internal/db/mention.go
@@ -0,0 +1,30 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+
+// Mention contains functions for getting/creating mentions in the database.
+type Mention interface {
+ // GetMention gets a single mention by ID
+ GetMention(id string) (*gtsmodel.Mention, Error)
+
+ // GetMentions gets multiple mentions.
+ GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
+}
diff --git a/internal/db/notification.go b/internal/db/notification.go
new file mode 100644
index 000000000..326f0f149
--- /dev/null
+++ b/internal/db/notification.go
@@ -0,0 +1,31 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+
+// Notification contains functions for creating and getting notifications.
+type Notification interface {
+ // GetNotifications returns a slice of notifications that pertain to the given accountID.
+ //
+ // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
+ GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
+ // GetNotification returns one notification according to its id.
+ GetNotification(id string) (*gtsmodel.Notification, Error)
+}
diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go
new file mode 100644
index 000000000..3889c6601
--- /dev/null
+++ b/internal/db/pg/account.go
@@ -0,0 +1,256 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type accountDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
+ return a.conn.Model(account).
+ Relation("AvatarMediaAttachment").
+ Relation("HeaderMediaAttachment")
+}
+
+func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
+ account := &gtsmodel.Account{}
+
+ q := a.newAccountQ(account).
+ Where("account.id = ?", id)
+
+ err := processErrorResponse(q.Select())
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
+ account := &gtsmodel.Account{}
+
+ q := a.newAccountQ(account).
+ Where("account.uri = ?", uri)
+
+ err := processErrorResponse(q.Select())
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
+ account := &gtsmodel.Account{}
+
+ q := a.newAccountQ(account).
+ Where("account.url = ?", uri)
+
+ err := processErrorResponse(q.Select())
+
+ return account, err
+}
+
+func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
+ account := &gtsmodel.Account{}
+
+ q := a.newAccountQ(account)
+
+ if domain == "" {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("account.domain = ?", domain)
+ } else {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("? IS NULL", pg.Ident("domain"))
+ }
+
+ err := processErrorResponse(q.Select())
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
+ status := &gtsmodel.Status{}
+
+ q := a.conn.Model(status).
+ Order("id DESC").
+ Limit(1).
+ Where("account_id = ?", accountID).
+ Column("created_at")
+
+ err := processErrorResponse(q.Select())
+
+ return status.CreatedAt, err
+}
+
+func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
+ if mediaAttachment.Avatar && mediaAttachment.Header {
+ return errors.New("one media attachment cannot be both header and avatar")
+ }
+
+ var headerOrAVI string
+ if mediaAttachment.Avatar {
+ headerOrAVI = "avatar"
+ } else if mediaAttachment.Header {
+ headerOrAVI = "header"
+ } else {
+ return errors.New("given media attachment was neither a header nor an avatar")
+ }
+
+ // TODO: there are probably more side effects here that need to be handled
+ if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ return err
+ }
+
+ if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
+ account := &gtsmodel.Account{}
+
+ q := a.newAccountQ(account).
+ Where("username = ?", username).
+ Where("? IS NULL", pg.Ident("domain"))
+
+ err := processErrorResponse(q.Select())
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := []*gtsmodel.StatusFave{}
+
+ if err := a.conn.Model(&faves).
+ Where("account_id = ?", accountID).
+ Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return faves, nil
+ }
+ return nil, err
+ }
+ return faves, nil
+}
+
+func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
+ return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
+}
+
+func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
+ a.log.Debugf("getting statuses for account %s", accountID)
+ statuses := []*gtsmodel.Status{}
+
+ q := a.conn.Model(&statuses).Order("id DESC")
+ if accountID != "" {
+ q = q.Where("account_id = ?", accountID)
+ }
+
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+
+ if excludeReplies {
+ q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
+ }
+
+ if pinnedOnly {
+ q = q.Where("pinned = ?", true)
+ }
+
+ if mediaOnly {
+ q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
+ return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
+ })
+ }
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if err := q.Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return nil, db.ErrNoEntries
+ }
+ return nil, err
+ }
+
+ if len(statuses) == 0 {
+ return nil, db.ErrNoEntries
+ }
+
+ a.log.Debugf("returning statuses for account %s", accountID)
+ return statuses, nil
+}
+
+func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
+ blocks := []*gtsmodel.Block{}
+
+ fq := a.conn.Model(&blocks).
+ Where("block.account_id = ?", accountID).
+ Relation("TargetAccount").
+ Order("block.id DESC")
+
+ if maxID != "" {
+ fq = fq.Where("block.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ fq = fq.Where("block.id > ?", sinceID)
+ }
+
+ if limit > 0 {
+ fq = fq.Limit(limit)
+ }
+
+ err := fq.Select()
+ if err != nil {
+ if err == pg.ErrNoRows {
+ return nil, "", "", db.ErrNoEntries
+ }
+ return nil, "", "", err
+ }
+
+ if len(blocks) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ accounts := []*gtsmodel.Account{}
+ for _, b := range blocks {
+ accounts = append(accounts, b.TargetAccount)
+ }
+
+ nextMaxID := blocks[len(blocks)-1].ID
+ prevMinID := blocks[0].ID
+ return accounts, nextMaxID, prevMinID, nil
+}
diff --git a/internal/db/pg/account_test.go b/internal/db/pg/account_test.go
new file mode 100644
index 000000000..7ea5ff39a
--- /dev/null
+++ b/internal/db/pg/account_test.go
@@ -0,0 +1,70 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type AccountTestSuite struct {
+ PGStandardTestSuite
+}
+
+func (suite *AccountTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *AccountTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *AccountTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
+ account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(account)
+ suite.NotNil(account.AvatarMediaAttachment)
+ suite.NotEmpty(account.AvatarMediaAttachment.URL)
+ suite.NotNil(account.HeaderMediaAttachment)
+ suite.NotEmpty(account.HeaderMediaAttachment.URL)
+}
+
+func TestAccountTestSuite(t *testing.T) {
+ suite.Run(t, new(AccountTestSuite))
+}
diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go
new file mode 100644
index 000000000..854f56ef0
--- /dev/null
+++ b/internal/db/pg/admin.go
@@ -0,0 +1,235 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "fmt"
+ "net"
+ "net/mail"
+ "strings"
+ "time"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "golang.org/x/crypto/bcrypt"
+)
+
+type adminDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (a *adminDB) IsUsernameAvailable(username string) db.Error {
+ // if no error we fail because it means we found something
+ // if error but it's not pg.ErrNoRows then we fail
+ // if err is pg.ErrNoRows we're good, we found nothing so continue
+ if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
+ return fmt.Errorf("username %s already in use", username)
+ } else if err != pg.ErrNoRows {
+ return fmt.Errorf("db error: %s", err)
+ }
+ return nil
+}
+
+func (a *adminDB) IsEmailAvailable(email string) db.Error {
+ // parse the domain from the email
+ m, err := mail.ParseAddress(email)
+ if err != nil {
+ return fmt.Errorf("error parsing email address %s: %s", email, err)
+ }
+ domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
+
+ // check if the email domain is blocked
+ if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
+ // fail because we found something
+ return fmt.Errorf("email domain %s is blocked", domain)
+ } else if err != pg.ErrNoRows {
+ // fail because we got an unexpected error
+ return fmt.Errorf("db error: %s", err)
+ }
+
+ // check if this email is associated with a user already
+ if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
+ // fail because we found something
+ return fmt.Errorf("email %s already in use", email)
+ } else if err != pg.ErrNoRows {
+ // fail because we got an unexpected error
+ return fmt.Errorf("db error: %s", err)
+ }
+ return nil
+}
+
+func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ a.log.Errorf("error creating new rsa key: %s", err)
+ return nil, err
+ }
+
+ // if something went wrong while creating a user, we might already have an account, so check here first...
+ acct := &gtsmodel.Account{}
+ err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
+ if err != nil {
+ // there's been an actual error
+ if err != pg.ErrNoRows {
+ return nil, fmt.Errorf("db error checking existence of account: %s", err)
+ }
+
+ // we just don't have an account yet create one
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
+ newAccountID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+
+ acct = &gtsmodel.Account{
+ ID: newAccountID,
+ Username: username,
+ DisplayName: username,
+ Reason: reason,
+ URL: newAccountURIs.UserURL,
+ PrivateKey: key,
+ PublicKey: &key.PublicKey,
+ PublicKeyURI: newAccountURIs.PublicKeyURI,
+ ActorType: gtsmodel.ActivityStreamsPerson,
+ URI: newAccountURIs.UserURI,
+ InboxURI: newAccountURIs.InboxURI,
+ OutboxURI: newAccountURIs.OutboxURI,
+ FollowersURI: newAccountURIs.FollowersURI,
+ FollowingURI: newAccountURIs.FollowingURI,
+ FeaturedCollectionURI: newAccountURIs.CollectionURI,
+ }
+ if _, err = a.conn.Model(acct).Insert(); err != nil {
+ return nil, err
+ }
+ }
+
+ pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ if err != nil {
+ return nil, fmt.Errorf("error hashing password: %s", err)
+ }
+
+ newUserID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+
+ u := &gtsmodel.User{
+ ID: newUserID,
+ AccountID: acct.ID,
+ EncryptedPassword: string(pw),
+ SignUpIP: signUpIP.To4(),
+ Locale: locale,
+ UnconfirmedEmail: email,
+ CreatedByApplicationID: appID,
+ Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
+ }
+
+ if emailVerified {
+ u.ConfirmedAt = time.Now()
+ u.Email = email
+ }
+
+ if admin {
+ u.Admin = true
+ u.Moderator = true
+ }
+
+ if _, err = a.conn.Model(u).Insert(); err != nil {
+ return nil, err
+ }
+
+ return u, nil
+}
+
+func (a *adminDB) CreateInstanceAccount() db.Error {
+ username := a.config.Host
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ a.log.Errorf("error creating new rsa key: %s", err)
+ return err
+ }
+
+ aID, err := id.NewRandomULID()
+ if err != nil {
+ return err
+ }
+
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
+ acct := &gtsmodel.Account{
+ ID: aID,
+ Username: a.config.Host,
+ DisplayName: username,
+ URL: newAccountURIs.UserURL,
+ PrivateKey: key,
+ PublicKey: &key.PublicKey,
+ PublicKeyURI: newAccountURIs.PublicKeyURI,
+ ActorType: gtsmodel.ActivityStreamsPerson,
+ URI: newAccountURIs.UserURI,
+ InboxURI: newAccountURIs.InboxURI,
+ OutboxURI: newAccountURIs.OutboxURI,
+ FollowersURI: newAccountURIs.FollowersURI,
+ FollowingURI: newAccountURIs.FollowingURI,
+ FeaturedCollectionURI: newAccountURIs.CollectionURI,
+ }
+ inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
+ if err != nil {
+ return err
+ }
+ if inserted {
+ a.log.Infof("created instance account %s with id %s", username, acct.ID)
+ } else {
+ a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
+ }
+ return nil
+}
+
+func (a *adminDB) CreateInstanceInstance() db.Error {
+ iID, err := id.NewRandomULID()
+ if err != nil {
+ return err
+ }
+
+ i := &gtsmodel.Instance{
+ ID: iID,
+ Domain: a.config.Host,
+ Title: a.config.Host,
+ URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
+ }
+ inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
+ if err != nil {
+ return err
+ }
+ if inserted {
+ a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
+ } else {
+ a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
+ }
+ return nil
+}
diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go
new file mode 100644
index 000000000..6e76b4450
--- /dev/null
+++ b/internal/db/pg/basic.go
@@ -0,0 +1,205 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+)
+
+type basicDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (b *basicDB) Put(i interface{}) db.Error {
+ _, err := b.conn.Model(i).Insert(i)
+ if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+}
+
+func (b *basicDB) GetByID(id string, i interface{}) db.Error {
+ if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return db.ErrNoEntries
+ }
+ return err
+
+ }
+ return nil
+}
+
+func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.Model(i)
+ for _, w := range where {
+
+ if w.Value == nil {
+ q = q.Where("? IS NULL", pg.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ if err := q.Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return db.ErrNoEntries
+ }
+ return err
+ }
+ return nil
+}
+
+func (b *basicDB) GetAll(i interface{}) db.Error {
+ if err := b.conn.Model(i).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return db.ErrNoEntries
+ }
+ return err
+ }
+ return nil
+}
+
+func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
+ if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
+ // if there are no rows *anyway* then that's fine
+ // just return err if there's an actual error
+ if err != pg.ErrNoRows {
+ return err
+ }
+ }
+ return nil
+}
+
+func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.Model(i)
+ for _, w := range where {
+ q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
+ }
+
+ if _, err := q.Delete(); err != nil {
+ // if there are no rows *anyway* then that's fine
+ // just return err if there's an actual error
+ if err != pg.ErrNoRows {
+ return err
+ }
+ }
+ return nil
+}
+
+func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
+ if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
+ if err == pg.ErrNoRows {
+ return db.ErrNoEntries
+ }
+ return err
+ }
+ return nil
+}
+
+func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
+ if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if err == pg.ErrNoRows {
+ return db.ErrNoEntries
+ }
+ return err
+ }
+ return nil
+}
+
+func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
+ _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
+ return err
+}
+
+func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.Model(i)
+
+ for _, w := range where {
+ if w.Value == nil {
+ q = q.Where("? IS NULL", pg.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ q = q.Set("? = ?", pg.Safe(key), value)
+
+ _, err := q.Update()
+
+ return err
+}
+
+func (b *basicDB) CreateTable(i interface{}) db.Error {
+ return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
+ IfNotExists: true,
+ })
+}
+
+func (b *basicDB) DropTable(i interface{}) db.Error {
+ return b.conn.Model(i).DropTable(&orm.DropTableOptions{
+ IfExists: true,
+ })
+}
+
+func (b *basicDB) RegisterTable(i interface{}) db.Error {
+ orm.RegisterTable(i)
+ return nil
+}
+
+func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
+ return b.conn.Ping(ctx)
+}
+
+func (b *basicDB) Stop(ctx context.Context) db.Error {
+ b.log.Info("closing db connection")
+ if err := b.conn.Close(); err != nil {
+ // only cancel if there's a problem closing the db
+ b.cancel()
+ return err
+ }
+ return nil
+}
diff --git a/internal/db/pg/blocks.go b/internal/db/pg/blocks.go
deleted file mode 100644
index a6fc1f859..000000000
--- a/internal/db/pg/blocks.go
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "github.com/go-pg/pg/v10"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-func (ps *postgresService) GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
- blocks := []*gtsmodel.Block{}
-
- fq := ps.conn.Model(&blocks).
- Where("block.account_id = ?", accountID).
- Relation("TargetAccount").
- Order("block.id DESC")
-
- if maxID != "" {
- fq = fq.Where("block.id < ?", maxID)
- }
-
- if sinceID != "" {
- fq = fq.Where("block.id > ?", sinceID)
- }
-
- if limit > 0 {
- fq = fq.Limit(limit)
- }
-
- err := fq.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries{}
- }
- return nil, "", "", err
- }
-
- if len(blocks) == 0 {
- return nil, "", "", db.ErrNoEntries{}
- }
-
- accounts := []*gtsmodel.Account{}
- for _, b := range blocks {
- accounts = append(accounts, b.TargetAccount)
- }
-
- nextMaxID := blocks[len(blocks)-1].ID
- prevMinID := blocks[0].ID
- return accounts, nextMaxID, prevMinID, nil
-}
diff --git a/internal/db/pg/domain.go b/internal/db/pg/domain.go
new file mode 100644
index 000000000..4e9b2ab48
--- /dev/null
+++ b/internal/db/pg/domain.go
@@ -0,0 +1,83 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "context"
+ "net/url"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+)
+
+type domainDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
+ if domain == "" {
+ return false, nil
+ }
+
+ blocked, err := d.conn.
+ Model(&gtsmodel.DomainBlock{}).
+ Where("LOWER(domain) = LOWER(?)", domain).
+ Exists()
+
+ err = processErrorResponse(err)
+
+ return blocked, err
+}
+
+func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
+ // filter out any doubles
+ uniqueDomains := util.UniqueStrings(domains)
+
+ for _, domain := range uniqueDomains {
+ if blocked, err := d.IsDomainBlocked(domain); err != nil {
+ return false, err
+ } else if blocked {
+ return blocked, nil
+ }
+ }
+
+ // no blocks found
+ return false, nil
+}
+
+func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
+ domain := uri.Hostname()
+ return d.IsDomainBlocked(domain)
+}
+
+func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
+ domains := []string{}
+ for _, uri := range uris {
+ domains = append(domains, uri.Hostname())
+ }
+
+ return d.AreDomainsBlocked(domains)
+}
diff --git a/internal/db/pg/get.go b/internal/db/pg/get.go
deleted file mode 100644
index d48c43520..000000000
--- a/internal/db/pg/get.go
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "errors"
-
- "github.com/go-pg/pg/v10"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-func (ps *postgresService) GetByID(id string, i interface{}) error {
- if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
-
- }
- return nil
-}
-
-func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := ps.conn.Model(i)
- for _, w := range where {
-
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetAll(i interface{}) error {
- if err := ps.conn.Model(i).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go
index c551b2a49..968832ca5 100644
--- a/internal/db/pg/instance.go
+++ b/internal/db/pg/instance.go
@@ -19,15 +19,26 @@
package pg
import (
+ "context"
+
"github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Account{})
+type instanceDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
+ q := i.conn.Model(&[]*gtsmodel.Account{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain"))
} else {
@@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
return q.Count()
}
-func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Status{})
+func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
+ q := i.conn.Model(&[]*gtsmodel.Status{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
@@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
return q.Count()
}
-func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Instance{})
+func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
+ q := i.conn.Model(&[]*gtsmodel.Instance{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
@@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
return q.Count()
}
-func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
- ps.log.Debug("GetAccountsForInstance")
+func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
+ i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
- q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
+ q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(accounts) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return accounts, nil
diff --git a/internal/db/pg/delete.go b/internal/db/pg/media.go
index 0f288353e..618030af3 100644
--- a/internal/db/pg/delete.go
+++ b/internal/db/pg/media.go
@@ -19,39 +19,35 @@
package pg
import (
- "errors"
+ "context"
"github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) DeleteByID(id string, i interface{}) error {
- if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
+type mediaDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
}
-func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := ps.conn.Model(i)
- for _, w := range where {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
-
- if _, err := q.Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
+func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
+ return m.conn.Model(i).
+ Relation("Account")
+}
+
+func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
+ attachment := &gtsmodel.MediaAttachment{}
+
+ q := m.newMediaQ(attachment).
+ Where("media_attachment.id = ?", id)
+
+ err := processErrorResponse(q.Select())
+
+ return attachment, err
}
diff --git a/internal/db/pg/mention.go b/internal/db/pg/mention.go
new file mode 100644
index 000000000..b31f07b67
--- /dev/null
+++ b/internal/db/pg/mention.go
@@ -0,0 +1,108 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "context"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type mentionDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+ cache cache.Cache
+}
+
+func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
+ if m.cache == nil {
+ m.cache = cache.New()
+ }
+
+ if err := m.cache.Store(id, mention); err != nil {
+ m.log.Panicf("mentionDB: error storing in cache: %s", err)
+ }
+}
+
+func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
+ if m.cache == nil {
+ m.cache = cache.New()
+ return nil, false
+ }
+
+ mI, err := m.cache.Fetch(id)
+ if err != nil || mI == nil {
+ return nil, false
+ }
+
+ mention, ok := mI.(*gtsmodel.Mention)
+ if !ok {
+ m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
+ }
+
+ return mention, true
+}
+
+func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
+ return m.conn.Model(i).
+ Relation("Status").
+ Relation("OriginAccount").
+ Relation("TargetAccount")
+}
+
+func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
+ if mention, cached := m.mentionCached(id); cached {
+ return mention, nil
+ }
+
+ mention := &gtsmodel.Mention{}
+
+ q := m.newMentionQ(mention).
+ Where("mention.id = ?", id)
+
+ err := processErrorResponse(q.Select())
+
+ if err == nil && mention != nil {
+ m.cacheMention(id, mention)
+ }
+
+ return mention, err
+}
+
+func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
+ mentions := []*gtsmodel.Mention{}
+
+ for _, i := range ids {
+ mention, err := m.GetMention(i)
+ if err != nil {
+ return nil, processErrorResponse(err)
+ }
+ mentions = append(mentions, mention)
+ }
+
+ return mentions, nil
+}
diff --git a/internal/db/pg/notification.go b/internal/db/pg/notification.go
new file mode 100644
index 000000000..281a76d85
--- /dev/null
+++ b/internal/db/pg/notification.go
@@ -0,0 +1,135 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "context"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type notificationDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+ cache cache.Cache
+}
+
+func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
+ if n.cache == nil {
+ n.cache = cache.New()
+ }
+
+ if err := n.cache.Store(id, notification); err != nil {
+ n.log.Panicf("notificationDB: error storing in cache: %s", err)
+ }
+}
+
+func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
+ if n.cache == nil {
+ n.cache = cache.New()
+ return nil, false
+ }
+
+ nI, err := n.cache.Fetch(id)
+ if err != nil || nI == nil {
+ return nil, false
+ }
+
+ notification, ok := nI.(*gtsmodel.Notification)
+ if !ok {
+ n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
+ }
+
+ return notification, true
+}
+
+func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query {
+ return n.conn.Model(i).
+ Relation("OriginAccount").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) {
+ if notification, cached := n.notificationCached(id); cached {
+ return notification, nil
+ }
+
+ notification := &gtsmodel.Notification{}
+
+ q := n.newNotificationQ(notification).
+ Where("notification.id = ?", id)
+
+ err := processErrorResponse(q.Select())
+
+ if err == nil && notification != nil {
+ n.cacheNotification(id, notification)
+ }
+
+ return notification, err
+}
+
+func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
+ // begin by selecting just the IDs
+ notifIDs := []*gtsmodel.Notification{}
+ q := n.conn.
+ Model(&notifIDs).
+ Column("id").
+ Where("target_account_id = ?", accountID).
+ Order("id DESC")
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ q = q.Where("id > ?", sinceID)
+ }
+
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+
+ err := processErrorResponse(q.Select())
+ if err != nil {
+ return nil, err
+ }
+
+ // now we have the IDs, select the notifs one by one
+ // reason for this is that for each notif, we can instead get it from our cache if it's cached
+ notifications := []*gtsmodel.Notification{}
+ for _, notifID := range notifIDs {
+ notif, err := n.GetNotification(notifID.ID)
+ errP := processErrorResponse(err)
+ if errP != nil {
+ return nil, errP
+ }
+ notifications = append(notifications, notif)
+ }
+
+ return notifications, nil
+}
diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go
index d49c50114..0437baf02 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/pg/pg.go
@@ -20,15 +20,11 @@ package pg
import (
"context"
- "crypto/rand"
- "crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
- "net"
- "net/mail"
"os"
"strings"
"time"
@@ -41,12 +37,26 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
- "github.com/superseriousbusiness/gotosocial/internal/util"
- "golang.org/x/crypto/bcrypt"
)
+var registerTables []interface{} = []interface{}{
+ &gtsmodel.StatusToEmoji{},
+ &gtsmodel.StatusToTag{},
+}
+
// postgresService satisfies the DB interface
type postgresService struct {
+ db.Account
+ db.Admin
+ db.Basic
+ db.Domain
+ db.Instance
+ db.Media
+ db.Mention
+ db.Notification
+ db.Relationship
+ db.Status
+ db.Timeline
config *config.Config
conn *pg.DB
log *logrus.Logger
@@ -56,6 +66,11 @@ type postgresService struct {
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
+ for _, t := range registerTables {
+ // https://pg.uptrace.dev/orm/many-to-many-relation/
+ orm.RegisterTable(t)
+ }
+
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
@@ -91,6 +106,72 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log.Infof("connected to postgres version: %s", version)
ps := &postgresService{
+ Account: &accountDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Admin: &adminDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Basic: &basicDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Domain: &domainDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Instance: &instanceDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Media: &mediaDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Mention: &mentionDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Notification: &notificationDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Relationship: &relationshipDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Status: &statusDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Timeline: &timelineDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
config: c,
conn: conn,
log: log,
@@ -200,724 +281,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
}
/*
- BASIC DB FUNCTIONALITY
-*/
-
-func (ps *postgresService) CreateTable(i interface{}) error {
- return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
-}
-
-func (ps *postgresService) DropTable(i interface{}) error {
- return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
- IfExists: true,
- })
-}
-
-func (ps *postgresService) Stop(ctx context.Context) error {
- ps.log.Info("closing db connection")
- if err := ps.conn.Close(); err != nil {
- // only cancel if there's a problem closing the db
- ps.cancel()
- return err
- }
- return nil
-}
-
-func (ps *postgresService) IsHealthy(ctx context.Context) error {
- return ps.conn.Ping(ctx)
-}
-
-func (ps *postgresService) CreateSchema(ctx context.Context) error {
- models := []interface{}{
- (*gtsmodel.Account)(nil),
- (*gtsmodel.Status)(nil),
- (*gtsmodel.User)(nil),
- }
- ps.log.Info("creating db schema")
-
- for _, model := range models {
- err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
- if err != nil {
- return err
- }
- }
-
- ps.log.Info("db schema created")
- return nil
-}
-
-/*
- HANDY SHORTCUTS
-*/
-
-func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
- // make sure the original follow request exists
- fr := &gtsmodel.FollowRequest{}
- if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
- if err == pg.ErrMultiRows {
- return nil, db.ErrNoEntries{}
- }
- return nil, err
- }
-
- // create a new follow to 'replace' the request with
- follow := &gtsmodel.Follow{
- ID: fr.ID,
- AccountID: originAccountID,
- TargetAccountID: targetAccountID,
- URI: fr.URI,
- }
-
- // if the follow already exists, just update the URI -- we don't need to do anything else
- if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
- return nil, err
- }
-
- // now remove the follow request
- if _, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
- return nil, err
- }
-
- return follow, nil
-}
-
-func (ps *postgresService) CreateInstanceAccount() error {
- username := ps.config.Host
- key, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- ps.log.Errorf("error creating new rsa key: %s", err)
- return err
- }
-
- aID, err := id.NewRandomULID()
- if err != nil {
- return err
- }
-
- newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
- a := &gtsmodel.Account{
- ID: aID,
- Username: ps.config.Host,
- DisplayName: username,
- URL: newAccountURIs.UserURL,
- PrivateKey: key,
- PublicKey: &key.PublicKey,
- PublicKeyURI: newAccountURIs.PublicKeyURI,
- ActorType: gtsmodel.ActivityStreamsPerson,
- URI: newAccountURIs.UserURI,
- InboxURI: newAccountURIs.InboxURI,
- OutboxURI: newAccountURIs.OutboxURI,
- FollowersURI: newAccountURIs.FollowersURI,
- FollowingURI: newAccountURIs.FollowingURI,
- FeaturedCollectionURI: newAccountURIs.CollectionURI,
- }
- inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert()
- if err != nil {
- return err
- }
- if inserted {
- ps.log.Infof("created instance account %s with id %s", username, a.ID)
- } else {
- ps.log.Infof("instance account %s already exists with id %s", username, a.ID)
- }
- return nil
-}
-
-func (ps *postgresService) CreateInstanceInstance() error {
- iID, err := id.NewRandomULID()
- if err != nil {
- return err
- }
-
- i := &gtsmodel.Instance{
- ID: iID,
- Domain: ps.config.Host,
- Title: ps.config.Host,
- URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host),
- }
- inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert()
- if err != nil {
- return err
- }
- if inserted {
- ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID)
- } else {
- ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID)
- }
- return nil
-}
-
-func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
- user := &gtsmodel.User{
- ID: userID,
- }
- if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error {
- if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error {
- if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error {
- if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error {
-
- q := ps.conn.Model(followers)
-
- if localOnly {
- // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
- whereGroup := func(q *pg.Query) (*pg.Query, error) {
- q = q.
- WhereOr("? IS NULL", pg.Ident("a.domain")).
- WhereOr("a.domain = ?", "")
- return q, nil
- }
-
- q = q.ColumnExpr("follow.*").
- Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
- Where("follow.target_account_id = ?", accountID).
- WhereGroup(whereGroup)
- } else {
- q = q.Where("target_account_id = ?", accountID)
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error {
- if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) CountStatusesByAccountID(accountID string) (int, error) {
- count, err := ps.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
- if err != nil {
- if err == pg.ErrNoRows {
- return 0, nil
- }
- return 0, err
- }
- return count, nil
-}
-
-func (ps *postgresService) GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) {
- ps.log.Debugf("getting statuses for account %s", accountID)
- statuses := []*gtsmodel.Status{}
-
- q := ps.conn.Model(&statuses).Order("id DESC")
- if accountID != "" {
- q = q.Where("account_id = ?", accountID)
- }
-
- if limit != 0 {
- q = q.Limit(limit)
- }
-
- if excludeReplies {
- q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
- }
-
- if pinnedOnly {
- q = q.Where("pinned = ?", true)
- }
-
- if mediaOnly {
- q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
- })
- }
-
- if maxID != "" {
- q = q.Where("id < ?", maxID)
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
- }
- return nil, err
- }
-
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries{}
- }
-
- ps.log.Debugf("returning statuses for account %s", accountID)
- return statuses, nil
-}
-
-func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error {
- if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-
-}
-
-func (ps *postgresService) IsUsernameAvailable(username string) error {
- // if no error we fail because it means we found something
- // if error but it's not pg.ErrNoRows then we fail
- // if err is pg.ErrNoRows we're good, we found nothing so continue
- if err := ps.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
- return fmt.Errorf("username %s already in use", username)
- } else if err != pg.ErrNoRows {
- return fmt.Errorf("db error: %s", err)
- }
- return nil
-}
-
-func (ps *postgresService) IsEmailAvailable(email string) error {
- // parse the domain from the email
- m, err := mail.ParseAddress(email)
- if err != nil {
- return fmt.Errorf("error parsing email address %s: %s", email, err)
- }
- domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
-
- // check if the email domain is blocked
- if err := ps.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
- // fail because we found something
- return fmt.Errorf("email domain %s is blocked", domain)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
- }
-
- // check if this email is associated with a user already
- if err := ps.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
- // fail because we found something
- return fmt.Errorf("email %s already in use", email)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
- }
- return nil
-}
-
-func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) {
- key, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- ps.log.Errorf("error creating new rsa key: %s", err)
- return nil, err
- }
-
- // if something went wrong while creating a user, we might already have an account, so check here first...
- a := &gtsmodel.Account{}
- err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
- if err != nil {
- // there's been an actual error
- if err != pg.ErrNoRows {
- return nil, fmt.Errorf("db error checking existence of account: %s", err)
- }
-
- // we just don't have an account yet create one
- newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
- newAccountID, err := id.NewRandomULID()
- if err != nil {
- return nil, err
- }
-
- a = &gtsmodel.Account{
- ID: newAccountID,
- Username: username,
- DisplayName: username,
- Reason: reason,
- URL: newAccountURIs.UserURL,
- PrivateKey: key,
- PublicKey: &key.PublicKey,
- PublicKeyURI: newAccountURIs.PublicKeyURI,
- ActorType: gtsmodel.ActivityStreamsPerson,
- URI: newAccountURIs.UserURI,
- InboxURI: newAccountURIs.InboxURI,
- OutboxURI: newAccountURIs.OutboxURI,
- FollowersURI: newAccountURIs.FollowersURI,
- FollowingURI: newAccountURIs.FollowingURI,
- FeaturedCollectionURI: newAccountURIs.CollectionURI,
- }
- if _, err = ps.conn.Model(a).Insert(); err != nil {
- return nil, err
- }
- }
-
- pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
- if err != nil {
- return nil, fmt.Errorf("error hashing password: %s", err)
- }
-
- newUserID, err := id.NewRandomULID()
- if err != nil {
- return nil, err
- }
-
- u := &gtsmodel.User{
- ID: newUserID,
- AccountID: a.ID,
- EncryptedPassword: string(pw),
- SignUpIP: signUpIP.To4(),
- Locale: locale,
- UnconfirmedEmail: email,
- CreatedByApplicationID: appID,
- Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
- }
-
- if emailVerified {
- u.ConfirmedAt = time.Now()
- u.Email = email
- }
-
- if admin {
- u.Admin = true
- u.Moderator = true
- }
-
- if _, err = ps.conn.Model(u).Insert(); err != nil {
- return nil, err
- }
-
- return u, nil
-}
-
-func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
- if mediaAttachment.Avatar && mediaAttachment.Header {
- return errors.New("one media attachment cannot be both header and avatar")
- }
-
- var headerOrAVI string
- if mediaAttachment.Avatar {
- headerOrAVI = "avatar"
- } else if mediaAttachment.Header {
- headerOrAVI = "header"
- } else {
- return errors.New("given media attachment was neither a header nor an avatar")
- }
-
- // TODO: there are probably more side effects here that need to be handled
- if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
- return err
- }
-
- if _, err := ps.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error {
- acct := &gtsmodel.Account{}
- if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
-
- if acct.HeaderMediaAttachmentID == "" {
- return db.ErrNoEntries{}
- }
-
- if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error {
- acct := &gtsmodel.Account{}
- if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
-
- if acct.AvatarMediaAttachmentID == "" {
- return db.ErrNoEntries{}
- }
-
- if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) {
- // TODO: check domain blocks as well
- var blocked bool
- if err := ps.conn.Model(&gtsmodel.Block{}).
- Where("account_id = ?", account1).Where("target_account_id = ?", account2).
- WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2).
- Select(); err != nil {
- if err == pg.ErrNoRows {
- blocked = false
- return blocked, nil
- }
- return blocked, err
- }
- blocked = true
- return blocked, nil
-}
-
-func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
- r := &gtsmodel.Relationship{
- ID: targetAccount,
- }
-
- // check if the requesting account follows the target account
- follow := &gtsmodel.Follow{}
- if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
- if err != pg.ErrNoRows {
- // a proper error
- return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
- }
- // no follow exists so these are all false
- r.Following = false
- r.ShowingReblogs = false
- r.Notifying = false
- } else {
- // follow exists so we can fill these fields out...
- r.Following = true
- r.ShowingReblogs = follow.ShowReblogs
- r.Notifying = follow.Notify
- }
-
- // check if the target account follows the requesting account
- followedBy, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
- }
- r.FollowedBy = followedBy
-
- // check if the requesting account blocks the target account
- blocking, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
- }
- r.Blocking = blocking
-
- // check if the target account blocks the requesting account
- blockedBy, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
- }
- r.BlockedBy = blockedBy
-
- // check if there's a pending following request from requesting account to target account
- requested, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
- }
- r.Requested = requested
-
- return r, nil
-}
-
-func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- return ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
-}
-
-func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- return ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
-}
-
-func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
- if account1 == nil || account2 == nil {
- return false, nil
- }
-
- // make sure account 1 follows account 2
- f1, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
- if err != nil {
- if err == pg.ErrNoRows {
- return false, nil
- }
- return false, err
- }
-
- // make sure account 2 follows account 1
- f2, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
- if err != nil {
- if err == pg.ErrNoRows {
- return false, nil
- }
- return false, err
- }
-
- return f1 && f2, nil
-}
-
-func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) {
- return ps.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
-}
-
-func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) {
- return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
-}
-
-func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) {
- return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
-}
-
-func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
- accounts := []*gtsmodel.Account{}
-
- faves := []*gtsmodel.StatusFave{}
- if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return accounts, nil // no rows just means nobody has faved this status, so that's fine
- }
- return nil, err // an actual error has occurred
- }
-
- for _, f := range faves {
- acc := &gtsmodel.Account{}
- if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
- }
- return nil, err // an actual error has occurred
- }
- accounts = append(accounts, acc)
- }
- return accounts, nil
-}
-
-func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
- accounts := []*gtsmodel.Account{}
-
- boosts := []*gtsmodel.Status{}
- if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return accounts, nil // no rows just means nobody has boosted this status, so that's fine
- }
- return nil, err // an actual error has occurred
- }
-
- for _, f := range boosts {
- acc := &gtsmodel.Account{}
- if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
- }
- return nil, err // an actual error has occurred
- }
- accounts = append(accounts, acc)
- }
- return accounts, nil
-}
-
-func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
- notifications := []*gtsmodel.Notification{}
-
- q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
-
- if maxID != "" {
- q = q.Where("id < ?", maxID)
- }
-
- if sinceID != "" {
- q = q.Where("id > ?", sinceID)
- }
-
- if limit != 0 {
- q = q.Limit(limit)
- }
-
- q = q.Order("created_at DESC")
-
- if err := q.Select(); err != nil {
- if err != pg.ErrNoRows {
- return nil, err
- }
-
- }
- return notifications, nil
-}
-
-/*
CONVERSION FUNCTIONS
*/
@@ -988,14 +351,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// id, createdAt and updatedAt will be populated by the db, so we have everything we need!
menchies = append(menchies, &gtsmodel.Mention{
- StatusID: statusID,
- OriginAccountID: ogAccount.ID,
- OriginAccountURI: ogAccount.URI,
- TargetAccountID: mentionedAccount.ID,
- NameString: a,
- MentionedAccountURI: mentionedAccount.URI,
- MentionedAccountURL: mentionedAccount.URL,
- GTSAccount: mentionedAccount,
+ StatusID: statusID,
+ OriginAccountID: ogAccount.ID,
+ OriginAccountURI: ogAccount.URI,
+ TargetAccountID: mentionedAccount.ID,
+ NameString: a,
+ TargetAccountURI: mentionedAccount.URI,
+ TargetAccountURL: mentionedAccount.URL,
+ OriginAccount: mentionedAccount,
})
}
return menchies, nil
diff --git a/internal/db/pg/pg_test.go b/internal/db/pg/pg_test.go
new file mode 100644
index 000000000..c1e10abdf
--- /dev/null
+++ b/internal/db/pg/pg_test.go
@@ -0,0 +1,47 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg_test
+
+import (
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+)
+
+type PGStandardTestSuite struct {
+ // standard suite interfaces
+ suite.Suite
+ config *config.Config
+ db db.DB
+ log *logrus.Logger
+
+ // standard suite models
+ testTokens map[string]*oauth.Token
+ testClients map[string]*oauth.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+ testAttachments map[string]*gtsmodel.MediaAttachment
+ testStatuses map[string]*gtsmodel.Status
+ testTags map[string]*gtsmodel.Tag
+ testMentions map[string]*gtsmodel.Mention
+}
diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go
new file mode 100644
index 000000000..76bd50c76
--- /dev/null
+++ b/internal/db/pg/relationship.go
@@ -0,0 +1,276 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type relationshipDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
+ return r.conn.Model(block).
+ Relation("Account").
+ Relation("TargetAccount")
+}
+
+func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query {
+ return r.conn.Model(follow).
+ Relation("Account").
+ Relation("TargetAccount")
+}
+
+func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
+ q := r.conn.
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", account1).
+ Where("target_account_id = ?", account2)
+
+ if eitherDirection {
+ q = q.
+ WhereOr("target_account_id = ?", account1).
+ Where("account_id = ?", account2)
+ }
+
+ return q.Exists()
+}
+
+func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
+ block := &gtsmodel.Block{}
+
+ q := r.newBlockQ(block).
+ Where("block.account_id = ?", account1).
+ Where("block.target_account_id = ?", account2)
+
+ err := processErrorResponse(q.Select())
+
+ return block, err
+}
+
+func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
+ rel := &gtsmodel.Relationship{
+ ID: targetAccount,
+ }
+
+ // check if the requesting account follows the target account
+ follow := &gtsmodel.Follow{}
+ if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
+ if err != pg.ErrNoRows {
+ // a proper error
+ return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
+ }
+ // no follow exists so these are all false
+ rel.Following = false
+ rel.ShowingReblogs = false
+ rel.Notifying = false
+ } else {
+ // follow exists so we can fill these fields out...
+ rel.Following = true
+ rel.ShowingReblogs = follow.ShowReblogs
+ rel.Notifying = follow.Notify
+ }
+
+ // check if the target account follows the requesting account
+ followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
+ }
+ rel.FollowedBy = followedBy
+
+ // check if the requesting account blocks the target account
+ blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
+ }
+ rel.Blocking = blocking
+
+ // check if the target account blocks the requesting account
+ blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ rel.BlockedBy = blockedBy
+
+ // check if there's a pending following request from requesting account to target account
+ requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ rel.Requested = requested
+
+ return rel, nil
+}
+
+func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+ if sourceAccount == nil || targetAccount == nil {
+ return false, nil
+ }
+
+ q := r.conn.
+ Model(&gtsmodel.Follow{}).
+ Where("account_id = ?", sourceAccount.ID).
+ Where("target_account_id = ?", targetAccount.ID)
+
+ return q.Exists()
+}
+
+func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+ if sourceAccount == nil || targetAccount == nil {
+ return false, nil
+ }
+
+ q := r.conn.
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", sourceAccount.ID).
+ Where("target_account_id = ?", targetAccount.ID)
+
+ return q.Exists()
+}
+
+func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
+ if account1 == nil || account2 == nil {
+ return false, nil
+ }
+
+ // make sure account 1 follows account 2
+ f1, err := r.IsFollowing(account1, account2)
+ if err != nil {
+ return false, processErrorResponse(err)
+ }
+
+ // make sure account 2 follows account 1
+ f2, err := r.IsFollowing(account2, account1)
+ if err != nil {
+ return false, processErrorResponse(err)
+ }
+
+ return f1 && f2, nil
+}
+
+func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
+ // make sure the original follow request exists
+ fr := &gtsmodel.FollowRequest{}
+ if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
+ if err == pg.ErrMultiRows {
+ return nil, db.ErrNoEntries
+ }
+ return nil, err
+ }
+
+ // create a new follow to 'replace' the request with
+ follow := &gtsmodel.Follow{
+ ID: fr.ID,
+ AccountID: originAccountID,
+ TargetAccountID: targetAccountID,
+ URI: fr.URI,
+ }
+
+ // if the follow already exists, just update the URI -- we don't need to do anything else
+ if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
+ return nil, err
+ }
+
+ // now remove the follow request
+ if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
+ return nil, err
+ }
+
+ return follow, nil
+}
+
+func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
+ followRequests := []*gtsmodel.FollowRequest{}
+
+ q := r.newFollowQ(&followRequests).
+ Where("target_account_id = ?", accountID)
+
+ err := processErrorResponse(q.Select())
+
+ return followRequests, err
+}
+
+func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) {
+ follows := []*gtsmodel.Follow{}
+
+ q := r.newFollowQ(&follows).
+ Where("account_id = ?", accountID)
+
+ err := processErrorResponse(q.Select())
+
+ return follows, err
+}
+
+func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) {
+ return r.conn.
+ Model(&[]*gtsmodel.Follow{}).
+ Where("account_id = ?", accountID).
+ Count()
+}
+
+func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
+
+ follows := []*gtsmodel.Follow{}
+
+ q := r.conn.Model(&follows)
+
+ if localOnly {
+ // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
+ whereGroup := func(q *pg.Query) (*pg.Query, error) {
+ q = q.
+ WhereOr("? IS NULL", pg.Ident("a.domain")).
+ WhereOr("a.domain = ?", "")
+ return q, nil
+ }
+
+ q = q.ColumnExpr("follow.*").
+ Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
+ Where("follow.target_account_id = ?", accountID).
+ WhereGroup(whereGroup)
+ } else {
+ q = q.Where("target_account_id = ?", accountID)
+ }
+
+ if err := q.Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return follows, nil
+ }
+ return nil, err
+ }
+ return follows, nil
+}
+
+func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) {
+ return r.conn.
+ Model(&[]*gtsmodel.Follow{}).
+ Where("target_account_id = ?", accountID).
+ Count()
+}
diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go
new file mode 100644
index 000000000..99790428e
--- /dev/null
+++ b/internal/db/pg/status.go
@@ -0,0 +1,318 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg
+
+import (
+ "container/list"
+ "context"
+ "errors"
+ "time"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type statusDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+ cache cache.Cache
+}
+
+func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ }
+
+ if err := s.cache.Store(id, status); err != nil {
+ s.log.Panicf("statusDB: error storing in cache: %s", err)
+ }
+}
+
+func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ return nil, false
+ }
+
+ sI, err := s.cache.Fetch(id)
+ if err != nil || sI == nil {
+ return nil, false
+ }
+
+ status, ok := sI.(*gtsmodel.Status)
+ if !ok {
+ s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
+ }
+
+ return status, true
+}
+
+func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
+ return s.conn.Model(status).
+ Relation("Attachments").
+ Relation("Tags").
+ Relation("Mentions").
+ Relation("Emojis").
+ Relation("Account").
+ Relation("InReplyTo").
+ Relation("InReplyToAccount").
+ Relation("BoostOf").
+ Relation("BoostOfAccount").
+ Relation("CreatedWithApplication")
+}
+
+func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
+ return s.conn.Model(faves).
+ Relation("Account").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(id); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("status.id = ?", id)
+
+ err := processErrorResponse(q.Select())
+
+ if err == nil && status != nil {
+ s.cacheStatus(id, status)
+ }
+
+ return status, err
+}
+
+func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.uri) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Select())
+
+ if err == nil && status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return status, err
+}
+
+func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.url) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Select())
+
+ if err == nil && status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return status, err
+}
+
+func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
+ transaction := func(tx *pg.Tx) error {
+ // create links between this status and any emojis it uses
+ for _, i := range status.EmojiIDs {
+ if _, err := tx.Model(&gtsmodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Insert(); err != nil {
+ return err
+ }
+ }
+
+ // create links between this status and any tags it uses
+ for _, i := range status.TagIDs {
+ if _, err := tx.Model(&gtsmodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Insert(); err != nil {
+ return err
+ }
+ }
+
+ // change the status ID of the media attachments to the new status
+ for _, a := range status.Attachments {
+ a.StatusID = status.ID
+ a.UpdatedAt = time.Now()
+ if _, err := s.conn.Model(a).
+ Where("id = ?", a.ID).
+ Update(); err != nil {
+ return err
+ }
+ }
+
+ _, err := tx.Model(status).Insert()
+ return err
+ }
+
+ return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
+}
+
+func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
+ parents := []*gtsmodel.Status{}
+ s.statusParent(status, &parents, onlyDirect)
+
+ return parents, nil
+}
+
+func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
+ if status.InReplyToID == "" {
+ return
+ }
+
+ parentStatus, err := s.GetStatusByID(status.InReplyToID)
+ if err == nil {
+ *foundStatuses = append(*foundStatuses, parentStatus)
+ }
+
+ if onlyDirect {
+ return
+ }
+
+ s.statusParent(parentStatus, foundStatuses, false)
+}
+
+func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
+ foundStatuses := &list.List{}
+ foundStatuses.PushFront(status)
+ s.statusChildren(status, foundStatuses, onlyDirect, minID)
+
+ children := []*gtsmodel.Status{}
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ // only append children, not the overall parent status
+ if entry.ID != status.ID {
+ children = append(children, entry)
+ }
+ }
+
+ return children, nil
+}
+
+func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
+ immediateChildren := []*gtsmodel.Status{}
+
+ q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
+ if minID != "" {
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if err := q.Select(); err != nil {
+ return
+ }
+
+ for _, child := range immediateChildren {
+ insertLoop:
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
+ foundStatuses.InsertAfter(child, e)
+ break insertLoop
+ }
+ }
+
+ // only do one loop if we only want direct children
+ if onlyDirect {
+ return
+ }
+ s.statusChildren(child, foundStatuses, false, minID)
+ }
+}
+
+func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
+}
+
+func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
+}
+
+func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
+}
+
+func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+}
+
+func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+}
+
+func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+}
+
+func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+}
+
+func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := []*gtsmodel.StatusFave{}
+
+ q := s.newFaveQ(&faves).
+ Where("status_id = ?", status.ID)
+
+ err := processErrorResponse(q.Select())
+
+ return faves, err
+}
+
+func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
+ reblogs := []*gtsmodel.Status{}
+
+ q := s.newStatusQ(&reblogs).
+ Where("boost_of_id = ?", status.ID)
+
+ err := processErrorResponse(q.Select())
+
+ return reblogs, err
+}
diff --git a/internal/db/pg/status_test.go b/internal/db/pg/status_test.go
new file mode 100644
index 000000000..8a185757c
--- /dev/null
+++ b/internal/db/pg/status_test.go
@@ -0,0 +1,134 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package pg_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type StatusTestSuite struct {
+ PGStandardTestSuite
+}
+
+func (suite *StatusTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *StatusTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *StatusTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *StatusTestSuite) TestGetStatusByID() {
+ status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusByURI() {
+ status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusWithExtras() {
+ status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.NotEmpty(status.Tags)
+ suite.NotEmpty(status.Attachments)
+ suite.NotEmpty(status.Emojis)
+}
+
+func (suite *StatusTestSuite) TestGetStatusWithMention() {
+ status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.NotEmpty(status.Mentions)
+ suite.NotEmpty(status.MentionIDs)
+ suite.NotNil(status.InReplyTo)
+ suite.NotNil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusTwice() {
+ before1 := time.Now()
+ _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ suite.NoError(err)
+ after1 := time.Now()
+ duration1 := after1.Sub(before1)
+ fmt.Println(duration1.Nanoseconds())
+
+ before2 := time.Now()
+ _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ suite.NoError(err)
+ after2 := time.Now()
+ duration2 := after2.Sub(before2)
+ fmt.Println(duration2.Nanoseconds())
+
+ // second retrieval should be several orders faster since it will be cached now
+ suite.Less(duration2, duration1)
+}
+
+func TestStatusTestSuite(t *testing.T) {
+ suite.Run(t, new(StatusTestSuite))
+}
diff --git a/internal/db/pg/statuscontext.go b/internal/db/pg/statuscontext.go
deleted file mode 100644
index 2ff1a20bb..000000000
--- a/internal/db/pg/statuscontext.go
+++ /dev/null
@@ -1,104 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "container/list"
- "errors"
-
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
- parents := []*gtsmodel.Status{}
- ps.statusParent(status, &parents, onlyDirect)
-
- return parents, nil
-}
-
-func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
- if status.InReplyToID == "" {
- return
- }
-
- parentStatus := &gtsmodel.Status{}
- if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
- *foundStatuses = append(*foundStatuses, parentStatus)
- }
-
- if onlyDirect {
- return
- }
- ps.statusParent(parentStatus, foundStatuses, false)
-}
-
-func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
- foundStatuses := &list.List{}
- foundStatuses.PushFront(status)
- ps.statusChildren(status, foundStatuses, onlyDirect, minID)
-
- children := []*gtsmodel.Status{}
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- // only append children, not the overall parent status
- if entry.ID != status.ID {
- children = append(children, entry)
- }
- }
-
- return children, nil
-}
-
-func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
- immediateChildren := []*gtsmodel.Status{}
-
- q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
- if minID != "" {
- q = q.Where("status.id > ?", minID)
- }
-
- if err := q.Select(); err != nil {
- return
- }
-
- for _, child := range immediateChildren {
- insertLoop:
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
- foundStatuses.InsertAfter(child, e)
- break insertLoop
- }
- }
-
- // only do one loop if we only want direct children
- if onlyDirect {
- return
- }
- ps.statusChildren(child, foundStatuses, false, minID)
- }
-}
diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go
index 585ca3067..fa8b07aab 100644
--- a/internal/db/pg/timeline.go
+++ b/internal/db/pg/timeline.go
@@ -19,16 +19,26 @@
package pg
import (
+ "context"
"sort"
"github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
+type timelineDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
- q := ps.conn.Model(&statuses)
+ q := t.conn.Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
@@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return statuses, nil
}
-func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
+func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
- q := ps.conn.Model(&statuses).
+ q := t.conn.Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
@@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return statuses, nil
@@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
-func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
+func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
- fq := ps.conn.Model(&faves).
+ fq := t.conn.Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(faves) == 0 {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
@@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
}
statuses := []*gtsmodel.Status{}
- err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
+ err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(statuses) == 0 {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
// arrange statuses by fave ID
diff --git a/internal/db/pg/update.go b/internal/db/pg/update.go
deleted file mode 100644
index f6bc70ad9..000000000
--- a/internal/db/pg/update.go
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "fmt"
-
- "github.com/go-pg/pg/v10"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error {
- if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) UpdateByID(id string, i interface{}) error {
- if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
- _, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
- return err
-}
-
-func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error {
- q := ps.conn.Model(i)
-
- for _, w := range where {
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- q = q.Set("? = ?", pg.Safe(key), value)
-
- _, err := q.Update()
-
- return err
-}
diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go
new file mode 100644
index 000000000..17c09b720
--- /dev/null
+++ b/internal/db/pg/util.go
@@ -0,0 +1,25 @@
+package pg
+
+import (
+ "strings"
+
+ "github.com/go-pg/pg/v10"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+)
+
+// processErrorResponse parses the given error and returns an appropriate DBError.
+func processErrorResponse(err error) db.Error {
+ switch err {
+ case nil:
+ return nil
+ case pg.ErrNoRows:
+ return db.ErrNoEntries
+ case pg.ErrMultiRows:
+ return db.ErrMultipleEntries
+ default:
+ if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+ }
+}
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
new file mode 100644
index 000000000..85f64d72b
--- /dev/null
+++ b/internal/db/relationship.go
@@ -0,0 +1,71 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+
+// Relationship contains functions for getting or modifying the relationship between two accounts.
+type Relationship interface {
+ // IsBlocked checks whether account 1 has a block in place against block2.
+ // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
+ IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
+
+ // GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
+ //
+ // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
+ // not if you're just checking for the existence of a block.
+ GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
+
+ // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
+ GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
+
+ // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
+ IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+
+ // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
+ IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+
+ // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
+ IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
+
+ // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
+ // In other words, it should create the follow, and delete the existing follow request.
+ //
+ // It will return the newly created follow for further processing.
+ AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
+
+ // GetAccountFollowRequests returns all follow requests targeting the given account.
+ GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
+
+ // GetAccountFollows returns a slice of follows owned by the given accountID.
+ GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
+
+ // CountAccountFollows returns the amount of accounts that the given accountID is following.
+ //
+ // If localOnly is set to true, then only follows from *this instance* will be returned.
+ CountAccountFollows(accountID string, localOnly bool) (int, Error)
+
+ // GetAccountFollowedBy fetches follows that target given accountID.
+ //
+ // If localOnly is set to true, then only follows from *this instance* will be returned.
+ GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
+
+ // CountAccountFollowedBy returns the amounts that the given ID is followed by.
+ CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
+}
diff --git a/internal/db/status.go b/internal/db/status.go
new file mode 100644
index 000000000..9d206c198
--- /dev/null
+++ b/internal/db/status.go
@@ -0,0 +1,75 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+
+// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
+type Status interface {
+ // GetStatusByID returns one status from the database, with all rel fields populated (if possible).
+ GetStatusByID(id string) (*gtsmodel.Status, Error)
+
+ // GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
+ GetStatusByURI(uri string) (*gtsmodel.Status, Error)
+
+ // GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
+ GetStatusByURL(uri string) (*gtsmodel.Status, Error)
+
+ // PutStatus stores one status in the database.
+ PutStatus(status *gtsmodel.Status) Error
+
+ // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
+ CountStatusReplies(status *gtsmodel.Status) (int, Error)
+
+ // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
+ CountStatusReblogs(status *gtsmodel.Status) (int, Error)
+
+ // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
+ CountStatusFaves(status *gtsmodel.Status) (int, Error)
+
+ // GetStatusParents gets the parent statuses of a given status.
+ //
+ // If onlyDirect is true, only the immediate parent will be returned.
+ GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
+
+ // GetStatusChildren gets the child statuses of a given status.
+ //
+ // If onlyDirect is true, only the immediate children will be returned.
+ GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
+
+ // IsStatusFavedBy checks if a given status has been faved by a given account ID
+ IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+
+ // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
+ IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+
+ // IsStatusMutedBy checks if a given status has been muted by a given account ID
+ IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+
+ // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
+ IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+
+ // GetStatusFaves returns a slice of faves/likes of the given status.
+ // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
+ GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
+
+ // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
+ // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
+ GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
+}
diff --git a/internal/db/timeline.go b/internal/db/timeline.go
new file mode 100644
index 000000000..74aa5c781
--- /dev/null
+++ b/internal/db/timeline.go
@@ -0,0 +1,44 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+
+// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
+type Timeline interface {
+ // GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
+ //
+ // Statuses should be returned in descending order of when they were created (newest first).
+ GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+
+ // GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
+ // It will use the given filters and try to return as many statuses as possible up to the limit.
+ //
+ // Statuses should be returned in descending order of when they were created (newest first).
+ GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+
+ // GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
+ // It will use the given filters and try to return as many statuses as possible up to the limit.
+ //
+ // Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id.
+ // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
+ //
+ // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
+ GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
+}