diff options
author | 2021-08-20 12:26:56 +0200 | |
---|---|---|
committer | 2021-08-20 12:26:56 +0200 | |
commit | 4920229a3b6e1d7dde536bc9ff766542b05d935c (patch) | |
tree | a9423beccec5331c372f01eedf38949dfb171e9e /internal/db | |
parent | Text/status parsing fixes (#141) (diff) | |
download | gotosocial-4920229a3b6e1d7dde536bc9ff766542b05d935c.tar.xz |
Database updates (#144)
* start moving some database stuff around
* continue moving db stuff around
* more fiddling
* more updates
* and some more
* and yet more
* i broke SOMETHING but what, it's a mystery
* tidy up
* vendor ttlcache
* use ttlcache
* fix up some tests
* rename some stuff
* little reminder
* some more updates
Diffstat (limited to 'internal/db')
33 files changed, 2623 insertions, 1386 deletions
diff --git a/internal/db/account.go b/internal/db/account.go new file mode 100644 index 000000000..0e1575f9b --- /dev/null +++ b/internal/db/account.go @@ -0,0 +1,66 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import ( + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Account contains functions related to account getting/setting/creation. +type Account interface { + // GetAccountByID returns one account with the given ID, or an error if something goes wrong. + GetAccountByID(id string) (*gtsmodel.Account, Error) + + // GetAccountByURI returns one account with the given URI, or an error if something goes wrong. + GetAccountByURI(uri string) (*gtsmodel.Account, Error) + + // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. + GetAccountByURL(uri string) (*gtsmodel.Account, Error) + + // GetLocalAccountByUsername returns an account on this instance by its username. + GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error) + + // GetAccountFaves fetches faves/likes created by the target accountID. + GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error) + + // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. + CountAccountStatuses(accountID string) (int, Error) + + // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided + // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can + // be very memory intensive so you probably shouldn't do this! + // In case of no entries, a 'no entries' error will be returned + GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error) + + GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) + + // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. + // + // The returned time will be zero if account has never posted anything. + GetAccountLastPosted(accountID string) (time.Time, Error) + + // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. + SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error + + // GetInstanceAccount returns the instance account for the given domain. + // If domain is empty, this instance account will be returned. + GetInstanceAccount(domain string) (*gtsmodel.Account, Error) +} diff --git a/internal/db/admin.go b/internal/db/admin.go new file mode 100644 index 000000000..aa2b22f47 --- /dev/null +++ b/internal/db/admin.go @@ -0,0 +1,53 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import ( + "net" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Admin contains functions related to instance administration (new signups etc). +type Admin interface { + // IsUsernameAvailable checks whether a given username is available on our domain. + // Returns an error if the username is already taken, or something went wrong in the db. + IsUsernameAvailable(username string) Error + + // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. + // Return an error if: + // A) the email is already associated with an account + // B) we block signups from this email domain + // C) something went wrong in the db + IsEmailAvailable(email string) Error + + // NewSignup creates a new user in the database with the given parameters. + // By the time this function is called, it should be assumed that all the parameters have passed validation! + NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error) + + // CreateInstanceAccount creates an account in the database with the same username as the instance host value. + // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. + // This is needed for things like serving files that belong to the instance and not an individual user/account. + CreateInstanceAccount() Error + + // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. + // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. + // This is needed for things like serving instance information through /api/v1/instance + CreateInstanceInstance() Error +} diff --git a/internal/db/basic.go b/internal/db/basic.go new file mode 100644 index 000000000..729920bba --- /dev/null +++ b/internal/db/basic.go @@ -0,0 +1,87 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "context" + +// Basic wraps basic database functionality. +type Basic interface { + // CreateTable creates a table for the given interface. + // For implementations that don't use tables, this can just return nil. + CreateTable(i interface{}) Error + + // DropTable drops the table for the given interface. + // For implementations that don't use tables, this can just return nil. + DropTable(i interface{}) Error + + // RegisterTable registers a table for use in many2many relations. + // For implementations that don't use tables, or many2many relations, this can just return nil. + RegisterTable(i interface{}) Error + + // Stop should stop and close the database connection cleanly, returning an error if this is not possible. + // If the database implementation doesn't need to be stopped, this can just return nil. + Stop(ctx context.Context) Error + + // IsHealthy should return nil if the database connection is healthy, or an error if not. + IsHealthy(ctx context.Context) Error + + // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry, + // for other implementations (for example, in-memory) it might just be the key of a map. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + // In case of no entries, a 'no entries' error will be returned + GetByID(id string, i interface{}) Error + + // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the + // name of the key to select from. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + // In case of no entries, a 'no entries' error will be returned + GetWhere(where []Where, i interface{}) Error + + // GetAll will try to get all entries of type i. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + // In case of no entries, a 'no entries' error will be returned + GetAll(i interface{}) Error + + // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + Put(i interface{}) Error + + // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/ + // It is up to the implementation to figure out how to store it, and using what key. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + Upsert(i interface{}, conflictColumn string) Error + + // UpdateByID updates i with id id. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + UpdateByID(id string, i interface{}) Error + + // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. + UpdateOneByID(id string, key string, value interface{}, i interface{}) Error + + // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. + UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error + + // DeleteByID removes i with id id. + // If i didn't exist anyway, then no error should be returned. + DeleteByID(id string, i interface{}) Error + + // DeleteWhere deletes i where key = value + // If i didn't exist anyway, then no error should be returned. + DeleteWhere(where []Where, i interface{}) Error +} diff --git a/internal/db/db.go b/internal/db/db.go index d0b23fbc6..d6ac883e4 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -19,9 +19,6 @@ package db import ( - "context" - "net" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -30,257 +27,19 @@ const ( DBTypePostgres string = "POSTGRES" ) -// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres). -// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated -// by whatever is returned from the database. +// DB provides methods for interacting with an underlying database or other storage mechanism. type DB interface { - /* - BASIC DB FUNCTIONALITY - */ - - // CreateTable creates a table for the given interface. - // For implementations that don't use tables, this can just return nil. - CreateTable(i interface{}) error - - // DropTable drops the table for the given interface. - // For implementations that don't use tables, this can just return nil. - DropTable(i interface{}) error - - // Stop should stop and close the database connection cleanly, returning an error if this is not possible. - // If the database implementation doesn't need to be stopped, this can just return nil. - Stop(ctx context.Context) error - - // IsHealthy should return nil if the database connection is healthy, or an error if not. - IsHealthy(ctx context.Context) error - - // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry, - // for other implementations (for example, in-memory) it might just be the key of a map. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - // In case of no entries, a 'no entries' error will be returned - GetByID(id string, i interface{}) error - - // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the - // name of the key to select from. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - // In case of no entries, a 'no entries' error will be returned - GetWhere(where []Where, i interface{}) error - - // GetAll will try to get all entries of type i. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - // In case of no entries, a 'no entries' error will be returned - GetAll(i interface{}) error - - // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Put(i interface{}) error - - // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/ - // It is up to the implementation to figure out how to store it, and using what key. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Upsert(i interface{}, conflictColumn string) error - - // UpdateByID updates i with id id. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - UpdateByID(id string, i interface{}) error - - // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. - UpdateOneByID(id string, key string, value interface{}, i interface{}) error - - // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. - UpdateWhere(where []Where, key string, value interface{}, i interface{}) error - - // DeleteByID removes i with id id. - // If i didn't exist anyway, then no error should be returned. - DeleteByID(id string, i interface{}) error - - // DeleteWhere deletes i where key = value - // If i didn't exist anyway, then no error should be returned. - DeleteWhere(where []Where, i interface{}) error - - /* - HANDY SHORTCUTS - */ - - // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. - // In other words, it should create the follow, and delete the existing follow request. - // - // It will return the newly created follow for further processing. - AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) - - // CreateInstanceAccount creates an account in the database with the same username as the instance host value. - // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. - // This is needed for things like serving files that belong to the instance and not an individual user/account. - CreateInstanceAccount() error - - // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. - // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. - // This is needed for things like serving instance information through /api/v1/instance - CreateInstanceInstance() error - - // GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID. - // The given account pointer will be set to the result of the query, whatever it is. - // In case of no entries, a 'no entries' error will be returned - GetAccountByUserID(userID string, account *gtsmodel.Account) error - - // GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE - // according to its username, which should be unique. - // The given account pointer will be set to the result of the query, whatever it is. - // In case of no entries, a 'no entries' error will be returned - GetLocalAccountByUsername(username string, account *gtsmodel.Account) error - - // GetFollowRequestsForAccountID is a shortcut for the common action of fetching a list of follow requests targeting the given account ID. - // The given slice 'followRequests' will be set to the result of the query, whatever it is. - // In case of no entries, a 'no entries' error will be returned - GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error - - // GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following. - // The given slice 'following' will be set to the result of the query, whatever it is. - // In case of no entries, a 'no entries' error will be returned - GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error - - // GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by. - // The given slice 'followers' will be set to the result of the query, whatever it is. - // In case of no entries, a 'no entries' error will be returned - // - // If localOnly is set to true, then only followers from *this instance* will be returned. - GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error - - // GetFavesByAccountID is a shortcut for the common action of fetching a list of faves made by the given accountID. - // The given slice 'faves' will be set to the result of the query, whatever it is. - // In case of no entries, a 'no entries' error will be returned - GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error - - // CountStatusesByAccountID is a shortcut for the common action of counting statuses produced by accountID. - CountStatusesByAccountID(accountID string) (int, error) - - // GetStatusesForAccount is a shortcut for getting the most recent statuses. accountID is optional, if not provided - // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can - // be very memory intensive so you probably shouldn't do this! - // In case of no entries, a 'no entries' error will be returned - GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) - - GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) - - // GetLastStatusForAccountID simply gets the most recent status by the given account. - // The given slice 'status' pointer will be set to the result of the query, whatever it is. - // In case of no entries, a 'no entries' error will be returned - GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error - - // IsUsernameAvailable checks whether a given username is available on our domain. - // Returns an error if the username is already taken, or something went wrong in the db. - IsUsernameAvailable(username string) error - - // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. - // Return an error if: - // A) the email is already associated with an account - // B) we block signups from this email domain - // C) something went wrong in the db - IsEmailAvailable(email string) error - - // NewSignup creates a new user in the database with the given parameters. - // By the time this function is called, it should be assumed that all the parameters have passed validation! - NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) - - // SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment. - SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error - - // GetHeaderAvatarForAccountID gets the current avatar for the given account ID. - // The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists. - GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error - - // GetHeaderForAccountID gets the current header for the given account ID. - // The passed mediaAttachment pointer will be populated with the value of the header, if it exists. - GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error - - // Blocked checks whether a block exists in eiher direction between two accounts. - // That is, it returns true if account1 blocks account2, OR if account2 blocks account1. - Blocked(account1 string, account2 string) (bool, error) - - // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. - GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) - - // Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. - Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) - - // FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. - FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) - - // Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. - Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) - - // GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong - GetReplyCountForStatus(status *gtsmodel.Status) (int, error) - - // GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong - GetReblogCountForStatus(status *gtsmodel.Status) (int, error) - - // GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong - GetFaveCountForStatus(status *gtsmodel.Status) (int, error) - - // StatusParents get the parent statuses of a given status. - // - // If onlyDirect is true, only the immediate parent will be returned. - StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) - - // StatusChildren gets the child statuses of a given status. - // - // If onlyDirect is true, only the immediate children will be returned. - StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) - - // StatusFavedBy checks if a given status has been faved by a given account ID - StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) - - // StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID - StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) - - // StatusMutedBy checks if a given status has been muted by a given account ID - StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) - - // StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID - StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) - - // WhoFavedStatus returns a slice of accounts who faved the given status. - // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) - - // WhoBoostedStatus returns a slice of accounts who boosted the given status. - // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) - - // GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id. - // - // Statuses should be returned in descending order of when they were created (newest first). - GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) - - // GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public. - // It will use the given filters and try to return as many statuses as possible up to the limit. - // - // Statuses should be returned in descending order of when they were created (newest first). - GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) - - // GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. - // It will use the given filters and try to return as many statuses as possible up to the limit. - // - // Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id. - // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. - // - // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. - GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) - - // GetNotificationsForAccount returns a list of notifications that pertain to the given accountID. - GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) - - // GetUserCountForInstance returns the number of known accounts registered with the given domain. - GetUserCountForInstance(domain string) (int, error) - - // GetStatusCountForInstance returns the number of known statuses posted from the given domain. - GetStatusCountForInstance(domain string) (int, error) - - // GetDomainCountForInstance returns the number of known instances known that the given domain federates with. - GetDomainCountForInstance(domain string) (int, error) - - // GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID. - GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) + Account + Admin + Basic + Domain + Instance + Media + Mention + Notification + Relationship + Status + Timeline /* USEFUL CONVERSION FUNCTIONS diff --git a/internal/db/domain.go b/internal/db/domain.go new file mode 100644 index 000000000..a6583c80c --- /dev/null +++ b/internal/db/domain.go @@ -0,0 +1,36 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "net/url" + +// Domain contains DB functions related to domains and domain blocks. +type Domain interface { + // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). + IsDomainBlocked(domain string) (bool, Error) + + // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found. + AreDomainsBlocked(domains []string) (bool, Error) + + // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`). + IsURIBlocked(uri *url.URL) (bool, Error) + + // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found. + AreURIsBlocked(uris []*url.URL) (bool, Error) +} diff --git a/internal/db/error.go b/internal/db/error.go index 197c7bd68..c13bd78dd 100644 --- a/internal/db/error.go +++ b/internal/db/error.go @@ -18,16 +18,18 @@ package db -// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query. -type ErrNoEntries struct{} - -func (e ErrNoEntries) Error() string { - return "no entries" -} - -// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints. -type ErrAlreadyExists struct{} - -func (e ErrAlreadyExists) Error() string { - return "already exists" -} +import "fmt" + +// Error denotes a database error. +type Error error + +var ( + // ErrNoEntries is returned when a caller expected an entry for a query, but none was found. + ErrNoEntries Error = fmt.Errorf("no entries") + // ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. + ErrMultipleEntries Error = fmt.Errorf("multiple entries") + // ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db. + ErrAlreadyExists Error = fmt.Errorf("already exists") + // ErrUnknown denotes an unknown database error. + ErrUnknown Error = fmt.Errorf("unknown error") +) diff --git a/internal/db/instance.go b/internal/db/instance.go new file mode 100644 index 000000000..1f7c83e4f --- /dev/null +++ b/internal/db/instance.go @@ -0,0 +1,36 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + +// Instance contains functions for instance-level actions (counting instance users etc.). +type Instance interface { + // CountInstanceUsers returns the number of known accounts registered with the given domain. + CountInstanceUsers(domain string) (int, Error) + + // CountInstanceStatuses returns the number of known statuses posted from the given domain. + CountInstanceStatuses(domain string) (int, Error) + + // CountInstanceDomains returns the number of known instances known that the given domain federates with. + CountInstanceDomains(domain string) (int, Error) + + // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. + GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) +} diff --git a/internal/db/pg/put.go b/internal/db/media.go index 09beca14b..db4db3411 100644 --- a/internal/db/pg/put.go +++ b/internal/db/media.go @@ -16,18 +16,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package db -import ( - "strings" +import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -func (ps *postgresService) Put(i interface{}) error { - _, err := ps.conn.Model(i).Insert(i) - if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists{} - } - return err +// Media contains functions related to creating/getting/removing media attachments. +type Media interface { + // GetAttachmentByID gets a single attachment by its ID + GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error) } diff --git a/internal/db/mention.go b/internal/db/mention.go new file mode 100644 index 000000000..cb1c56dc1 --- /dev/null +++ b/internal/db/mention.go @@ -0,0 +1,30 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + +// Mention contains functions for getting/creating mentions in the database. +type Mention interface { + // GetMention gets a single mention by ID + GetMention(id string) (*gtsmodel.Mention, Error) + + // GetMentions gets multiple mentions. + GetMentions(ids []string) ([]*gtsmodel.Mention, Error) +} diff --git a/internal/db/notification.go b/internal/db/notification.go new file mode 100644 index 000000000..326f0f149 --- /dev/null +++ b/internal/db/notification.go @@ -0,0 +1,31 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + +// Notification contains functions for creating and getting notifications. +type Notification interface { + // GetNotifications returns a slice of notifications that pertain to the given accountID. + // + // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). + GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) + // GetNotification returns one notification according to its id. + GetNotification(id string) (*gtsmodel.Notification, Error) +} diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go new file mode 100644 index 000000000..3889c6601 --- /dev/null +++ b/internal/db/pg/account.go @@ -0,0 +1,256 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type accountDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query { + return a.conn.Model(account). + Relation("AvatarMediaAttachment"). + Relation("HeaderMediaAttachment") +} + +func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) { + account := >smodel.Account{} + + q := a.newAccountQ(account). + Where("account.id = ?", id) + + err := processErrorResponse(q.Select()) + + return account, err +} + +func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) { + account := >smodel.Account{} + + q := a.newAccountQ(account). + Where("account.uri = ?", uri) + + err := processErrorResponse(q.Select()) + + return account, err +} + +func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) { + account := >smodel.Account{} + + q := a.newAccountQ(account). + Where("account.url = ?", uri) + + err := processErrorResponse(q.Select()) + + return account, err +} + +func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) { + account := >smodel.Account{} + + q := a.newAccountQ(account) + + if domain == "" { + q = q. + Where("account.username = ?", domain). + Where("account.domain = ?", domain) + } else { + q = q. + Where("account.username = ?", domain). + Where("? IS NULL", pg.Ident("domain")) + } + + err := processErrorResponse(q.Select()) + + return account, err +} + +func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) { + status := >smodel.Status{} + + q := a.conn.Model(status). + Order("id DESC"). + Limit(1). + Where("account_id = ?", accountID). + Column("created_at") + + err := processErrorResponse(q.Select()) + + return status.CreatedAt, err +} + +func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { + if mediaAttachment.Avatar && mediaAttachment.Header { + return errors.New("one media attachment cannot be both header and avatar") + } + + var headerOrAVI string + if mediaAttachment.Avatar { + headerOrAVI = "avatar" + } else if mediaAttachment.Header { + headerOrAVI = "header" + } else { + return errors.New("given media attachment was neither a header nor an avatar") + } + + // TODO: there are probably more side effects here that need to be handled + if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { + return err + } + + if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { + return err + } + return nil +} + +func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) { + account := >smodel.Account{} + + q := a.newAccountQ(account). + Where("username = ?", username). + Where("? IS NULL", pg.Ident("domain")) + + err := processErrorResponse(q.Select()) + + return account, err +} + +func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) { + faves := []*gtsmodel.StatusFave{} + + if err := a.conn.Model(&faves). + Where("account_id = ?", accountID). + Select(); err != nil { + if err == pg.ErrNoRows { + return faves, nil + } + return nil, err + } + return faves, nil +} + +func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) { + return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() +} + +func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { + a.log.Debugf("getting statuses for account %s", accountID) + statuses := []*gtsmodel.Status{} + + q := a.conn.Model(&statuses).Order("id DESC") + if accountID != "" { + q = q.Where("account_id = ?", accountID) + } + + if limit != 0 { + q = q.Limit(limit) + } + + if excludeReplies { + q = q.Where("? IS NULL", pg.Ident("in_reply_to_id")) + } + + if pinnedOnly { + q = q.Where("pinned = ?", true) + } + + if mediaOnly { + q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) { + return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil + }) + } + + if maxID != "" { + q = q.Where("id < ?", maxID) + } + + if err := q.Select(); err != nil { + if err == pg.ErrNoRows { + return nil, db.ErrNoEntries + } + return nil, err + } + + if len(statuses) == 0 { + return nil, db.ErrNoEntries + } + + a.log.Debugf("returning statuses for account %s", accountID) + return statuses, nil +} + +func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { + blocks := []*gtsmodel.Block{} + + fq := a.conn.Model(&blocks). + Where("block.account_id = ?", accountID). + Relation("TargetAccount"). + Order("block.id DESC") + + if maxID != "" { + fq = fq.Where("block.id < ?", maxID) + } + + if sinceID != "" { + fq = fq.Where("block.id > ?", sinceID) + } + + if limit > 0 { + fq = fq.Limit(limit) + } + + err := fq.Select() + if err != nil { + if err == pg.ErrNoRows { + return nil, "", "", db.ErrNoEntries + } + return nil, "", "", err + } + + if len(blocks) == 0 { + return nil, "", "", db.ErrNoEntries + } + + accounts := []*gtsmodel.Account{} + for _, b := range blocks { + accounts = append(accounts, b.TargetAccount) + } + + nextMaxID := blocks[len(blocks)-1].ID + prevMinID := blocks[0].ID + return accounts, nextMaxID, prevMinID, nil +} diff --git a/internal/db/pg/account_test.go b/internal/db/pg/account_test.go new file mode 100644 index 000000000..7ea5ff39a --- /dev/null +++ b/internal/db/pg/account_test.go @@ -0,0 +1,70 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type AccountTestSuite struct { + PGStandardTestSuite +} + +func (suite *AccountTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *AccountTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *AccountTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { + account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(account) + suite.NotNil(account.AvatarMediaAttachment) + suite.NotEmpty(account.AvatarMediaAttachment.URL) + suite.NotNil(account.HeaderMediaAttachment) + suite.NotEmpty(account.HeaderMediaAttachment.URL) +} + +func TestAccountTestSuite(t *testing.T) { + suite.Run(t, new(AccountTestSuite)) +} diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go new file mode 100644 index 000000000..854f56ef0 --- /dev/null +++ b/internal/db/pg/admin.go @@ -0,0 +1,235 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" + "net" + "net/mail" + "strings" + "time" + + "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/util" + "golang.org/x/crypto/bcrypt" +) + +type adminDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (a *adminDB) IsUsernameAvailable(username string) db.Error { + // if no error we fail because it means we found something + // if error but it's not pg.ErrNoRows then we fail + // if err is pg.ErrNoRows we're good, we found nothing so continue + if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { + return fmt.Errorf("username %s already in use", username) + } else if err != pg.ErrNoRows { + return fmt.Errorf("db error: %s", err) + } + return nil +} + +func (a *adminDB) IsEmailAvailable(email string) db.Error { + // parse the domain from the email + m, err := mail.ParseAddress(email) + if err != nil { + return fmt.Errorf("error parsing email address %s: %s", email, err) + } + domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ + + // check if the email domain is blocked + if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { + // fail because we found something + return fmt.Errorf("email domain %s is blocked", domain) + } else if err != pg.ErrNoRows { + // fail because we got an unexpected error + return fmt.Errorf("db error: %s", err) + } + + // check if this email is associated with a user already + if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { + // fail because we found something + return fmt.Errorf("email %s already in use", email) + } else if err != pg.ErrNoRows { + // fail because we got an unexpected error + return fmt.Errorf("db error: %s", err) + } + return nil +} + +func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + a.log.Errorf("error creating new rsa key: %s", err) + return nil, err + } + + // if something went wrong while creating a user, we might already have an account, so check here first... + acct := >smodel.Account{} + err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select() + if err != nil { + // there's been an actual error + if err != pg.ErrNoRows { + return nil, fmt.Errorf("db error checking existence of account: %s", err) + } + + // we just don't have an account yet create one + newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) + newAccountID, err := id.NewRandomULID() + if err != nil { + return nil, err + } + + acct = >smodel.Account{ + ID: newAccountID, + Username: username, + DisplayName: username, + Reason: reason, + URL: newAccountURIs.UserURL, + PrivateKey: key, + PublicKey: &key.PublicKey, + PublicKeyURI: newAccountURIs.PublicKeyURI, + ActorType: gtsmodel.ActivityStreamsPerson, + URI: newAccountURIs.UserURI, + InboxURI: newAccountURIs.InboxURI, + OutboxURI: newAccountURIs.OutboxURI, + FollowersURI: newAccountURIs.FollowersURI, + FollowingURI: newAccountURIs.FollowingURI, + FeaturedCollectionURI: newAccountURIs.CollectionURI, + } + if _, err = a.conn.Model(acct).Insert(); err != nil { + return nil, err + } + } + + pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("error hashing password: %s", err) + } + + newUserID, err := id.NewRandomULID() + if err != nil { + return nil, err + } + + u := >smodel.User{ + ID: newUserID, + AccountID: acct.ID, + EncryptedPassword: string(pw), + SignUpIP: signUpIP.To4(), + Locale: locale, + UnconfirmedEmail: email, + CreatedByApplicationID: appID, + Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user + } + + if emailVerified { + u.ConfirmedAt = time.Now() + u.Email = email + } + + if admin { + u.Admin = true + u.Moderator = true + } + + if _, err = a.conn.Model(u).Insert(); err != nil { + return nil, err + } + + return u, nil +} + +func (a *adminDB) CreateInstanceAccount() db.Error { + username := a.config.Host + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + a.log.Errorf("error creating new rsa key: %s", err) + return err + } + + aID, err := id.NewRandomULID() + if err != nil { + return err + } + + newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) + acct := >smodel.Account{ + ID: aID, + Username: a.config.Host, + DisplayName: username, + URL: newAccountURIs.UserURL, + PrivateKey: key, + PublicKey: &key.PublicKey, + PublicKeyURI: newAccountURIs.PublicKeyURI, + ActorType: gtsmodel.ActivityStreamsPerson, + URI: newAccountURIs.UserURI, + InboxURI: newAccountURIs.InboxURI, + OutboxURI: newAccountURIs.OutboxURI, + FollowersURI: newAccountURIs.FollowersURI, + FollowingURI: newAccountURIs.FollowingURI, + FeaturedCollectionURI: newAccountURIs.CollectionURI, + } + inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert() + if err != nil { + return err + } + if inserted { + a.log.Infof("created instance account %s with id %s", username, acct.ID) + } else { + a.log.Infof("instance account %s already exists with id %s", username, acct.ID) + } + return nil +} + +func (a *adminDB) CreateInstanceInstance() db.Error { + iID, err := id.NewRandomULID() + if err != nil { + return err + } + + i := >smodel.Instance{ + ID: iID, + Domain: a.config.Host, + Title: a.config.Host, + URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host), + } + inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert() + if err != nil { + return err + } + if inserted { + a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID) + } else { + a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID) + } + return nil +} diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go new file mode 100644 index 000000000..6e76b4450 --- /dev/null +++ b/internal/db/pg/basic.go @@ -0,0 +1,205 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" +) + +type basicDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (b *basicDB) Put(i interface{}) db.Error { + _, err := b.conn.Model(i).Insert(i) + if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err +} + +func (b *basicDB) GetByID(id string, i interface{}) db.Error { + if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil { + if err == pg.ErrNoRows { + return db.ErrNoEntries + } + return err + + } + return nil +} + +func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn.Model(i) + for _, w := range where { + + if w.Value == nil { + q = q.Where("? IS NULL", pg.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", pg.Safe(w.Key), w.Value) + } + } + } + + if err := q.Select(); err != nil { + if err == pg.ErrNoRows { + return db.ErrNoEntries + } + return err + } + return nil +} + +func (b *basicDB) GetAll(i interface{}) db.Error { + if err := b.conn.Model(i).Select(); err != nil { + if err == pg.ErrNoRows { + return db.ErrNoEntries + } + return err + } + return nil +} + +func (b *basicDB) DeleteByID(id string, i interface{}) db.Error { + if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil { + // if there are no rows *anyway* then that's fine + // just return err if there's an actual error + if err != pg.ErrNoRows { + return err + } + } + return nil +} + +func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn.Model(i) + for _, w := range where { + q = q.Where("? = ?", pg.Safe(w.Key), w.Value) + } + + if _, err := q.Delete(); err != nil { + // if there are no rows *anyway* then that's fine + // just return err if there's an actual error + if err != pg.ErrNoRows { + return err + } + } + return nil +} + +func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error { + if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil { + if err == pg.ErrNoRows { + return db.ErrNoEntries + } + return err + } + return nil +} + +func (b *basicDB) UpdateByID(id string, i interface{}) db.Error { + if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if err == pg.ErrNoRows { + return db.ErrNoEntries + } + return err + } + return nil +} + +func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error { + _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() + return err +} + +func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error { + q := b.conn.Model(i) + + for _, w := range where { + if w.Value == nil { + q = q.Where("? IS NULL", pg.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", pg.Safe(w.Key), w.Value) + } + } + } + + q = q.Set("? = ?", pg.Safe(key), value) + + _, err := q.Update() + + return err +} + +func (b *basicDB) CreateTable(i interface{}) db.Error { + return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{ + IfNotExists: true, + }) +} + +func (b *basicDB) DropTable(i interface{}) db.Error { + return b.conn.Model(i).DropTable(&orm.DropTableOptions{ + IfExists: true, + }) +} + +func (b *basicDB) RegisterTable(i interface{}) db.Error { + orm.RegisterTable(i) + return nil +} + +func (b *basicDB) IsHealthy(ctx context.Context) db.Error { + return b.conn.Ping(ctx) +} + +func (b *basicDB) Stop(ctx context.Context) db.Error { + b.log.Info("closing db connection") + if err := b.conn.Close(); err != nil { + // only cancel if there's a problem closing the db + b.cancel() + return err + } + return nil +} diff --git a/internal/db/pg/blocks.go b/internal/db/pg/blocks.go deleted file mode 100644 index a6fc1f859..000000000 --- a/internal/db/pg/blocks.go +++ /dev/null @@ -1,67 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -func (ps *postgresService) GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) { - blocks := []*gtsmodel.Block{} - - fq := ps.conn.Model(&blocks). - Where("block.account_id = ?", accountID). - Relation("TargetAccount"). - Order("block.id DESC") - - if maxID != "" { - fq = fq.Where("block.id < ?", maxID) - } - - if sinceID != "" { - fq = fq.Where("block.id > ?", sinceID) - } - - if limit > 0 { - fq = fq.Limit(limit) - } - - err := fq.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries{} - } - return nil, "", "", err - } - - if len(blocks) == 0 { - return nil, "", "", db.ErrNoEntries{} - } - - accounts := []*gtsmodel.Account{} - for _, b := range blocks { - accounts = append(accounts, b.TargetAccount) - } - - nextMaxID := blocks[len(blocks)-1].ID - prevMinID := blocks[0].ID - return accounts, nextMaxID, prevMinID, nil -} diff --git a/internal/db/pg/domain.go b/internal/db/pg/domain.go new file mode 100644 index 000000000..4e9b2ab48 --- /dev/null +++ b/internal/db/pg/domain.go @@ -0,0 +1,83 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "context" + "net/url" + + "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +type domainDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) { + if domain == "" { + return false, nil + } + + blocked, err := d.conn. + Model(>smodel.DomainBlock{}). + Where("LOWER(domain) = LOWER(?)", domain). + Exists() + + err = processErrorResponse(err) + + return blocked, err +} + +func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) { + // filter out any doubles + uniqueDomains := util.UniqueStrings(domains) + + for _, domain := range uniqueDomains { + if blocked, err := d.IsDomainBlocked(domain); err != nil { + return false, err + } else if blocked { + return blocked, nil + } + } + + // no blocks found + return false, nil +} + +func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) { + domain := uri.Hostname() + return d.IsDomainBlocked(domain) +} + +func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) { + domains := []string{} + for _, uri := range uris { + domains = append(domains, uri.Hostname()) + } + + return d.AreDomainsBlocked(domains) +} diff --git a/internal/db/pg/get.go b/internal/db/pg/get.go deleted file mode 100644 index d48c43520..000000000 --- a/internal/db/pg/get.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "errors" - - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -func (ps *postgresService) GetByID(id string, i interface{}) error { - if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - - } - return nil -} - -func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := ps.conn.Model(i) - for _, w := range where { - - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) GetAll(i interface{}) error { - if err := ps.conn.Model(i).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go index c551b2a49..968832ca5 100644 --- a/internal/db/pg/instance.go +++ b/internal/db/pg/instance.go @@ -19,15 +19,26 @@ package pg import ( + "context" + "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Account{}) +type instanceDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) { + q := i.conn.Model(&[]*gtsmodel.Account{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count where the domain field is null q = q.Where("? IS NULL", pg.Ident("domain")) } else { @@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { return q.Count() } -func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Status{}) +func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) { + q := i.conn.Model(&[]*gtsmodel.Status{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count where local is true q = q.Where("local = ?", true) } else { @@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) return q.Count() } -func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Instance{}) +func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) { + q := i.conn.Model(&[]*gtsmodel.Instance{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) @@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) return q.Count() } -func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) { - ps.log.Debug("GetAccountsForInstance") +func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { + i.log.Debug("GetAccountsForInstance") accounts := []*gtsmodel.Account{} - q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") + q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") if maxID != "" { q = q.Where("id < ?", maxID) @@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l err := q.Select() if err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(accounts) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return accounts, nil diff --git a/internal/db/pg/delete.go b/internal/db/pg/media.go index 0f288353e..618030af3 100644 --- a/internal/db/pg/delete.go +++ b/internal/db/pg/media.go @@ -19,39 +19,35 @@ package pg import ( - "errors" + "context" "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) DeleteByID(id string, i interface{}) error { - if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil +type mediaDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc } -func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := ps.conn.Model(i) - for _, w := range where { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - - if _, err := q.Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil +func (m *mediaDB) newMediaQ(i interface{}) *orm.Query { + return m.conn.Model(i). + Relation("Account") +} + +func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) { + attachment := >smodel.MediaAttachment{} + + q := m.newMediaQ(attachment). + Where("media_attachment.id = ?", id) + + err := processErrorResponse(q.Select()) + + return attachment, err } diff --git a/internal/db/pg/mention.go b/internal/db/pg/mention.go new file mode 100644 index 000000000..b31f07b67 --- /dev/null +++ b/internal/db/pg/mention.go @@ -0,0 +1,108 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "context" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type mentionDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc + cache cache.Cache +} + +func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) { + if m.cache == nil { + m.cache = cache.New() + } + + if err := m.cache.Store(id, mention); err != nil { + m.log.Panicf("mentionDB: error storing in cache: %s", err) + } +} + +func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) { + if m.cache == nil { + m.cache = cache.New() + return nil, false + } + + mI, err := m.cache.Fetch(id) + if err != nil || mI == nil { + return nil, false + } + + mention, ok := mI.(*gtsmodel.Mention) + if !ok { + m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id) + } + + return mention, true +} + +func (m *mentionDB) newMentionQ(i interface{}) *orm.Query { + return m.conn.Model(i). + Relation("Status"). + Relation("OriginAccount"). + Relation("TargetAccount") +} + +func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { + if mention, cached := m.mentionCached(id); cached { + return mention, nil + } + + mention := >smodel.Mention{} + + q := m.newMentionQ(mention). + Where("mention.id = ?", id) + + err := processErrorResponse(q.Select()) + + if err == nil && mention != nil { + m.cacheMention(id, mention) + } + + return mention, err +} + +func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) { + mentions := []*gtsmodel.Mention{} + + for _, i := range ids { + mention, err := m.GetMention(i) + if err != nil { + return nil, processErrorResponse(err) + } + mentions = append(mentions, mention) + } + + return mentions, nil +} diff --git a/internal/db/pg/notification.go b/internal/db/pg/notification.go new file mode 100644 index 000000000..281a76d85 --- /dev/null +++ b/internal/db/pg/notification.go @@ -0,0 +1,135 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "context" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type notificationDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc + cache cache.Cache +} + +func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) { + if n.cache == nil { + n.cache = cache.New() + } + + if err := n.cache.Store(id, notification); err != nil { + n.log.Panicf("notificationDB: error storing in cache: %s", err) + } +} + +func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) { + if n.cache == nil { + n.cache = cache.New() + return nil, false + } + + nI, err := n.cache.Fetch(id) + if err != nil || nI == nil { + return nil, false + } + + notification, ok := nI.(*gtsmodel.Notification) + if !ok { + n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id) + } + + return notification, true +} + +func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query { + return n.conn.Model(i). + Relation("OriginAccount"). + Relation("TargetAccount"). + Relation("Status") +} + +func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) { + if notification, cached := n.notificationCached(id); cached { + return notification, nil + } + + notification := >smodel.Notification{} + + q := n.newNotificationQ(notification). + Where("notification.id = ?", id) + + err := processErrorResponse(q.Select()) + + if err == nil && notification != nil { + n.cacheNotification(id, notification) + } + + return notification, err +} + +func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { + // begin by selecting just the IDs + notifIDs := []*gtsmodel.Notification{} + q := n.conn. + Model(¬ifIDs). + Column("id"). + Where("target_account_id = ?", accountID). + Order("id DESC") + + if maxID != "" { + q = q.Where("id < ?", maxID) + } + + if sinceID != "" { + q = q.Where("id > ?", sinceID) + } + + if limit != 0 { + q = q.Limit(limit) + } + + err := processErrorResponse(q.Select()) + if err != nil { + return nil, err + } + + // now we have the IDs, select the notifs one by one + // reason for this is that for each notif, we can instead get it from our cache if it's cached + notifications := []*gtsmodel.Notification{} + for _, notifID := range notifIDs { + notif, err := n.GetNotification(notifID.ID) + errP := processErrorResponse(err) + if errP != nil { + return nil, errP + } + notifications = append(notifications, notif) + } + + return notifications, nil +} diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index d49c50114..0437baf02 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -20,15 +20,11 @@ package pg import ( "context" - "crypto/rand" - "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" - "net" - "net/mail" "os" "strings" "time" @@ -41,12 +37,26 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/superseriousbusiness/gotosocial/internal/util" - "golang.org/x/crypto/bcrypt" ) +var registerTables []interface{} = []interface{}{ + >smodel.StatusToEmoji{}, + >smodel.StatusToTag{}, +} + // postgresService satisfies the DB interface type postgresService struct { + db.Account + db.Admin + db.Basic + db.Domain + db.Instance + db.Media + db.Mention + db.Notification + db.Relationship + db.Status + db.Timeline config *config.Config conn *pg.DB log *logrus.Logger @@ -56,6 +66,11 @@ type postgresService struct { // NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { + for _, t := range registerTables { + // https://pg.uptrace.dev/orm/many-to-many-relation/ + orm.RegisterTable(t) + } + opts, err := derivePGOptions(c) if err != nil { return nil, fmt.Errorf("could not create postgres service: %s", err) @@ -91,6 +106,72 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge log.Infof("connected to postgres version: %s", version) ps := &postgresService{ + Account: &accountDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Admin: &adminDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Basic: &basicDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Domain: &domainDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Instance: &instanceDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Media: &mediaDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Mention: &mentionDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Notification: ¬ificationDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Relationship: &relationshipDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Status: &statusDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Timeline: &timelineDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, config: c, conn: conn, log: log, @@ -200,724 +281,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { } /* - BASIC DB FUNCTIONALITY -*/ - -func (ps *postgresService) CreateTable(i interface{}) error { - return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) -} - -func (ps *postgresService) DropTable(i interface{}) error { - return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ - IfExists: true, - }) -} - -func (ps *postgresService) Stop(ctx context.Context) error { - ps.log.Info("closing db connection") - if err := ps.conn.Close(); err != nil { - // only cancel if there's a problem closing the db - ps.cancel() - return err - } - return nil -} - -func (ps *postgresService) IsHealthy(ctx context.Context) error { - return ps.conn.Ping(ctx) -} - -func (ps *postgresService) CreateSchema(ctx context.Context) error { - models := []interface{}{ - (*gtsmodel.Account)(nil), - (*gtsmodel.Status)(nil), - (*gtsmodel.User)(nil), - } - ps.log.Info("creating db schema") - - for _, model := range models { - err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) - if err != nil { - return err - } - } - - ps.log.Info("db schema created") - return nil -} - -/* - HANDY SHORTCUTS -*/ - -func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { - // make sure the original follow request exists - fr := >smodel.FollowRequest{} - if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { - if err == pg.ErrMultiRows { - return nil, db.ErrNoEntries{} - } - return nil, err - } - - // create a new follow to 'replace' the request with - follow := >smodel.Follow{ - ID: fr.ID, - AccountID: originAccountID, - TargetAccountID: targetAccountID, - URI: fr.URI, - } - - // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { - return nil, err - } - - // now remove the follow request - if _, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { - return nil, err - } - - return follow, nil -} - -func (ps *postgresService) CreateInstanceAccount() error { - username := ps.config.Host - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - ps.log.Errorf("error creating new rsa key: %s", err) - return err - } - - aID, err := id.NewRandomULID() - if err != nil { - return err - } - - newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host) - a := >smodel.Account{ - ID: aID, - Username: ps.config.Host, - DisplayName: username, - URL: newAccountURIs.UserURL, - PrivateKey: key, - PublicKey: &key.PublicKey, - PublicKeyURI: newAccountURIs.PublicKeyURI, - ActorType: gtsmodel.ActivityStreamsPerson, - URI: newAccountURIs.UserURI, - InboxURI: newAccountURIs.InboxURI, - OutboxURI: newAccountURIs.OutboxURI, - FollowersURI: newAccountURIs.FollowersURI, - FollowingURI: newAccountURIs.FollowingURI, - FeaturedCollectionURI: newAccountURIs.CollectionURI, - } - inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert() - if err != nil { - return err - } - if inserted { - ps.log.Infof("created instance account %s with id %s", username, a.ID) - } else { - ps.log.Infof("instance account %s already exists with id %s", username, a.ID) - } - return nil -} - -func (ps *postgresService) CreateInstanceInstance() error { - iID, err := id.NewRandomULID() - if err != nil { - return err - } - - i := >smodel.Instance{ - ID: iID, - Domain: ps.config.Host, - Title: ps.config.Host, - URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host), - } - inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert() - if err != nil { - return err - } - if inserted { - ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID) - } else { - ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID) - } - return nil -} - -func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error { - user := >smodel.User{ - ID: userID, - } - if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error { - if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error { - if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return nil - } - return err - } - return nil -} - -func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error { - if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return nil - } - return err - } - return nil -} - -func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error { - - q := ps.conn.Model(followers) - - if localOnly { - // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe - whereGroup := func(q *pg.Query) (*pg.Query, error) { - q = q. - WhereOr("? IS NULL", pg.Ident("a.domain")). - WhereOr("a.domain = ?", "") - return q, nil - } - - q = q.ColumnExpr("follow.*"). - Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). - Where("follow.target_account_id = ?", accountID). - WhereGroup(whereGroup) - } else { - q = q.Where("target_account_id = ?", accountID) - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return nil - } - return err - } - return nil -} - -func (ps *postgresService) GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error { - if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return nil - } - return err - } - return nil -} - -func (ps *postgresService) CountStatusesByAccountID(accountID string) (int, error) { - count, err := ps.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() - if err != nil { - if err == pg.ErrNoRows { - return 0, nil - } - return 0, err - } - return count, nil -} - -func (ps *postgresService) GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) { - ps.log.Debugf("getting statuses for account %s", accountID) - statuses := []*gtsmodel.Status{} - - q := ps.conn.Model(&statuses).Order("id DESC") - if accountID != "" { - q = q.Where("account_id = ?", accountID) - } - - if limit != 0 { - q = q.Limit(limit) - } - - if excludeReplies { - q = q.Where("? IS NULL", pg.Ident("in_reply_to_id")) - } - - if pinnedOnly { - q = q.Where("pinned = ?", true) - } - - if mediaOnly { - q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) { - return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil - }) - } - - if maxID != "" { - q = q.Where("id < ?", maxID) - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} - } - return nil, err - } - - if len(statuses) == 0 { - return nil, db.ErrNoEntries{} - } - - ps.log.Debugf("returning statuses for account %s", accountID) - return statuses, nil -} - -func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error { - if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil - -} - -func (ps *postgresService) IsUsernameAvailable(username string) error { - // if no error we fail because it means we found something - // if error but it's not pg.ErrNoRows then we fail - // if err is pg.ErrNoRows we're good, we found nothing so continue - if err := ps.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { - return fmt.Errorf("username %s already in use", username) - } else if err != pg.ErrNoRows { - return fmt.Errorf("db error: %s", err) - } - return nil -} - -func (ps *postgresService) IsEmailAvailable(email string) error { - // parse the domain from the email - m, err := mail.ParseAddress(email) - if err != nil { - return fmt.Errorf("error parsing email address %s: %s", email, err) - } - domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ - - // check if the email domain is blocked - if err := ps.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { - // fail because we found something - return fmt.Errorf("email domain %s is blocked", domain) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) - } - - // check if this email is associated with a user already - if err := ps.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { - // fail because we found something - return fmt.Errorf("email %s already in use", email) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) - } - return nil -} - -func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - ps.log.Errorf("error creating new rsa key: %s", err) - return nil, err - } - - // if something went wrong while creating a user, we might already have an account, so check here first... - a := >smodel.Account{} - err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select() - if err != nil { - // there's been an actual error - if err != pg.ErrNoRows { - return nil, fmt.Errorf("db error checking existence of account: %s", err) - } - - // we just don't have an account yet create one - newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host) - newAccountID, err := id.NewRandomULID() - if err != nil { - return nil, err - } - - a = >smodel.Account{ - ID: newAccountID, - Username: username, - DisplayName: username, - Reason: reason, - URL: newAccountURIs.UserURL, - PrivateKey: key, - PublicKey: &key.PublicKey, - PublicKeyURI: newAccountURIs.PublicKeyURI, - ActorType: gtsmodel.ActivityStreamsPerson, - URI: newAccountURIs.UserURI, - InboxURI: newAccountURIs.InboxURI, - OutboxURI: newAccountURIs.OutboxURI, - FollowersURI: newAccountURIs.FollowersURI, - FollowingURI: newAccountURIs.FollowingURI, - FeaturedCollectionURI: newAccountURIs.CollectionURI, - } - if _, err = ps.conn.Model(a).Insert(); err != nil { - return nil, err - } - } - - pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return nil, fmt.Errorf("error hashing password: %s", err) - } - - newUserID, err := id.NewRandomULID() - if err != nil { - return nil, err - } - - u := >smodel.User{ - ID: newUserID, - AccountID: a.ID, - EncryptedPassword: string(pw), - SignUpIP: signUpIP.To4(), - Locale: locale, - UnconfirmedEmail: email, - CreatedByApplicationID: appID, - Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user - } - - if emailVerified { - u.ConfirmedAt = time.Now() - u.Email = email - } - - if admin { - u.Admin = true - u.Moderator = true - } - - if _, err = ps.conn.Model(u).Insert(); err != nil { - return nil, err - } - - return u, nil -} - -func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error { - if mediaAttachment.Avatar && mediaAttachment.Header { - return errors.New("one media attachment cannot be both header and avatar") - } - - var headerOrAVI string - if mediaAttachment.Avatar { - headerOrAVI = "avatar" - } else if mediaAttachment.Header { - headerOrAVI = "header" - } else { - return errors.New("given media attachment was neither a header nor an avatar") - } - - // TODO: there are probably more side effects here that need to be handled - if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { - return err - } - - if _, err := ps.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { - return err - } - return nil -} - -func (ps *postgresService) GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error { - acct := >smodel.Account{} - if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - - if acct.HeaderMediaAttachmentID == "" { - return db.ErrNoEntries{} - } - - if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error { - acct := >smodel.Account{} - if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - - if acct.AvatarMediaAttachmentID == "" { - return db.ErrNoEntries{} - } - - if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) { - // TODO: check domain blocks as well - var blocked bool - if err := ps.conn.Model(>smodel.Block{}). - Where("account_id = ?", account1).Where("target_account_id = ?", account2). - WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2). - Select(); err != nil { - if err == pg.ErrNoRows { - blocked = false - return blocked, nil - } - return blocked, err - } - blocked = true - return blocked, nil -} - -func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) { - r := >smodel.Relationship{ - ID: targetAccount, - } - - // check if the requesting account follows the target account - follow := >smodel.Follow{} - if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { - if err != pg.ErrNoRows { - // a proper error - return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) - } - // no follow exists so these are all false - r.Following = false - r.ShowingReblogs = false - r.Notifying = false - } else { - // follow exists so we can fill these fields out... - r.Following = true - r.ShowingReblogs = follow.ShowReblogs - r.Notifying = follow.Notify - } - - // check if the target account follows the requesting account - followedBy, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) - } - r.FollowedBy = followedBy - - // check if the requesting account blocks the target account - blocking, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) - } - r.Blocking = blocking - - // check if the target account blocks the requesting account - blockedBy, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) - } - r.BlockedBy = blockedBy - - // check if there's a pending following request from requesting account to target account - requested, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) - } - r.Requested = requested - - return r, nil -} - -func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) { - if sourceAccount == nil || targetAccount == nil { - return false, nil - } - - return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() -} - -func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) { - if sourceAccount == nil || targetAccount == nil { - return false, nil - } - - return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() -} - -func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) { - if account1 == nil || account2 == nil { - return false, nil - } - - // make sure account 1 follows account 2 - f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists() - if err != nil { - if err == pg.ErrNoRows { - return false, nil - } - return false, err - } - - // make sure account 2 follows account 1 - f2, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists() - if err != nil { - if err == pg.ErrNoRows { - return false, nil - } - return false, err - } - - return f1 && f2, nil -} - -func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) { - return ps.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count() -} - -func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) { - return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count() -} - -func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) { - return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count() -} - -func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) { - accounts := []*gtsmodel.Account{} - - faves := []*gtsmodel.StatusFave{} - if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil { - if err == pg.ErrNoRows { - return accounts, nil // no rows just means nobody has faved this status, so that's fine - } - return nil, err // an actual error has occurred - } - - for _, f := range faves { - acc := >smodel.Account{} - if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil { - if err == pg.ErrNoRows { - continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it - } - return nil, err // an actual error has occurred - } - accounts = append(accounts, acc) - } - return accounts, nil -} - -func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) { - accounts := []*gtsmodel.Account{} - - boosts := []*gtsmodel.Status{} - if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil { - if err == pg.ErrNoRows { - return accounts, nil // no rows just means nobody has boosted this status, so that's fine - } - return nil, err // an actual error has occurred - } - - for _, f := range boosts { - acc := >smodel.Account{} - if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil { - if err == pg.ErrNoRows { - continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it - } - return nil, err // an actual error has occurred - } - accounts = append(accounts, acc) - } - return accounts, nil -} - -func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) { - notifications := []*gtsmodel.Notification{} - - q := ps.conn.Model(¬ifications).Where("target_account_id = ?", accountID) - - if maxID != "" { - q = q.Where("id < ?", maxID) - } - - if sinceID != "" { - q = q.Where("id > ?", sinceID) - } - - if limit != 0 { - q = q.Limit(limit) - } - - q = q.Order("created_at DESC") - - if err := q.Select(); err != nil { - if err != pg.ErrNoRows { - return nil, err - } - - } - return notifications, nil -} - -/* CONVERSION FUNCTIONS */ @@ -988,14 +351,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori // id, createdAt and updatedAt will be populated by the db, so we have everything we need! menchies = append(menchies, >smodel.Mention{ - StatusID: statusID, - OriginAccountID: ogAccount.ID, - OriginAccountURI: ogAccount.URI, - TargetAccountID: mentionedAccount.ID, - NameString: a, - MentionedAccountURI: mentionedAccount.URI, - MentionedAccountURL: mentionedAccount.URL, - GTSAccount: mentionedAccount, + StatusID: statusID, + OriginAccountID: ogAccount.ID, + OriginAccountURI: ogAccount.URI, + TargetAccountID: mentionedAccount.ID, + NameString: a, + TargetAccountURI: mentionedAccount.URI, + TargetAccountURL: mentionedAccount.URL, + OriginAccount: mentionedAccount, }) } return menchies, nil diff --git a/internal/db/pg/pg_test.go b/internal/db/pg/pg_test.go new file mode 100644 index 000000000..c1e10abdf --- /dev/null +++ b/internal/db/pg/pg_test.go @@ -0,0 +1,47 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg_test + +import ( + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +type PGStandardTestSuite struct { + // standard suite interfaces + suite.Suite + config *config.Config + db db.DB + log *logrus.Logger + + // standard suite models + testTokens map[string]*oauth.Token + testClients map[string]*oauth.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testAttachments map[string]*gtsmodel.MediaAttachment + testStatuses map[string]*gtsmodel.Status + testTags map[string]*gtsmodel.Tag + testMentions map[string]*gtsmodel.Mention +} diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go new file mode 100644 index 000000000..76bd50c76 --- /dev/null +++ b/internal/db/pg/relationship.go @@ -0,0 +1,276 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "context" + "fmt" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type relationshipDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query { + return r.conn.Model(block). + Relation("Account"). + Relation("TargetAccount") +} + +func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query { + return r.conn.Model(follow). + Relation("Account"). + Relation("TargetAccount") +} + +func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) { + q := r.conn. + Model(>smodel.Block{}). + Where("account_id = ?", account1). + Where("target_account_id = ?", account2) + + if eitherDirection { + q = q. + WhereOr("target_account_id = ?", account1). + Where("account_id = ?", account2) + } + + return q.Exists() +} + +func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) { + block := >smodel.Block{} + + q := r.newBlockQ(block). + Where("block.account_id = ?", account1). + Where("block.target_account_id = ?", account2) + + err := processErrorResponse(q.Select()) + + return block, err +} + +func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { + rel := >smodel.Relationship{ + ID: targetAccount, + } + + // check if the requesting account follows the target account + follow := >smodel.Follow{} + if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { + if err != pg.ErrNoRows { + // a proper error + return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) + } + // no follow exists so these are all false + rel.Following = false + rel.ShowingReblogs = false + rel.Notifying = false + } else { + // follow exists so we can fill these fields out... + rel.Following = true + rel.ShowingReblogs = follow.ShowReblogs + rel.Notifying = follow.Notify + } + + // check if the target account follows the requesting account + followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) + } + rel.FollowedBy = followedBy + + // check if the requesting account blocks the target account + blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) + } + rel.Blocking = blocking + + // check if the target account blocks the requesting account + blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + } + rel.BlockedBy = blockedBy + + // check if there's a pending following request from requesting account to target account + requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + } + rel.Requested = requested + + return rel, nil +} + +func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { + if sourceAccount == nil || targetAccount == nil { + return false, nil + } + + q := r.conn. + Model(>smodel.Follow{}). + Where("account_id = ?", sourceAccount.ID). + Where("target_account_id = ?", targetAccount.ID) + + return q.Exists() +} + +func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { + if sourceAccount == nil || targetAccount == nil { + return false, nil + } + + q := r.conn. + Model(>smodel.FollowRequest{}). + Where("account_id = ?", sourceAccount.ID). + Where("target_account_id = ?", targetAccount.ID) + + return q.Exists() +} + +func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { + if account1 == nil || account2 == nil { + return false, nil + } + + // make sure account 1 follows account 2 + f1, err := r.IsFollowing(account1, account2) + if err != nil { + return false, processErrorResponse(err) + } + + // make sure account 2 follows account 1 + f2, err := r.IsFollowing(account2, account1) + if err != nil { + return false, processErrorResponse(err) + } + + return f1 && f2, nil +} + +func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { + // make sure the original follow request exists + fr := >smodel.FollowRequest{} + if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { + if err == pg.ErrMultiRows { + return nil, db.ErrNoEntries + } + return nil, err + } + + // create a new follow to 'replace' the request with + follow := >smodel.Follow{ + ID: fr.ID, + AccountID: originAccountID, + TargetAccountID: targetAccountID, + URI: fr.URI, + } + + // if the follow already exists, just update the URI -- we don't need to do anything else + if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { + return nil, err + } + + // now remove the follow request + if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { + return nil, err + } + + return follow, nil +} + +func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) { + followRequests := []*gtsmodel.FollowRequest{} + + q := r.newFollowQ(&followRequests). + Where("target_account_id = ?", accountID) + + err := processErrorResponse(q.Select()) + + return followRequests, err +} + +func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) { + follows := []*gtsmodel.Follow{} + + q := r.newFollowQ(&follows). + Where("account_id = ?", accountID) + + err := processErrorResponse(q.Select()) + + return follows, err +} + +func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) { + return r.conn. + Model(&[]*gtsmodel.Follow{}). + Where("account_id = ?", accountID). + Count() +} + +func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { + + follows := []*gtsmodel.Follow{} + + q := r.conn.Model(&follows) + + if localOnly { + // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe + whereGroup := func(q *pg.Query) (*pg.Query, error) { + q = q. + WhereOr("? IS NULL", pg.Ident("a.domain")). + WhereOr("a.domain = ?", "") + return q, nil + } + + q = q.ColumnExpr("follow.*"). + Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). + Where("follow.target_account_id = ?", accountID). + WhereGroup(whereGroup) + } else { + q = q.Where("target_account_id = ?", accountID) + } + + if err := q.Select(); err != nil { + if err == pg.ErrNoRows { + return follows, nil + } + return nil, err + } + return follows, nil +} + +func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) { + return r.conn. + Model(&[]*gtsmodel.Follow{}). + Where("target_account_id = ?", accountID). + Count() +} diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go new file mode 100644 index 000000000..99790428e --- /dev/null +++ b/internal/db/pg/status.go @@ -0,0 +1,318 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg + +import ( + "container/list" + "context" + "errors" + "time" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type statusDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc + cache cache.Cache +} + +func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { + if s.cache == nil { + s.cache = cache.New() + } + + if err := s.cache.Store(id, status); err != nil { + s.log.Panicf("statusDB: error storing in cache: %s", err) + } +} + +func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { + if s.cache == nil { + s.cache = cache.New() + return nil, false + } + + sI, err := s.cache.Fetch(id) + if err != nil || sI == nil { + return nil, false + } + + status, ok := sI.(*gtsmodel.Status) + if !ok { + s.log.Panicf("statusDB: cached interface with key %s was not a status", id) + } + + return status, true +} + +func (s *statusDB) newStatusQ(status interface{}) *orm.Query { + return s.conn.Model(status). + Relation("Attachments"). + Relation("Tags"). + Relation("Mentions"). + Relation("Emojis"). + Relation("Account"). + Relation("InReplyTo"). + Relation("InReplyToAccount"). + Relation("BoostOf"). + Relation("BoostOfAccount"). + Relation("CreatedWithApplication") +} + +func (s *statusDB) newFaveQ(faves interface{}) *orm.Query { + return s.conn.Model(faves). + Relation("Account"). + Relation("TargetAccount"). + Relation("Status") +} + +func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(id); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("status.id = ?", id) + + err := processErrorResponse(q.Select()) + + if err == nil && status != nil { + s.cacheStatus(id, status) + } + + return status, err +} + +func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.uri) = LOWER(?)", uri) + + err := processErrorResponse(q.Select()) + + if err == nil && status != nil { + s.cacheStatus(uri, status) + } + + return status, err +} + +func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.url) = LOWER(?)", uri) + + err := processErrorResponse(q.Select()) + + if err == nil && status != nil { + s.cacheStatus(uri, status) + } + + return status, err +} + +func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error { + transaction := func(tx *pg.Tx) error { + // create links between this status and any emojis it uses + for _, i := range status.EmojiIDs { + if _, err := tx.Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Insert(); err != nil { + return err + } + } + + // create links between this status and any tags it uses + for _, i := range status.TagIDs { + if _, err := tx.Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Insert(); err != nil { + return err + } + } + + // change the status ID of the media attachments to the new status + for _, a := range status.Attachments { + a.StatusID = status.ID + a.UpdatedAt = time.Now() + if _, err := s.conn.Model(a). + Where("id = ?", a.ID). + Update(); err != nil { + return err + } + } + + _, err := tx.Model(status).Insert() + return err + } + + return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction)) +} + +func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { + parents := []*gtsmodel.Status{} + s.statusParent(status, &parents, onlyDirect) + + return parents, nil +} + +func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { + if status.InReplyToID == "" { + return + } + + parentStatus, err := s.GetStatusByID(status.InReplyToID) + if err == nil { + *foundStatuses = append(*foundStatuses, parentStatus) + } + + if onlyDirect { + return + } + + s.statusParent(parentStatus, foundStatuses, false) +} + +func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { + foundStatuses := &list.List{} + foundStatuses.PushFront(status) + s.statusChildren(status, foundStatuses, onlyDirect, minID) + + children := []*gtsmodel.Status{} + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + // only append children, not the overall parent status + if entry.ID != status.ID { + children = append(children, entry) + } + } + + return children, nil +} + +func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { + immediateChildren := []*gtsmodel.Status{} + + q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID) + if minID != "" { + q = q.Where("status.id > ?", minID) + } + + if err := q.Select(); err != nil { + return + } + + for _, child := range immediateChildren { + insertLoop: + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { + foundStatuses.InsertAfter(child, e) + break insertLoop + } + } + + // only do one loop if we only want direct children + if onlyDirect { + return + } + s.statusChildren(child, foundStatuses, false, minID) + } +} + +func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) { + return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count() +} + +func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) { + return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count() +} + +func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) { + return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count() +} + +func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { + return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +} + +func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { + return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +} + +func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { + return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +} + +func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { + return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +} + +func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { + faves := []*gtsmodel.StatusFave{} + + q := s.newFaveQ(&faves). + Where("status_id = ?", status.ID) + + err := processErrorResponse(q.Select()) + + return faves, err +} + +func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { + reblogs := []*gtsmodel.Status{} + + q := s.newStatusQ(&reblogs). + Where("boost_of_id = ?", status.ID) + + err := processErrorResponse(q.Select()) + + return reblogs, err +} diff --git a/internal/db/pg/status_test.go b/internal/db/pg/status_test.go new file mode 100644 index 000000000..8a185757c --- /dev/null +++ b/internal/db/pg/status_test.go @@ -0,0 +1,134 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package pg_test + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type StatusTestSuite struct { + PGStandardTestSuite +} + +func (suite *StatusTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *StatusTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *StatusTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *StatusTestSuite) TestGetStatusByID() { + status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.Nil(status.BoostOf) + suite.Nil(status.BoostOfAccount) + suite.Nil(status.InReplyTo) + suite.Nil(status.InReplyToAccount) +} + +func (suite *StatusTestSuite) TestGetStatusByURI() { + status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.Nil(status.BoostOf) + suite.Nil(status.BoostOfAccount) + suite.Nil(status.InReplyTo) + suite.Nil(status.InReplyToAccount) +} + +func (suite *StatusTestSuite) TestGetStatusWithExtras() { + status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.NotEmpty(status.Tags) + suite.NotEmpty(status.Attachments) + suite.NotEmpty(status.Emojis) +} + +func (suite *StatusTestSuite) TestGetStatusWithMention() { + status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.NotEmpty(status.Mentions) + suite.NotEmpty(status.MentionIDs) + suite.NotNil(status.InReplyTo) + suite.NotNil(status.InReplyToAccount) +} + +func (suite *StatusTestSuite) TestGetStatusTwice() { + before1 := time.Now() + _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + suite.NoError(err) + after1 := time.Now() + duration1 := after1.Sub(before1) + fmt.Println(duration1.Nanoseconds()) + + before2 := time.Now() + _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + suite.NoError(err) + after2 := time.Now() + duration2 := after2.Sub(before2) + fmt.Println(duration2.Nanoseconds()) + + // second retrieval should be several orders faster since it will be cached now + suite.Less(duration2, duration1) +} + +func TestStatusTestSuite(t *testing.T) { + suite.Run(t, new(StatusTestSuite)) +} diff --git a/internal/db/pg/statuscontext.go b/internal/db/pg/statuscontext.go deleted file mode 100644 index 2ff1a20bb..000000000 --- a/internal/db/pg/statuscontext.go +++ /dev/null @@ -1,104 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "container/list" - "errors" - - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) { - parents := []*gtsmodel.Status{} - ps.statusParent(status, &parents, onlyDirect) - - return parents, nil -} - -func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { - if status.InReplyToID == "" { - return - } - - parentStatus := >smodel.Status{} - if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil { - *foundStatuses = append(*foundStatuses, parentStatus) - } - - if onlyDirect { - return - } - ps.statusParent(parentStatus, foundStatuses, false) -} - -func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) { - foundStatuses := &list.List{} - foundStatuses.PushFront(status) - ps.statusChildren(status, foundStatuses, onlyDirect, minID) - - children := []*gtsmodel.Status{} - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - // only append children, not the overall parent status - if entry.ID != status.ID { - children = append(children, entry) - } - } - - return children, nil -} - -func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - immediateChildren := []*gtsmodel.Status{} - - q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID) - if minID != "" { - q = q.Where("status.id > ?", minID) - } - - if err := q.Select(); err != nil { - return - } - - for _, child := range immediateChildren { - insertLoop: - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { - foundStatuses.InsertAfter(child, e) - break insertLoop - } - } - - // only do one loop if we only want direct children - if onlyDirect { - return - } - ps.statusChildren(child, foundStatuses, false, minID) - } -} diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go index 585ca3067..fa8b07aab 100644 --- a/internal/db/pg/timeline.go +++ b/internal/db/pg/timeline.go @@ -19,16 +19,26 @@ package pg import ( + "context" "sort" "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { +type timelineDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { statuses := []*gtsmodel.Status{} - q := ps.conn.Model(&statuses) + q := t.conn.Model(&statuses) q = q.ColumnExpr("status.*"). // Find out who accountID follows. @@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str err := q.Select() if err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(statuses) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return statuses, nil } -func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { +func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { statuses := []*gtsmodel.Status{} - q := ps.conn.Model(&statuses). + q := t.conn.Model(&statuses). Where("visibility = ?", gtsmodel.VisibilityPublic). Where("? IS NULL", pg.Ident("in_reply_to_id")). Where("? IS NULL", pg.Ident("in_reply_to_uri")). @@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s err := q.Select() if err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(statuses) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return statuses, nil @@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. -func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) { +func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { faves := []*gtsmodel.StatusFave{} - fq := ps.conn.Model(&faves). + fq := t.conn.Model(&faves). Where("account_id = ?", accountID). Order("id DESC") @@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st err := fq.Select() if err != nil { if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } return nil, "", "", err } if len(faves) == 0 { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID @@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st } statuses := []*gtsmodel.Status{} - err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() + err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() if err != nil { if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } return nil, "", "", err } if len(statuses) == 0 { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } // arrange statuses by fave ID diff --git a/internal/db/pg/update.go b/internal/db/pg/update.go deleted file mode 100644 index f6bc70ad9..000000000 --- a/internal/db/pg/update.go +++ /dev/null @@ -1,73 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "fmt" - - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error { - if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) UpdateByID(id string, i interface{}) error { - if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error { - _, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() - return err -} - -func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error { - q := ps.conn.Model(i) - - for _, w := range where { - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - q = q.Set("? = ?", pg.Safe(key), value) - - _, err := q.Update() - - return err -} diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go new file mode 100644 index 000000000..17c09b720 --- /dev/null +++ b/internal/db/pg/util.go @@ -0,0 +1,25 @@ +package pg + +import ( + "strings" + + "github.com/go-pg/pg/v10" + "github.com/superseriousbusiness/gotosocial/internal/db" +) + +// processErrorResponse parses the given error and returns an appropriate DBError. +func processErrorResponse(err error) db.Error { + switch err { + case nil: + return nil + case pg.ErrNoRows: + return db.ErrNoEntries + case pg.ErrMultiRows: + return db.ErrMultipleEntries + default: + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err + } +} diff --git a/internal/db/relationship.go b/internal/db/relationship.go new file mode 100644 index 000000000..85f64d72b --- /dev/null +++ b/internal/db/relationship.go @@ -0,0 +1,71 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + +// Relationship contains functions for getting or modifying the relationship between two accounts. +type Relationship interface { + // IsBlocked checks whether account 1 has a block in place against block2. + // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1. + IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error) + + // GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't. + // + // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason, + // not if you're just checking for the existence of a block. + GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error) + + // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. + GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) + + // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. + IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + + // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. + IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + + // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. + IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) + + // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. + // In other words, it should create the follow, and delete the existing follow request. + // + // It will return the newly created follow for further processing. + AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) + + // GetAccountFollowRequests returns all follow requests targeting the given account. + GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error) + + // GetAccountFollows returns a slice of follows owned by the given accountID. + GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error) + + // CountAccountFollows returns the amount of accounts that the given accountID is following. + // + // If localOnly is set to true, then only follows from *this instance* will be returned. + CountAccountFollows(accountID string, localOnly bool) (int, Error) + + // GetAccountFollowedBy fetches follows that target given accountID. + // + // If localOnly is set to true, then only follows from *this instance* will be returned. + GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error) + + // CountAccountFollowedBy returns the amounts that the given ID is followed by. + CountAccountFollowedBy(accountID string, localOnly bool) (int, Error) +} diff --git a/internal/db/status.go b/internal/db/status.go new file mode 100644 index 000000000..9d206c198 --- /dev/null +++ b/internal/db/status.go @@ -0,0 +1,75 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + +// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. +type Status interface { + // GetStatusByID returns one status from the database, with all rel fields populated (if possible). + GetStatusByID(id string) (*gtsmodel.Status, Error) + + // GetStatusByURI returns one status from the database, with all rel fields populated (if possible). + GetStatusByURI(uri string) (*gtsmodel.Status, Error) + + // GetStatusByURL returns one status from the database, with all rel fields populated (if possible). + GetStatusByURL(uri string) (*gtsmodel.Status, Error) + + // PutStatus stores one status in the database. + PutStatus(status *gtsmodel.Status) Error + + // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong + CountStatusReplies(status *gtsmodel.Status) (int, Error) + + // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong + CountStatusReblogs(status *gtsmodel.Status) (int, Error) + + // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong + CountStatusFaves(status *gtsmodel.Status) (int, Error) + + // GetStatusParents gets the parent statuses of a given status. + // + // If onlyDirect is true, only the immediate parent will be returned. + GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) + + // GetStatusChildren gets the child statuses of a given status. + // + // If onlyDirect is true, only the immediate children will be returned. + GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) + + // IsStatusFavedBy checks if a given status has been faved by a given account ID + IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error) + + // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID + IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error) + + // IsStatusMutedBy checks if a given status has been muted by a given account ID + IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error) + + // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID + IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error) + + // GetStatusFaves returns a slice of faves/likes of the given status. + // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. + GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error) + + // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status. + // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. + GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error) +} diff --git a/internal/db/timeline.go b/internal/db/timeline.go new file mode 100644 index 000000000..74aa5c781 --- /dev/null +++ b/internal/db/timeline.go @@ -0,0 +1,44 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + +// Timeline contains functionality for retrieving home/public/faved etc timelines for an account. +type Timeline interface { + // GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id. + // + // Statuses should be returned in descending order of when they were created (newest first). + GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + + // GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public. + // It will use the given filters and try to return as many statuses as possible up to the limit. + // + // Statuses should be returned in descending order of when they were created (newest first). + GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + + // GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. + // It will use the given filters and try to return as many statuses as possible up to the limit. + // + // Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id. + // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. + // + // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. + GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) +} |