diff --git a/internal/database/issue_mail.go b/internal/database/issue_mail.go index 74f3507df..c5f6bd3e9 100644 --- a/internal/database/issue_mail.go +++ b/internal/database/issue_mail.go @@ -176,7 +176,7 @@ func mailIssueCommentToParticipants(issue *Issue, doer *User, mentions []string) // and mentioned people. func (issue *Issue) MailParticipants() (err error) { mentions := markup.FindAllMentions(issue.Content) - if err = updateIssueMentions(x, issue.ID, mentions); err != nil { + if err = updateIssueMentions(db, issue.ID, mentions); err != nil { return errors.Newf("UpdateIssueMentions [%d]: %v", issue.ID, err) } diff --git a/internal/database/models.go b/internal/database/models.go index 5a5aee498..4be4fd91e 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -97,7 +97,7 @@ func getGormDB(gormLogger logger.Writer) (*gorm.DB, error) { func NewTestEngine() error { var err error - db, err = getGormDB(&dbutil.Logger{Writer: log.NewConsoleWriter()}) + db, err = getGormDB(&dbutil.Logger{Writer: os.Stdout}) if err != nil { return errors.Newf("connect to database: %v", err) } diff --git a/internal/database/org.go b/internal/database/org.go index 5cc995836..60d0d6c8b 100644 --- a/internal/database/org.go +++ b/internal/database/org.go @@ -484,7 +484,7 @@ func (org *User) GetUserRepositories(userID int64, page, pageSize int) ([]*Repos } var teamRepoIDs []int64 - if err = x.Table("team_repo").In("team_id", teamIDs).Distinct("repo_id").Find(&teamRepoIDs); err != nil { + if err = db.Table("team_repo").Where("team_id IN ?", teamIDs).Distinct("repo_id").Find(&teamRepoIDs).Error; err != nil { return nil, 0, errors.Newf("get team repository IDs: %v", err) } if len(teamRepoIDs) == 0 { @@ -496,22 +496,18 @@ func (org *User) GetUserRepositories(userID int64, page, pageSize int) ([]*Repos page = 1 } repos := make([]*Repository, 0, pageSize) - if err = x.Where("owner_id = ?", org.ID). - And(builder.Or( - builder.And(builder.Expr("is_private = ?", false), builder.Expr("is_unlisted = ?", false)), - builder.In("id", teamRepoIDs))). - Desc("updated_unix"). - Limit(pageSize, (page-1)*pageSize). - Find(&repos); err != nil { + if err = db.Where("owner_id = ?", org.ID). + Where(db.Where("is_private = ? AND is_unlisted = ?", false, false).Or("id IN ?", teamRepoIDs)). + Order("updated_unix DESC"). + Limit(pageSize).Offset((page - 1) * pageSize). + Find(&repos).Error; err != nil { return nil, 0, errors.Newf("get user repositories: %v", err) } - repoCount, err := x.Where("owner_id = ?", org.ID). - And(builder.Or( - builder.Expr("is_private = ?", false), - builder.In("id", teamRepoIDs))). - Count(new(Repository)) - if err != nil { + var repoCount int64 + if err = db.Model(&Repository{}).Where("owner_id = ?", org.ID). + Where(db.Where("is_private = ?", false).Or("id IN ?", teamRepoIDs)). + Count(&repoCount).Error; err != nil { return nil, 0, errors.Newf("count user repositories: %v", err) } @@ -529,7 +525,7 @@ func (org *User) GetUserMirrorRepositories(userID int64) ([]*Repository, error) } var teamRepoIDs []int64 - err = x.Table("team_repo").In("team_id", teamIDs).Distinct("repo_id").Find(&teamRepoIDs) + err = db.Table("team_repo").Where("team_id IN ?", teamIDs).Distinct("repo_id").Find(&teamRepoIDs).Error if err != nil { return nil, errors.Newf("get team repository ids: %v", err) } @@ -539,12 +535,12 @@ func (org *User) GetUserMirrorRepositories(userID int64) ([]*Repository, error) } repos := make([]*Repository, 0, 10) - if err = x.Where("owner_id = ?", org.ID). - And("is_private = ?", false). - Or(builder.In("id", teamRepoIDs)). - And("is_mirror = ?", true). // Don't move up because it's an independent condition - Desc("updated_unix"). - Find(&repos); err != nil { + if err = db.Where("owner_id = ?", org.ID). + Where("is_private = ?", false). + Or("id IN ?", teamRepoIDs). + Where("is_mirror = ?", true). // Don't move up because it's an independent condition + Order("updated_unix DESC"). + Find(&repos).Error; err != nil { return nil, errors.Newf("get user repositories: %v", err) } return repos, nil diff --git a/internal/database/repo.go b/internal/database/repo.go index 2935e6c47..b452036a6 100644 --- a/internal/database/repo.go +++ b/internal/database/repo.go @@ -549,7 +549,7 @@ func (r *Repository) GetAssigneeByID(userID int64) (*User, error) { // GetWriters returns all users that have write access to the repository. func (r *Repository) GetWriters() (_ []*User, err error) { - return r.getUsersWithAccesMode(x, AccessModeWrite) + return r.getUsersWithAccesMode(db, AccessModeWrite) } // GetMilestoneByID returns the milestone belongs to repository by given ID. @@ -1230,7 +1230,7 @@ func CreateRepository(doer, owner *User, opts CreateRepoOptionsLegacy) (_ *Repos EnablePulls: true, } - err := db.Transaction(func(tx *gorm.DB) error { + err = db.Transaction(func(tx *gorm.DB) error { if err := createRepository(tx, doer, owner, repo); err != nil { return err } diff --git a/internal/database/repo_branch.go b/internal/database/repo_branch.go index 97d9ec64a..2ee730432 100644 --- a/internal/database/repo_branch.go +++ b/internal/database/repo_branch.go @@ -8,6 +8,7 @@ import ( "github.com/cockroachdb/errors" "github.com/gogs/git-module" "github.com/unknwon/com" + "gorm.io/gorm" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/tool" @@ -93,8 +94,9 @@ type ProtectBranchWhitelist struct { // IsUserInProtectBranchWhitelist returns true if given user is in the whitelist of a branch in a repository. func IsUserInProtectBranchWhitelist(repoID, userID int64, branch string) bool { - has, err := x.Where("repo_id = ?", repoID).And("user_id = ?", userID).And("name = ?", branch).Get(new(ProtectBranchWhitelist)) - return has && err == nil + var whitelist ProtectBranchWhitelist + err := db.Where("repo_id = ?", repoID).Where("user_id = ?", userID).Where("name = ?", branch).First(&whitelist).Error + return err == nil } // ProtectBranch contains options of a protected branch. @@ -115,11 +117,11 @@ func GetProtectBranchOfRepoByName(repoID int64, name string) (*ProtectBranch, er RepoID: repoID, Name: name, } - has, err := x.Get(protectBranch) - if err != nil { - return nil, err - } else if !has { + err := db.Where("repo_id = ? AND name = ?", repoID, name).First(protectBranch).Error + if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrBranchNotExist{args: map[string]any{"name": name}} + } else if err != nil { + return nil, err } return protectBranch, nil } @@ -136,23 +138,19 @@ func IsBranchOfRepoRequirePullRequest(repoID int64, name string) bool { // UpdateProtectBranch saves branch protection options. // If ID is 0, it creates a new record. Otherwise, updates existing record. func UpdateProtectBranch(protectBranch *ProtectBranch) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if protectBranch.ID == 0 { - if _, err = sess.Insert(protectBranch); err != nil { - return errors.Newf("insert: %v", err) + return db.Transaction(func(tx *gorm.DB) error { + if protectBranch.ID == 0 { + if err := tx.Create(protectBranch).Error; err != nil { + return errors.Newf("insert: %v", err) + } } - } - if _, err = sess.ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil { - return errors.Newf("update: %v", err) - } + if err := tx.Model(&ProtectBranch{}).Where("id = ?", protectBranch.ID).Updates(protectBranch).Error; err != nil { + return errors.Newf("update: %v", err) + } - return sess.Commit() + return nil + }) } // UpdateOrgProtectBranch saves branch protection options of organizational repository. @@ -209,7 +207,7 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit // Make sure protectBranch.ID is not 0 for whitelists if protectBranch.ID == 0 { - if _, err = x.Insert(protectBranch); err != nil { + if err = db.Create(protectBranch).Error; err != nil { return errors.Newf("insert: %v", err) } } @@ -247,30 +245,29 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit } } - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if _, err = sess.ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil { - return errors.Newf("Update: %v", err) - } - - // Refresh whitelists - if hasUsersChanged || hasTeamsChanged { - if _, err = sess.Delete(&ProtectBranchWhitelist{ProtectBranchID: protectBranch.ID}); err != nil { - return errors.Newf("delete old protect branch whitelists: %v", err) - } else if _, err = sess.Insert(whitelists); err != nil { - return errors.Newf("insert new protect branch whitelists: %v", err) + return db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&ProtectBranch{}).Where("id = ?", protectBranch.ID).Updates(protectBranch).Error; err != nil { + return errors.Newf("Update: %v", err) } - } - return sess.Commit() + // Refresh whitelists + if hasUsersChanged || hasTeamsChanged { + if err := tx.Delete(&ProtectBranchWhitelist{}, "protect_branch_id = ?", protectBranch.ID).Error; err != nil { + return errors.Newf("delete old protect branch whitelists: %v", err) + } + if len(whitelists) > 0 { + if err := tx.Create(&whitelists).Error; err != nil { + return errors.Newf("insert new protect branch whitelists: %v", err) + } + } + } + + return nil + }) } // GetProtectBranchesByRepoID returns a list of *ProtectBranch in given repository. func GetProtectBranchesByRepoID(repoID int64) ([]*ProtectBranch, error) { protectBranches := make([]*ProtectBranch, 0, 2) - return protectBranches, x.Where("repo_id = ? and protected = ?", repoID, true).Asc("name").Find(&protectBranches) + return protectBranches, db.Where("repo_id = ? AND protected = ?", repoID, true).Order("name ASC").Find(&protectBranches).Error } diff --git a/internal/database/repo_collaboration.go b/internal/database/repo_collaboration.go index 52adf4c8a..56cfda00d 100644 --- a/internal/database/repo_collaboration.go +++ b/internal/database/repo_collaboration.go @@ -35,12 +35,12 @@ func IsCollaborator(repoID, userID int64) bool { RepoID: repoID, UserID: userID, } - has, err := x.Get(collaboration) + err := db.Where("repo_id = ? AND user_id = ?", repoID, userID).First(collaboration).Error if err != nil { log.Error("get collaboration [repo_id: %d, user_id: %d]: %v", repoID, userID, err) return false } - return has + return true } func (r *Repository) IsCollaborator(userID int64) bool { @@ -54,27 +54,24 @@ func (r *Repository) AddCollaborator(u *User) error { UserID: u.ID, } - has, err := x.Get(collaboration) - if err != nil { - return err - } else if has { + var existing Collaboration + err := db.Where("repo_id = ? AND user_id = ?", r.ID, u.ID).First(&existing).Error + if err == nil { return nil + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return err } collaboration.Mode = AccessModeWrite - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if _, err = sess.Insert(collaboration); err != nil { - return err - } else if err = r.recalculateAccesses(sess); err != nil { - return errors.Newf("recalculateAccesses [repo_id: %v]: %v", r.ID, err) - } - - return sess.Commit() + return db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(collaboration).Error; err != nil { + return err + } + if err := r.recalculateAccesses(tx); err != nil { + return errors.Newf("recalculateAccesses [repo_id: %v]: %v", r.ID, err) + } + return nil + }) } func (r *Repository) getCollaborations(e *gorm.DB) ([]*Collaboration, error) { @@ -121,7 +118,7 @@ func (r *Repository) getCollaborators(e *gorm.DB) ([]*Collaborator, error) { // GetCollaborators returns the collaborators for a repository func (r *Repository) GetCollaborators() ([]*Collaborator, error) { - return r.getCollaborators(x) + return r.getCollaborators(db) } // ChangeCollaborationAccessMode sets new access mode for the collaboration. @@ -135,11 +132,11 @@ func (r *Repository) ChangeCollaborationAccessMode(userID int64, mode AccessMode RepoID: r.ID, UserID: userID, } - has, err := x.Get(collaboration) - if err != nil { - return errors.Newf("get collaboration: %v", err) - } else if !has { + err := db.Where("repo_id = ? AND user_id = ?", r.ID, userID).First(collaboration).Error + if errors.Is(err, gorm.ErrRecordNotFound) { return nil + } else if err != nil { + return errors.Newf("get collaboration: %v", err) } if collaboration.Mode == mode { @@ -160,35 +157,31 @@ func (r *Repository) ChangeCollaborationAccessMode(userID int64, mode AccessMode } } - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } + return db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&Collaboration{}).Where("id = ?", collaboration.ID).Updates(collaboration).Error; err != nil { + return errors.Newf("update collaboration: %v", err) + } - if _, err = sess.ID(collaboration.ID).AllCols().Update(collaboration); err != nil { - return errors.Newf("update collaboration: %v", err) - } + access := &Access{ + UserID: userID, + RepoID: r.ID, + } + err := tx.Where("user_id = ? AND repo_id = ?", userID, r.ID).First(access).Error + if err == nil { + if err := tx.Exec("UPDATE access SET mode = ? WHERE user_id = ? AND repo_id = ?", mode, userID, r.ID).Error; err != nil { + return errors.Newf("update access table: %v", err) + } + } else if errors.Is(err, gorm.ErrRecordNotFound) { + access.Mode = mode + if err := tx.Create(access).Error; err != nil { + return errors.Newf("insert access table: %v", err) + } + } else { + return errors.Newf("get access record: %v", err) + } - access := &Access{ - UserID: userID, - RepoID: r.ID, - } - has, err = sess.Get(access) - if err != nil { - return errors.Newf("get access record: %v", err) - } - if has { - _, err = sess.Exec("UPDATE access SET mode = ? WHERE user_id = ? AND repo_id = ?", mode, userID, r.ID) - } else { - access.Mode = mode - _, err = sess.Insert(access) - } - if err != nil { - return errors.Newf("update/insert access table: %v", err) - } - - return sess.Commit() + return nil + }) } // DeleteCollaboration removes collaboration relation between the user and repository. @@ -202,19 +195,20 @@ func DeleteCollaboration(repo *Repository, userID int64) (err error) { UserID: userID, } - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } + return db.Transaction(func(tx *gorm.DB) error { + result := tx.Delete(collaboration, "repo_id = ? AND user_id = ?", repo.ID, userID) + if result.Error != nil { + return result.Error + } else if result.RowsAffected == 0 { + return nil + } - if has, err := sess.Delete(collaboration); err != nil || has == 0 { - return err - } else if err = repo.recalculateAccesses(sess); err != nil { - return err - } + if err := repo.recalculateAccesses(tx); err != nil { + return err + } - return sess.Commit() + return nil + }) } func (r *Repository) DeleteCollaboration(userID int64) error {