diff --git a/internal/database/issue.go b/internal/database/issue.go index f55fda552..0d6f48e10 100644 --- a/internal/database/issue.go +++ b/internal/database/issue.go @@ -8,8 +8,8 @@ import ( "github.com/cockroachdb/errors" "github.com/unknwon/com" + "gorm.io/gorm" log "unknwon.dev/clog/v2" - "xorm.io/xorm" api "github.com/gogs/go-gogs-client" @@ -65,25 +65,22 @@ func (issue *Issue) BeforeUpdate() { issue.DeadlineUnix = issue.Deadline.Unix() } -func (issue *Issue) AfterSet(colName string, _ xorm.Cell) { - switch colName { - case "deadline_unix": - issue.Deadline = time.Unix(issue.DeadlineUnix, 0).Local() - case "created_unix": - issue.Created = time.Unix(issue.CreatedUnix, 0).Local() - case "updated_unix": - issue.Updated = time.Unix(issue.UpdatedUnix, 0).Local() - } +func (issue *Issue) AfterFind(tx *gorm.DB) error { + issue.Deadline = time.Unix(issue.DeadlineUnix, 0).Local() + issue.Created = time.Unix(issue.CreatedUnix, 0).Local() + issue.Updated = time.Unix(issue.UpdatedUnix, 0).Local() + return nil } // Deprecated: Use Users.GetByID instead. -func getUserByID(e Engine, id int64) (*User, error) { +func getUserByID(db *gorm.DB, id int64) (*User, error) { u := new(User) - has, err := e.ID(id).Get(u) + err := db.First(u, id).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUserNotExist{args: errutil.Args{"userID": id}} + } return nil, err - } else if !has { - return nil, ErrUserNotExist{args: errutil.Args{"userID": id}} } // TODO(unknwon): Rely on AfterFind hook to sanitize user full name. @@ -91,16 +88,16 @@ func getUserByID(e Engine, id int64) (*User, error) { return u, nil } -func (issue *Issue) loadAttributes(e Engine) (err error) { +func (issue *Issue) loadAttributes(db *gorm.DB) (err error) { if issue.Repo == nil { - issue.Repo, err = getRepositoryByID(e, issue.RepoID) + issue.Repo, err = getRepositoryByID(db, issue.RepoID) if err != nil { return errors.Newf("getRepositoryByID [%d]: %v", issue.RepoID, err) } } if issue.Poster == nil { - issue.Poster, err = getUserByID(e, issue.PosterID) + issue.Poster, err = getUserByID(db, issue.PosterID) if err != nil { if IsErrUserNotExist(err) { issue.PosterID = -1 @@ -112,21 +109,21 @@ func (issue *Issue) loadAttributes(e Engine) (err error) { } if issue.Labels == nil { - issue.Labels, err = getLabelsByIssueID(e, issue.ID) + issue.Labels, err = getLabelsByIssueID(db, issue.ID) if err != nil { return errors.Newf("getLabelsByIssueID [%d]: %v", issue.ID, err) } } if issue.Milestone == nil && issue.MilestoneID > 0 { - issue.Milestone, err = getMilestoneByRepoID(e, issue.RepoID, issue.MilestoneID) + issue.Milestone, err = getMilestoneByRepoID(db, issue.RepoID, issue.MilestoneID) if err != nil { return errors.Newf("getMilestoneByRepoID [repo_id: %d, milestone_id: %d]: %v", issue.RepoID, issue.MilestoneID, err) } } if issue.Assignee == nil && issue.AssigneeID > 0 { - issue.Assignee, err = getUserByID(e, issue.AssigneeID) + issue.Assignee, err = getUserByID(db, issue.AssigneeID) if err != nil { return errors.Newf("getUserByID.(assignee) [%d]: %v", issue.AssigneeID, err) } @@ -134,21 +131,21 @@ func (issue *Issue) loadAttributes(e Engine) (err error) { if issue.IsPull && issue.PullRequest == nil { // It is possible pull request is not yet created. - issue.PullRequest, err = getPullRequestByIssueID(e, issue.ID) + issue.PullRequest, err = getPullRequestByIssueID(db, issue.ID) if err != nil && !IsErrPullRequestNotExist(err) { return errors.Newf("getPullRequestByIssueID [%d]: %v", issue.ID, err) } } if issue.Attachments == nil { - issue.Attachments, err = getAttachmentsByIssueID(e, issue.ID) + issue.Attachments, err = getAttachmentsByIssueID(db, issue.ID) if err != nil { return errors.Newf("getAttachmentsByIssueID [%d]: %v", issue.ID, err) } } if issue.Comments == nil { - issue.Comments, err = getCommentsByIssueID(e, issue.ID) + issue.Comments, err = getCommentsByIssueID(db, issue.ID) if err != nil { return errors.Newf("getCommentsByIssueID [%d]: %v", issue.ID, err) } @@ -158,7 +155,7 @@ func (issue *Issue) loadAttributes(e Engine) (err error) { } func (issue *Issue) LoadAttributes() error { - return issue.loadAttributes(x) + return issue.loadAttributes(db) } func (issue *Issue) HTMLURL() string { @@ -229,13 +226,13 @@ func (issue *Issue) IsPoster(uid int64) bool { return issue.PosterID == uid } -func (issue *Issue) hasLabel(e Engine, labelID int64) bool { - return hasIssueLabel(e, issue.ID, labelID) +func (issue *Issue) hasLabel(db *gorm.DB, labelID int64) bool { + return hasIssueLabel(db, issue.ID, labelID) } // HasLabel returns true if issue has been labeled by given ID. func (issue *Issue) HasLabel(labelID int64) bool { - return issue.hasLabel(x, labelID) + return issue.hasLabel(db, labelID) } func (issue *Issue) sendLabelUpdatedWebhook(doer *User) { @@ -267,8 +264,8 @@ func (issue *Issue) sendLabelUpdatedWebhook(doer *User) { } } -func (issue *Issue) addLabel(e *xorm.Session, label *Label) error { - return newIssueLabel(e, issue, label) +func (issue *Issue) addLabel(tx *gorm.DB, label *Label) error { + return newIssueLabel(tx, issue, label) } // AddLabel adds a new label to the issue. @@ -281,8 +278,8 @@ func (issue *Issue) AddLabel(doer *User, label *Label) error { return nil } -func (issue *Issue) addLabels(e *xorm.Session, labels []*Label) error { - return newIssueLabels(e, issue, labels) +func (issue *Issue) addLabels(tx *gorm.DB, labels []*Label) error { + return newIssueLabels(tx, issue, labels) } // AddLabels adds a list of new labels to the issue. @@ -295,20 +292,20 @@ func (issue *Issue) AddLabels(doer *User, labels []*Label) error { return nil } -func (issue *Issue) getLabels(e Engine) (err error) { +func (issue *Issue) getLabels(db *gorm.DB) (err error) { if len(issue.Labels) > 0 { return nil } - issue.Labels, err = getLabelsByIssueID(e, issue.ID) + issue.Labels, err = getLabelsByIssueID(db, issue.ID) if err != nil { return errors.Newf("getLabelsByIssueID: %v", err) } return nil } -func (issue *Issue) removeLabel(e *xorm.Session, label *Label) error { - return deleteIssueLabel(e, issue, label) +func (issue *Issue) removeLabel(tx *gorm.DB, label *Label) error { + return deleteIssueLabel(tx, issue, label) } // RemoveLabel removes a label from issue by given ID. @@ -321,8 +318,8 @@ func (issue *Issue) RemoveLabel(doer *User, label *Label) error { return nil } -func (issue *Issue) clearLabels(e *xorm.Session) (err error) { - if err = issue.getLabels(e); err != nil { +func (issue *Issue) clearLabels(tx *gorm.DB) (err error) { + if err = issue.getLabels(tx); err != nil { return errors.Newf("getLabels: %v", err) } @@ -330,7 +327,7 @@ func (issue *Issue) clearLabels(e *xorm.Session) (err error) { labels := make([]*Label, len(issue.Labels)) copy(labels, issue.Labels) for i := range labels { - if err = issue.removeLabel(e, labels[i]); err != nil { + if err = issue.removeLabel(tx, labels[i]); err != nil { return errors.Newf("removeLabel: %v", err) } } @@ -339,20 +336,13 @@ func (issue *Issue) clearLabels(e *xorm.Session) (err error) { } func (issue *Issue) ClearLabels(doer *User) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { + err = db.Transaction(func(tx *gorm.DB) error { + return issue.clearLabels(tx) + }) + if err != nil { return err } - if err = issue.clearLabels(sess); err != nil { - return err - } - - if err = sess.Commit(); err != nil { - return errors.Newf("commit: %v", err) - } - if issue.IsPull { err = issue.PullRequest.LoadIssue() if err != nil { @@ -383,20 +373,16 @@ func (issue *Issue) ClearLabels(doer *User) (err error) { } // ReplaceLabels removes all current labels and add new labels to the issue. -func (issue *Issue) ReplaceLabels(labels []*Label) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = issue.clearLabels(sess); err != nil { - return errors.Newf("clearLabels: %v", err) - } else if err = issue.addLabels(sess, labels); err != nil { - return errors.Newf("addLabels: %v", err) - } - - return sess.Commit() +func (issue *Issue) ReplaceLabels(labels []*Label) error { + return db.Transaction(func(tx *gorm.DB) error { + if err := issue.clearLabels(tx); err != nil { + return errors.Newf("clearLabels: %v", err) + } + if err := issue.addLabels(tx, labels); err != nil { + return errors.Newf("addLabels: %v", err) + } + return nil + }) } func (issue *Issue) GetAssignee() (err error) { @@ -416,32 +402,52 @@ func (issue *Issue) ReadBy(uid int64) error { return UpdateIssueUserByRead(uid, issue.ID) } -func updateIssueCols(e Engine, issue *Issue, cols ...string) error { - cols = append(cols, "updated_unix") - _, err := e.ID(issue.ID).Cols(cols...).Update(issue) - return err +func updateIssueCols(db *gorm.DB, issue *Issue, cols ...string) error { + updates := make(map[string]any) + for _, col := range cols { + switch col { + case "is_closed": + updates["is_closed"] = issue.IsClosed + case "priority": + updates["priority"] = issue.Priority + case "milestone_id": + updates["milestone_id"] = issue.MilestoneID + case "assignee_id": + updates["assignee_id"] = issue.AssigneeID + case "num_comments": + updates["num_comments"] = issue.NumComments + case "deadline_unix": + updates["deadline_unix"] = issue.DeadlineUnix + case "title": + updates["title"] = issue.Title + case "content": + updates["content"] = issue.Content + } + } + updates["updated_unix"] = time.Now().Unix() + return db.Model(&Issue{}).Where("id = ?", issue.ID).Updates(updates).Error } // UpdateIssueCols only updates values of specific columns for given issue. func UpdateIssueCols(issue *Issue, cols ...string) error { - return updateIssueCols(x, issue, cols...) + return updateIssueCols(db, issue, cols...) } -func (issue *Issue) changeStatus(e *xorm.Session, doer *User, repo *Repository, isClosed bool) (err error) { +func (issue *Issue) changeStatus(tx *gorm.DB, doer *User, repo *Repository, isClosed bool) (err error) { // Nothing should be performed if current status is same as target status if issue.IsClosed == isClosed { return nil } issue.IsClosed = isClosed - if err = updateIssueCols(e, issue, "is_closed"); err != nil { + if err = updateIssueCols(tx, issue, "is_closed"); err != nil { return err - } else if err = updateIssueUsersByStatus(e, issue.ID, isClosed); err != nil { + } else if err = updateIssueUsersByStatus(tx, issue.ID, isClosed); err != nil { return err } // Update issue count of labels - if err = issue.getLabels(e); err != nil { + if err = issue.getLabels(tx); err != nil { return err } for idx := range issue.Labels { @@ -450,18 +456,18 @@ func (issue *Issue) changeStatus(e *xorm.Session, doer *User, repo *Repository, } else { issue.Labels[idx].NumClosedIssues-- } - if err = updateLabel(e, issue.Labels[idx]); err != nil { + if err = updateLabel(tx, issue.Labels[idx]); err != nil { return err } } // Update issue count of milestone - if err = changeMilestoneIssueStats(e, issue); err != nil { + if err = changeMilestoneIssueStats(tx, issue); err != nil { return err } // New action comment - if _, err = createStatusComment(e, doer, repo, issue); err != nil { + if _, err = createStatusComment(tx, doer, repo, issue); err != nil { return err } @@ -470,20 +476,13 @@ func (issue *Issue) changeStatus(e *xorm.Session, doer *User, repo *Repository, // ChangeStatus changes issue status to open or closed. func (issue *Issue) ChangeStatus(doer *User, repo *Repository, isClosed bool) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { + err = db.Transaction(func(tx *gorm.DB) error { + return issue.changeStatus(tx, doer, repo, isClosed) + }) + if err != nil { return err } - if err = issue.changeStatus(sess, doer, repo, isClosed); err != nil { - return err - } - - if err = sess.Commit(); err != nil { - return errors.Newf("commit: %v", err) - } - if issue.IsPull { // Merge pull request calls issue.changeStatus so we need to handle separately. issue.PullRequest.Issue = issue @@ -661,12 +660,12 @@ type NewIssueOptions struct { IsPull bool } -func newIssue(e *xorm.Session, opts NewIssueOptions) (err error) { +func newIssue(tx *gorm.DB, opts NewIssueOptions) (err error) { opts.Issue.Title = strings.TrimSpace(opts.Issue.Title) opts.Issue.Index = opts.Repo.NextIssueIndex() if opts.Issue.MilestoneID > 0 { - milestone, err := getMilestoneByRepoID(e, opts.Issue.RepoID, opts.Issue.MilestoneID) + milestone, err := getMilestoneByRepoID(tx, opts.Issue.RepoID, opts.Issue.MilestoneID) if err != nil && !IsErrMilestoneNotExist(err) { return errors.Newf("getMilestoneByID: %v", err) } @@ -676,14 +675,14 @@ func newIssue(e *xorm.Session, opts NewIssueOptions) (err error) { if milestone != nil { opts.Issue.MilestoneID = milestone.ID opts.Issue.Milestone = milestone - if err = changeMilestoneAssign(e, opts.Issue, -1); err != nil { + if err = changeMilestoneAssign(tx, opts.Issue, -1); err != nil { return err } } } if opts.Issue.AssigneeID > 0 { - assignee, err := getUserByID(e, opts.Issue.AssigneeID) + assignee, err := getUserByID(tx, opts.Issue.AssigneeID) if err != nil && !IsErrUserNotExist(err) { return errors.Newf("get user by ID: %v", err) } @@ -743,36 +742,29 @@ func newIssue(e *xorm.Session, opts NewIssueOptions) (err error) { for i := 0; i < len(attachments); i++ { attachments[i].IssueID = opts.Issue.ID - if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil { + if err = tx.Model(&Attachment{}).Where("id = ?", attachments[i].ID).Updates(attachments[i]).Error; err != nil { return errors.Newf("update attachment [id: %d]: %v", attachments[i].ID, err) } } } - return opts.Issue.loadAttributes(e) + return opts.Issue.loadAttributes(tx) } // NewIssue creates new issue with labels and attachments for repository. func NewIssue(repo *Repository, issue *Issue, labelIDs []int64, uuids []string) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = newIssue(sess, NewIssueOptions{ - Repo: repo, - Issue: issue, - LableIDs: labelIDs, - Attachments: uuids, - }); err != nil { + err = db.Transaction(func(tx *gorm.DB) error { + return newIssue(tx, NewIssueOptions{ + Repo: repo, + Issue: issue, + LableIDs: labelIDs, + Attachments: uuids, + }) + }) + if err != nil { return errors.Newf("new issue: %v", err) } - if err = sess.Commit(); err != nil { - return errors.Newf("commit: %v", err) - } - if err = NotifyWatchers(&Action{ ActUserID: issue.Poster.ID, ActUserName: issue.Poster.Name, @@ -852,11 +844,12 @@ func GetRawIssueByIndex(repoID, index int64) (*Issue, error) { RepoID: repoID, Index: index, } - has, err := x.Get(issue) + err := db.Where("repo_id = ? AND `index` = ?", repoID, index).First(issue).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrIssueNotExist{args: map[string]any{"repoID": repoID, "index": index}} + } return nil, err - } else if !has { - return nil, ErrIssueNotExist{args: map[string]any{"repoID": repoID, "index": index}} } return issue, nil } @@ -870,28 +863,29 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) { return issue, issue.LoadAttributes() } -func getRawIssueByID(e Engine, id int64) (*Issue, error) { +func getRawIssueByID(db *gorm.DB, id int64) (*Issue, error) { issue := new(Issue) - has, err := e.ID(id).Get(issue) + err := db.First(issue, id).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrIssueNotExist{args: map[string]any{"issueID": id}} + } return nil, err - } else if !has { - return nil, ErrIssueNotExist{args: map[string]any{"issueID": id}} } return issue, nil } -func getIssueByID(e Engine, id int64) (*Issue, error) { - issue, err := getRawIssueByID(e, id) +func getIssueByID(db *gorm.DB, id int64) (*Issue, error) { + issue, err := getRawIssueByID(db, id) if err != nil { return nil, err } - return issue, issue.loadAttributes(e) + return issue, issue.loadAttributes(db) } // GetIssueByID returns an issue by given ID. func GetIssueByID(id int64) (*Issue, error) { - return getIssueByID(x, id) + return getIssueByID(db, id) } type IssuesOptions struct { @@ -910,93 +904,95 @@ type IssuesOptions struct { } // buildIssuesQuery returns nil if it foresees there won't be any value returned. -func buildIssuesQuery(opts *IssuesOptions) *xorm.Session { - sess := x.NewSession() +func buildIssuesQuery(opts *IssuesOptions) *gorm.DB { + query := db.Model(&Issue{}) if opts.Page <= 0 { opts.Page = 1 } if opts.RepoID > 0 { - sess.Where("issue.repo_id=?", opts.RepoID).And("issue.is_closed=?", opts.IsClosed) + query = query.Where("issue.repo_id = ?", opts.RepoID).Where("issue.is_closed = ?", opts.IsClosed) } else if opts.RepoIDs != nil { // In case repository IDs are provided but actually no repository has issue. if len(opts.RepoIDs) == 0 { return nil } - sess.In("issue.repo_id", opts.RepoIDs).And("issue.is_closed=?", opts.IsClosed) + query = query.Where("issue.repo_id IN ?", opts.RepoIDs).Where("issue.is_closed = ?", opts.IsClosed) } else { - sess.Where("issue.is_closed=?", opts.IsClosed) + query = query.Where("issue.is_closed = ?", opts.IsClosed) } if opts.AssigneeID > 0 { - sess.And("issue.assignee_id=?", opts.AssigneeID) + query = query.Where("issue.assignee_id = ?", opts.AssigneeID) } else if opts.PosterID > 0 { - sess.And("issue.poster_id=?", opts.PosterID) + query = query.Where("issue.poster_id = ?", opts.PosterID) } if opts.MilestoneID > 0 { - sess.And("issue.milestone_id=?", opts.MilestoneID) + query = query.Where("issue.milestone_id = ?", opts.MilestoneID) } - sess.And("issue.is_pull=?", opts.IsPull) + query = query.Where("issue.is_pull = ?", opts.IsPull) switch opts.SortType { case "oldest": - sess.Asc("issue.created_unix") + query = query.Order("issue.created_unix ASC") case "recentupdate": - sess.Desc("issue.updated_unix") + query = query.Order("issue.updated_unix DESC") case "leastupdate": - sess.Asc("issue.updated_unix") + query = query.Order("issue.updated_unix ASC") case "mostcomment": - sess.Desc("issue.num_comments") + query = query.Order("issue.num_comments DESC") case "leastcomment": - sess.Asc("issue.num_comments") + query = query.Order("issue.num_comments ASC") case "priority": - sess.Desc("issue.priority") + query = query.Order("issue.priority DESC") default: - sess.Desc("issue.created_unix") + query = query.Order("issue.created_unix DESC") } if len(opts.Labels) > 0 && opts.Labels != "0" { labelIDs := strings.Split(opts.Labels, ",") if len(labelIDs) > 0 { - sess.Join("INNER", "issue_label", "issue.id = issue_label.issue_id").In("issue_label.label_id", labelIDs) + query = query.Joins("INNER JOIN issue_label ON issue.id = issue_label.issue_id").Where("issue_label.label_id IN ?", labelIDs) } } if opts.IsMention { - sess.Join("INNER", "issue_user", "issue.id = issue_user.issue_id").And("issue_user.is_mentioned = ?", true) + query = query.Joins("INNER JOIN issue_user ON issue.id = issue_user.issue_id").Where("issue_user.is_mentioned = ?", true) if opts.UserID > 0 { - sess.And("issue_user.uid = ?", opts.UserID) + query = query.Where("issue_user.uid = ?", opts.UserID) } } - return sess + return query } // IssuesCount returns the number of issues by given conditions. func IssuesCount(opts *IssuesOptions) (int64, error) { - sess := buildIssuesQuery(opts) - if sess == nil { + query := buildIssuesQuery(opts) + if query == nil { return 0, nil } - return sess.Count(&Issue{}) + var count int64 + err := query.Count(&count).Error + return count, err } // Issues returns a list of issues by given conditions. func Issues(opts *IssuesOptions) ([]*Issue, error) { - sess := buildIssuesQuery(opts) - if sess == nil { + query := buildIssuesQuery(opts) + if query == nil { return make([]*Issue, 0), nil } - sess.Limit(conf.UI.IssuePagingNum, (opts.Page-1)*conf.UI.IssuePagingNum) + query = query.Limit(conf.UI.IssuePagingNum).Offset((opts.Page - 1) * conf.UI.IssuePagingNum) issues := make([]*Issue, 0, conf.UI.IssuePagingNum) - if err := sess.Find(&issues); err != nil { + if err := query.Find(&issues).Error; err != nil { return nil, errors.Newf("find: %v", err) } @@ -1013,10 +1009,10 @@ func Issues(opts *IssuesOptions) ([]*Issue, error) { // GetParticipantsByIssueID returns all users who are participated in comments of an issue. func GetParticipantsByIssueID(issueID int64) ([]*User, error) { userIDs := make([]int64, 0, 5) - if err := x.Table("comment").Cols("poster_id"). + if err := db.Table("comment"). + Select("DISTINCT poster_id"). Where("issue_id = ?", issueID). - Distinct("poster_id"). - Find(&userIDs); err != nil { + Pluck("poster_id", &userIDs).Error; err != nil { return nil, errors.Newf("get poster IDs: %v", err) } if len(userIDs) == 0 { @@ -1024,7 +1020,7 @@ func GetParticipantsByIssueID(issueID int64) ([]*User, error) { } users := make([]*User, 0, len(userIDs)) - return users, x.In("id", userIDs).Find(&users) + return users, db.Where("id IN ?", userIDs).Find(&users).Error } // .___ ____ ___ @@ -1048,8 +1044,8 @@ type IssueUser struct { IsClosed bool } -func newIssueUsers(e *xorm.Session, repo *Repository, issue *Issue) error { - assignees, err := repo.getAssignees(e) +func newIssueUsers(tx *gorm.DB, repo *Repository, issue *Issue) error { + assignees, err := repo.getAssignees(tx) if err != nil { return errors.Newf("getAssignees: %v", err) } @@ -1082,7 +1078,7 @@ func newIssueUsers(e *xorm.Session, repo *Repository, issue *Issue) error { }) } - if _, err = e.Insert(issueUsers); err != nil { + if err = tx.Create(issueUsers).Error; err != nil { return err } return nil @@ -1090,17 +1086,9 @@ func newIssueUsers(e *xorm.Session, repo *Repository, issue *Issue) error { // NewIssueUsers adds new issue-user relations for new issue of repository. func NewIssueUsers(repo *Repository, issue *Issue) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = newIssueUsers(sess, repo, issue); err != nil { - return err - } - - return sess.Commit() + return db.Transaction(func(tx *gorm.DB) error { + return newIssueUsers(tx, repo, issue) + }) } // PairsContains returns true when pairs list contains given issue. @@ -1117,7 +1105,7 @@ func PairsContains(ius []*IssueUser, issueID, uid int64) int { // GetIssueUsers returns issue-user pairs by given repository and user. func GetIssueUsers(rid, uid int64, isClosed bool) ([]*IssueUser, error) { ius := make([]*IssueUser, 0, 10) - err := x.Where("is_closed=?", isClosed).Find(&ius, &IssueUser{RepoID: rid, UserID: uid}) + err := db.Where("repo_id = ? AND uid = ? AND is_closed = ?", rid, uid, isClosed).Find(&ius).Error return ius, err } @@ -1128,34 +1116,33 @@ func GetIssueUserPairsByRepoIds(rids []int64, isClosed bool, page int) ([]*Issue } ius := make([]*IssueUser, 0, 10) - sess := x.Limit(20, (page-1)*20).Where("is_closed=?", isClosed).In("repo_id", rids) - err := sess.Find(&ius) + err := db.Limit(20).Offset((page-1)*20).Where("is_closed = ? AND repo_id IN ?", isClosed, rids).Find(&ius).Error return ius, err } // GetIssueUserPairsByMode returns issue-user pairs by given repository and user. func GetIssueUserPairsByMode(userID, repoID int64, filterMode FilterMode, isClosed bool, page int) ([]*IssueUser, error) { ius := make([]*IssueUser, 0, 10) - sess := x.Limit(20, (page-1)*20).Where("uid=?", userID).And("is_closed=?", isClosed) + query := db.Limit(20).Offset((page-1)*20).Where("uid = ? AND is_closed = ?", userID, isClosed) if repoID > 0 { - sess.And("repo_id=?", repoID) + query = query.Where("repo_id = ?", repoID) } switch filterMode { case FilterModeAssign: - sess.And("is_assigned=?", true) + query = query.Where("is_assigned = ?", true) case FilterModeCreate: - sess.And("is_poster=?", true) + query = query.Where("is_poster = ?", true) default: return ius, nil } - err := sess.Find(&ius) + err := query.Find(&ius).Error return ius, err } // updateIssueMentions extracts mentioned people from content and // updates issue-user relations for them. -func updateIssueMentions(e Engine, issueID int64, mentions []string) error { +func updateIssueMentions(db *gorm.DB, issueID int64, mentions []string) error { if len(mentions) == 0 { return nil } @@ -1165,7 +1152,7 @@ func updateIssueMentions(e Engine, issueID int64, mentions []string) error { } users := make([]*User, 0, len(mentions)) - if err := e.In("lower_name", mentions).Asc("lower_name").Find(&users); err != nil { + if err := db.Where("lower_name IN ?", mentions).Order("lower_name ASC").Find(&users).Error; err != nil { return errors.Newf("find mentioned users: %v", err) } @@ -1177,7 +1164,7 @@ func updateIssueMentions(e Engine, issueID int64, mentions []string) error { } memberIDs := make([]int64, 0, user.NumMembers) - orgUsers, err := getOrgUsersByOrgID(e, user.ID, 0) + orgUsers, err := getOrgUsersByOrgID(db, user.ID, 0) if err != nil { return errors.Newf("getOrgUsersByOrgID [%d]: %v", user.ID, err) } @@ -1189,7 +1176,7 @@ func updateIssueMentions(e Engine, issueID int64, mentions []string) error { ids = append(ids, memberIDs...) } - if err := updateIssueUsersByMentions(e, issueID, ids); err != nil { + if err := updateIssueUsersByMentions(db, issueID, ids); err != nil { return errors.Newf("UpdateIssueUsersByMentions: %v", err) } @@ -1238,60 +1225,43 @@ type IssueStatsOptions struct { func GetIssueStats(opts *IssueStatsOptions) *IssueStats { stats := &IssueStats{} - countSession := func(opts *IssueStatsOptions) *xorm.Session { - sess := x.Where("issue.repo_id = ?", opts.RepoID).And("is_pull = ?", opts.IsPull) + countSession := func(opts *IssueStatsOptions) *gorm.DB { + query := db.Table("issue").Where("issue.repo_id = ? AND is_pull = ?", opts.RepoID, opts.IsPull) if len(opts.Labels) > 0 && opts.Labels != "0" { labelIDs := tool.StringsToInt64s(strings.Split(opts.Labels, ",")) if len(labelIDs) > 0 { - sess.Join("INNER", "issue_label", "issue.id = issue_id").In("label_id", labelIDs) + query = query.Joins("INNER JOIN issue_label ON issue.id = issue_id").Where("label_id IN ?", labelIDs) } } if opts.MilestoneID > 0 { - sess.And("issue.milestone_id = ?", opts.MilestoneID) + query = query.Where("issue.milestone_id = ?", opts.MilestoneID) } if opts.AssigneeID > 0 { - sess.And("assignee_id = ?", opts.AssigneeID) + query = query.Where("assignee_id = ?", opts.AssigneeID) } - return sess + return query } switch opts.FilterMode { case FilterModeYourRepos, FilterModeAssign: - stats.OpenCount, _ = countSession(opts). - And("is_closed = ?", false). - Count(new(Issue)) - - stats.ClosedCount, _ = countSession(opts). - And("is_closed = ?", true). - Count(new(Issue)) + countSession(opts).Where("is_closed = ?", false).Count(&stats.OpenCount) + countSession(opts).Where("is_closed = ?", true).Count(&stats.ClosedCount) case FilterModeCreate: - stats.OpenCount, _ = countSession(opts). - And("poster_id = ?", opts.UserID). - And("is_closed = ?", false). - Count(new(Issue)) - - stats.ClosedCount, _ = countSession(opts). - And("poster_id = ?", opts.UserID). - And("is_closed = ?", true). - Count(new(Issue)) + countSession(opts).Where("poster_id = ? AND is_closed = ?", opts.UserID, false).Count(&stats.OpenCount) + countSession(opts).Where("poster_id = ? AND is_closed = ?", opts.UserID, true).Count(&stats.ClosedCount) case FilterModeMention: - stats.OpenCount, _ = countSession(opts). - Join("INNER", "issue_user", "issue.id = issue_user.issue_id"). - And("issue_user.uid = ?", opts.UserID). - And("issue_user.is_mentioned = ?", true). - And("issue.is_closed = ?", false). - Count(new(Issue)) - - stats.ClosedCount, _ = countSession(opts). - Join("INNER", "issue_user", "issue.id = issue_user.issue_id"). - And("issue_user.uid = ?", opts.UserID). - And("issue_user.is_mentioned = ?", true). - And("issue.is_closed = ?", true). - Count(new(Issue)) + countSession(opts). + Joins("INNER JOIN issue_user ON issue.id = issue_user.issue_id"). + Where("issue_user.uid = ? AND issue_user.is_mentioned = ? AND issue.is_closed = ?", opts.UserID, true, false). + Count(&stats.OpenCount) + countSession(opts). + Joins("INNER JOIN issue_user ON issue.id = issue_user.issue_id"). + Where("issue_user.uid = ? AND issue_user.is_mentioned = ? AND issue.is_closed = ?", opts.UserID, true, true). + Count(&stats.ClosedCount) } return stats } @@ -1300,29 +1270,23 @@ func GetIssueStats(opts *IssueStatsOptions) *IssueStats { func GetUserIssueStats(repoID, userID int64, repoIDs []int64, filterMode FilterMode, isPull bool) *IssueStats { stats := &IssueStats{} hasAnyRepo := repoID > 0 || len(repoIDs) > 0 - countSession := func(isClosed, isPull bool, repoID int64, repoIDs []int64) *xorm.Session { - sess := x.Where("issue.is_closed = ?", isClosed).And("issue.is_pull = ?", isPull) + countSession := func(isClosed, isPull bool, repoID int64, repoIDs []int64) *gorm.DB { + query := db.Table("issue").Where("issue.is_closed = ? AND issue.is_pull = ?", isClosed, isPull) if repoID > 0 { - sess.And("repo_id = ?", repoID) + query = query.Where("repo_id = ?", repoID) } else if len(repoIDs) > 0 { - sess.In("repo_id", repoIDs) + query = query.Where("repo_id IN ?", repoIDs) } - return sess + return query } - stats.AssignCount, _ = countSession(false, isPull, repoID, nil). - And("assignee_id = ?", userID). - Count(new(Issue)) - - stats.CreateCount, _ = countSession(false, isPull, repoID, nil). - And("poster_id = ?", userID). - Count(new(Issue)) + countSession(false, isPull, repoID, nil).Where("assignee_id = ?", userID).Count(&stats.AssignCount) + countSession(false, isPull, repoID, nil).Where("poster_id = ?", userID).Count(&stats.CreateCount) if hasAnyRepo { - stats.YourReposCount, _ = countSession(false, isPull, repoID, repoIDs). - Count(new(Issue)) + countSession(false, isPull, repoID, repoIDs).Count(&stats.YourReposCount) } switch filterMode { @@ -1330,25 +1294,14 @@ func GetUserIssueStats(repoID, userID int64, repoIDs []int64, filterMode FilterM if !hasAnyRepo { break } - - stats.OpenCount, _ = countSession(false, isPull, repoID, repoIDs). - Count(new(Issue)) - stats.ClosedCount, _ = countSession(true, isPull, repoID, repoIDs). - Count(new(Issue)) + countSession(false, isPull, repoID, repoIDs).Count(&stats.OpenCount) + countSession(true, isPull, repoID, repoIDs).Count(&stats.ClosedCount) case FilterModeAssign: - stats.OpenCount, _ = countSession(false, isPull, repoID, nil). - And("assignee_id = ?", userID). - Count(new(Issue)) - stats.ClosedCount, _ = countSession(true, isPull, repoID, nil). - And("assignee_id = ?", userID). - Count(new(Issue)) + countSession(false, isPull, repoID, nil).Where("assignee_id = ?", userID).Count(&stats.OpenCount) + countSession(true, isPull, repoID, nil).Where("assignee_id = ?", userID).Count(&stats.ClosedCount) case FilterModeCreate: - stats.OpenCount, _ = countSession(false, isPull, repoID, nil). - And("poster_id = ?", userID). - Count(new(Issue)) - stats.ClosedCount, _ = countSession(true, isPull, repoID, nil). - And("poster_id = ?", userID). - Count(new(Issue)) + countSession(false, isPull, repoID, nil).Where("poster_id = ?", userID).Count(&stats.OpenCount) + countSession(true, isPull, repoID, nil).Where("poster_id = ?", userID).Count(&stats.ClosedCount) } return stats @@ -1356,12 +1309,8 @@ func GetUserIssueStats(repoID, userID int64, repoIDs []int64, filterMode FilterM // GetRepoIssueStats returns number of open and closed repository issues by given filter mode. func GetRepoIssueStats(repoID, userID int64, filterMode FilterMode, isPull bool) (numOpen, numClosed int64) { - countSession := func(isClosed, isPull bool, repoID int64) *xorm.Session { - sess := x.Where("issue.repo_id = ?", isClosed). - And("is_pull = ?", isPull). - And("repo_id = ?", repoID) - - return sess + countSession := func(isClosed, isPull bool, repoID int64) *gorm.DB { + return db.Table("issue").Where("issue.is_closed = ? AND is_pull = ? AND repo_id = ?", isClosed, isPull, repoID) } openCountSession := countSession(false, isPull, repoID) @@ -1369,92 +1318,81 @@ func GetRepoIssueStats(repoID, userID int64, filterMode FilterMode, isPull bool) switch filterMode { case FilterModeAssign: - openCountSession.And("assignee_id = ?", userID) - closedCountSession.And("assignee_id = ?", userID) + openCountSession = openCountSession.Where("assignee_id = ?", userID) + closedCountSession = closedCountSession.Where("assignee_id = ?", userID) case FilterModeCreate: - openCountSession.And("poster_id = ?", userID) - closedCountSession.And("poster_id = ?", userID) + openCountSession = openCountSession.Where("poster_id = ?", userID) + closedCountSession = closedCountSession.Where("poster_id = ?", userID) } - openResult, _ := openCountSession.Count(new(Issue)) - closedResult, _ := closedCountSession.Count(new(Issue)) + openCountSession.Count(&numOpen) + closedCountSession.Count(&numClosed) - return openResult, closedResult + return numOpen, numClosed } -func updateIssue(e Engine, issue *Issue) error { - _, err := e.ID(issue.ID).AllCols().Update(issue) - return err +func updateIssue(db *gorm.DB, issue *Issue) error { + return db.Model(&Issue{}).Where("id = ?", issue.ID).Updates(issue).Error } // UpdateIssue updates all fields of given issue. func UpdateIssue(issue *Issue) error { - return updateIssue(x, issue) + return updateIssue(db, issue) } -func updateIssueUsersByStatus(e Engine, issueID int64, isClosed bool) error { - _, err := e.Exec("UPDATE `issue_user` SET is_closed=? WHERE issue_id=?", isClosed, issueID) - return err +func updateIssueUsersByStatus(db *gorm.DB, issueID int64, isClosed bool) error { + return db.Exec("UPDATE `issue_user` SET is_closed = ? WHERE issue_id = ?", isClosed, issueID).Error } // UpdateIssueUsersByStatus updates issue-user relations by issue status. func UpdateIssueUsersByStatus(issueID int64, isClosed bool) error { - return updateIssueUsersByStatus(x, issueID, isClosed) + return updateIssueUsersByStatus(db, issueID, isClosed) } -func updateIssueUserByAssignee(e *xorm.Session, issue *Issue) (err error) { - if _, err = e.Exec("UPDATE `issue_user` SET is_assigned = ? WHERE issue_id = ?", false, issue.ID); err != nil { +func updateIssueUserByAssignee(tx *gorm.DB, issue *Issue) (err error) { + if err = tx.Exec("UPDATE `issue_user` SET is_assigned = ? WHERE issue_id = ?", false, issue.ID).Error; err != nil { return err } // Assignee ID equals to 0 means clear assignee. if issue.AssigneeID > 0 { - if _, err = e.Exec("UPDATE `issue_user` SET is_assigned = ? WHERE uid = ? AND issue_id = ?", true, issue.AssigneeID, issue.ID); err != nil { + if err = tx.Exec("UPDATE `issue_user` SET is_assigned = ? WHERE uid = ? AND issue_id = ?", true, issue.AssigneeID, issue.ID).Error; err != nil { return err } } - return updateIssue(e, issue) + return updateIssue(tx, issue) } // UpdateIssueUserByAssignee updates issue-user relation for assignee. func UpdateIssueUserByAssignee(issue *Issue) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = updateIssueUserByAssignee(sess, issue); err != nil { - return err - } - - return sess.Commit() + return db.Transaction(func(tx *gorm.DB) error { + return updateIssueUserByAssignee(tx, issue) + }) } // UpdateIssueUserByRead updates issue-user relation for reading. func UpdateIssueUserByRead(uid, issueID int64) error { - _, err := x.Exec("UPDATE `issue_user` SET is_read=? WHERE uid=? AND issue_id=?", true, uid, issueID) - return err + return db.Exec("UPDATE `issue_user` SET is_read = ? WHERE uid = ? AND issue_id = ?", true, uid, issueID).Error } // updateIssueUsersByMentions updates issue-user pairs by mentioning. -func updateIssueUsersByMentions(e Engine, issueID int64, uids []int64) error { +func updateIssueUsersByMentions(db *gorm.DB, issueID int64, uids []int64) error { for _, uid := range uids { iu := &IssueUser{ UserID: uid, IssueID: issueID, } - has, err := e.Get(iu) - if err != nil { + err := db.Where("uid = ? AND issue_id = ?", uid, issueID).First(iu).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } iu.IsMentioned = true - if has { - _, err = e.ID(iu.ID).AllCols().Update(iu) + if errors.Is(err, gorm.ErrRecordNotFound) { + err = db.Create(iu).Error } else { - _, err = e.Insert(iu) + err = db.Model(&IssueUser{}).Where("id = ?", iu.ID).Updates(iu).Error } if err != nil { return err diff --git a/internal/database/pull.go b/internal/database/pull.go index 90a32e052..46a9d4849 100644 --- a/internal/database/pull.go +++ b/internal/database/pull.go @@ -9,8 +9,8 @@ import ( "github.com/cockroachdb/errors" "github.com/unknwon/com" + "gorm.io/gorm" log "unknwon.dev/clog/v2" - "xorm.io/xorm" "github.com/gogs/git-module" api "github.com/gogs/go-gogs-client" @@ -71,35 +71,31 @@ func (pr *PullRequest) BeforeUpdate() { } // Note: don't try to get Issue because will end up recursive querying. -func (pr *PullRequest) AfterSet(colName string, _ xorm.Cell) { - switch colName { - case "merged_unix": - if !pr.HasMerged { - return - } - +func (pr *PullRequest) AfterFind(tx *gorm.DB) error { + if pr.HasMerged { pr.Merged = time.Unix(pr.MergedUnix, 0).Local() } + return nil } // Note: don't try to get Issue because will end up recursive querying. -func (pr *PullRequest) loadAttributes(e Engine) (err error) { +func (pr *PullRequest) loadAttributes(db *gorm.DB) (err error) { if pr.HeadRepo == nil { - pr.HeadRepo, err = getRepositoryByID(e, pr.HeadRepoID) + pr.HeadRepo, err = getRepositoryByID(db, pr.HeadRepoID) if err != nil && !IsErrRepoNotExist(err) { return errors.Newf("get head repository by ID: %v", err) } } if pr.BaseRepo == nil { - pr.BaseRepo, err = getRepositoryByID(e, pr.BaseRepoID) + pr.BaseRepo, err = getRepositoryByID(db, pr.BaseRepoID) if err != nil { return errors.Newf("get base repository by ID: %v", err) } } if pr.HasMerged && pr.Merger == nil { - pr.Merger, err = getUserByID(e, pr.MergerID) + pr.Merger, err = getUserByID(db, pr.MergerID) if IsErrUserNotExist(err) { pr.MergerID = -1 pr.Merger = NewGhostUser() @@ -112,7 +108,7 @@ func (pr *PullRequest) loadAttributes(e Engine) (err error) { } func (pr *PullRequest) LoadAttributes() error { - return pr.loadAttributes(x) + return pr.loadAttributes(db) } func (pr *PullRequest) LoadIssue() (err error) { @@ -199,15 +195,10 @@ func (pr *PullRequest) Merge(doer *User, baseGitRepo *git.Repository, mergeStyle go AddTestPullRequestTask(doer, pr.BaseRepo.ID, pr.BaseBranch, false) }() - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = pr.Issue.changeStatus(sess, doer, pr.Issue.Repo, true); err != nil { - return errors.Newf("Issue.changeStatus: %v", err) - } + return db.Transaction(func(tx *gorm.DB) error { + if err := pr.Issue.changeStatus(tx, doer, pr.Issue.Repo, true); err != nil { + return errors.Newf("Issue.changeStatus: %v", err) + } headRepoPath := RepoPath(pr.HeadUserName, pr.HeadRepo.Name) headGitRepo, err := git.Open(headRepoPath) @@ -443,30 +434,25 @@ func (pr *PullRequest) testPatch() (err error) { // NewPullRequest creates new pull request with labels for repository. func NewPullRequest(repo *Repository, pull *Issue, labelIDs []int64, uuids []string, pr *PullRequest, patch []byte) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } + err = db.Transaction(func(tx *gorm.DB) error { + if err := newIssue(tx, NewIssueOptions{ + Repo: repo, + Issue: pull, + LableIDs: labelIDs, + Attachments: uuids, + IsPull: true, + }); err != nil { + return errors.Newf("newIssue: %v", err) + } - if err = newIssue(sess, NewIssueOptions{ - Repo: repo, - Issue: pull, - LableIDs: labelIDs, - Attachments: uuids, - IsPull: true, - }); err != nil { - return errors.Newf("newIssue: %v", err) - } + pr.Index = pull.Index + if err := repo.SavePatch(pr.Index, patch); err != nil { + return errors.Newf("SavePatch: %v", err) + } - pr.Index = pull.Index - if err = repo.SavePatch(pr.Index, patch); err != nil { - return errors.Newf("SavePatch: %v", err) - } - - pr.BaseRepo = repo - if err = pr.testPatch(); err != nil { - return errors.Newf("testPatch: %v", err) + pr.BaseRepo = repo + if err := pr.testPatch(); err != nil { + return errors.Newf("testPatch: %v", err) } // No conflict appears after test means mergeable. if pr.Status == PullRequestStatusChecking { @@ -517,18 +503,20 @@ func NewPullRequest(repo *Repository, pull *Issue, labelIDs []int64, uuids []str // by given head/base and repo/branch. func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch string) (*PullRequest, error) { pr := new(PullRequest) - has, err := x.Where("head_repo_id=? AND head_branch=? AND base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?", - headRepoID, headBranch, baseRepoID, baseBranch, false, false). - Join("INNER", "issue", "issue.id=pull_request.issue_id").Get(pr) + err := db.Joins("INNER JOIN issue ON issue.id = pull_request.issue_id"). + Where("pull_request.head_repo_id = ? AND pull_request.head_branch = ? AND pull_request.base_repo_id = ? AND pull_request.base_branch = ? AND pull_request.has_merged = ? AND issue.is_closed = ?", + headRepoID, headBranch, baseRepoID, baseBranch, false, false). + First(pr).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrPullRequestNotExist{args: map[string]any{ + "headRepoID": headRepoID, + "baseRepoID": baseRepoID, + "headBranch": headBranch, + "baseBranch": baseBranch, + }} + } return nil, err - } else if !has { - return nil, ErrPullRequestNotExist{args: map[string]any{ - "headRepoID": headRepoID, - "baseRepoID": baseRepoID, - "headBranch": headBranch, - "baseBranch": baseBranch, - }} } return pr, nil @@ -538,18 +526,22 @@ func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch // by given head information (repo and branch). func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - return prs, x.Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ?", - repoID, branch, false, false). - Join("INNER", "issue", "issue.id = pull_request.issue_id").Find(&prs) + err := db.Joins("INNER JOIN issue ON issue.id = pull_request.issue_id"). + Where("pull_request.head_repo_id = ? AND pull_request.head_branch = ? AND pull_request.has_merged = ? AND issue.is_closed = ?", + repoID, branch, false, false). + Find(&prs).Error + return prs, err } // GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged // by given base information (repo and branch). func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - return prs, x.Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?", - repoID, branch, false, false). - Join("INNER", "issue", "issue.id=pull_request.issue_id").Find(&prs) + err := db.Joins("INNER JOIN issue ON issue.id = pull_request.issue_id"). + Where("pull_request.base_repo_id = ? AND pull_request.base_branch = ? AND pull_request.has_merged = ? AND issue.is_closed = ?", + repoID, branch, false, false). + Find(&prs).Error + return prs, err } var _ errutil.NotFound = (*ErrPullRequestNotExist)(nil) @@ -571,50 +563,65 @@ func (ErrPullRequestNotExist) NotFound() bool { return true } -func getPullRequestByID(e Engine, id int64) (*PullRequest, error) { +func getPullRequestByID(db *gorm.DB, id int64) (*PullRequest, error) { pr := new(PullRequest) - has, err := e.ID(id).Get(pr) + err := db.First(pr, id).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrPullRequestNotExist{args: map[string]any{"pullRequestID": id}} + } return nil, err - } else if !has { - return nil, ErrPullRequestNotExist{args: map[string]any{"pullRequestID": id}} } - return pr, pr.loadAttributes(e) + return pr, pr.loadAttributes(db) } // GetPullRequestByID returns a pull request by given ID. func GetPullRequestByID(id int64) (*PullRequest, error) { - return getPullRequestByID(x, id) + return getPullRequestByID(db, id) } -func getPullRequestByIssueID(e Engine, issueID int64) (*PullRequest, error) { - pr := &PullRequest{ - IssueID: issueID, - } - has, err := e.Get(pr) +func getPullRequestByIssueID(db *gorm.DB, issueID int64) (*PullRequest, error) { + pr := &PullRequest{} + err := db.Where("issue_id = ?", issueID).First(pr).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrPullRequestNotExist{args: map[string]any{"issueID": issueID}} + } return nil, err - } else if !has { - return nil, ErrPullRequestNotExist{args: map[string]any{"issueID": issueID}} } - return pr, pr.loadAttributes(e) + return pr, pr.loadAttributes(db) } // GetPullRequestByIssueID returns pull request by given issue ID. func GetPullRequestByIssueID(issueID int64) (*PullRequest, error) { - return getPullRequestByIssueID(x, issueID) + return getPullRequestByIssueID(db, issueID) } // Update updates all fields of pull request. func (pr *PullRequest) Update() error { - _, err := x.Id(pr.ID).AllCols().Update(pr) - return err + return db.Model(&PullRequest{}).Where("id = ?", pr.ID).Updates(pr).Error } // Update updates specific fields of pull request. func (pr *PullRequest) UpdateCols(cols ...string) error { - _, err := x.Id(pr.ID).Cols(cols...).Update(pr) - return err + updates := make(map[string]any) + for _, col := range cols { + switch col { + case "status": + updates["status"] = pr.Status + case "merge_base": + updates["merge_base"] = pr.MergeBase + case "has_merged": + updates["has_merged"] = pr.HasMerged + case "merged_commit_id": + updates["merged_commit_id"] = pr.MergedCommitID + case "merger_id": + updates["merger_id"] = pr.MergerID + case "merged_unix": + updates["merged_unix"] = pr.MergedUnix + } + } + return db.Model(&PullRequest{}).Where("id = ?", pr.ID).Updates(updates).Error } // UpdatePatch generates and saves a new patch. @@ -711,7 +718,7 @@ func (pr *PullRequest) AddToTaskQueue() { type PullRequestList []*PullRequest -func (prs PullRequestList) loadAttributes(e Engine) (err error) { +func (prs PullRequestList) loadAttributes(db *gorm.DB) (err error) { if len(prs) == 0 { return nil } @@ -726,7 +733,7 @@ func (prs PullRequestList) loadAttributes(e Engine) (err error) { issueIDs = append(issueIDs, issueID) } issues := make([]*Issue, 0, len(issueIDs)) - if err = e.Where("id > 0").In("id", issueIDs).Find(&issues); err != nil { + if err = db.Where("id IN ?", issueIDs).Find(&issues).Error; err != nil { return errors.Newf("find issues: %v", err) } for i := range issues { @@ -839,24 +846,22 @@ func (pr *PullRequest) checkAndUpdateStatus() { // TODO: test more pull requests at same time. func TestPullRequests() { prs := make([]*PullRequest, 0, 10) - _ = x.Iterate(PullRequest{ - Status: PullRequestStatusChecking, - }, - func(idx int, bean any) error { - pr := bean.(*PullRequest) + db.Where("status = ?", PullRequestStatusChecking).FindInBatches(&prs, 100, func(tx *gorm.DB, batch int) error { + for i := range prs { + pr := prs[i] if err := pr.LoadAttributes(); err != nil { log.Error("LoadAttributes: %v", err) - return nil + continue } if err := pr.testPatch(); err != nil { log.Error("testPatch: %v", err) - return nil + continue } - prs = append(prs, pr) - return nil - }) + } + return nil + }) // Update pull request status. for _, pr := range prs {