From 73cf0e2614263c8c4ad6d55d06e753b28d2b6091 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 31 Mar 2020 15:47:00 +0800 Subject: [PATCH] Fix milestones too many SQL variables bug (#10880) * Fix milestones too many SQL variables bug * Fix test * Don't display repositories with no milestone and fix tests * Remove unused code and add some comments --- models/issue_milestone.go | 58 ++++++++++--- models/issue_milestone_test.go | 3 +- models/repo_list.go | 32 ++++++-- routers/user/home.go | 143 ++++++++++++++------------------- routers/user/home_test.go | 4 +- 5 files changed, 135 insertions(+), 105 deletions(-) diff --git a/models/issue_milestone.go b/models/issue_milestone.go index ba39e6ebc..6bef35ce6 100644 --- a/models/issue_milestone.go +++ b/models/issue_milestone.go @@ -525,10 +525,12 @@ func DeleteMilestoneByRepoID(repoID, id int64) error { return sess.Commit() } -// CountMilestonesByRepoIDs map from repoIDs to number of milestones matching the options` -func CountMilestonesByRepoIDs(repoIDs []int64, isClosed bool) (map[int64]int64, error) { +// CountMilestones map from repo conditions to number of milestones matching the options` +func CountMilestones(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { sess := x.Where("is_closed = ?", isClosed) - sess.In("repo_id", repoIDs) + if repoCond.IsValid() { + sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond)) + } countsSlice := make([]*struct { RepoID int64 @@ -548,11 +550,21 @@ func CountMilestonesByRepoIDs(repoIDs []int64, isClosed bool) (map[int64]int64, return countMap, nil } -// GetMilestonesByRepoIDs returns a list of milestones of given repositories and status. -func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { +// CountMilestonesByRepoIDs map from repoIDs to number of milestones matching the options` +func CountMilestonesByRepoIDs(repoIDs []int64, isClosed bool) (map[int64]int64, error) { + return CountMilestones( + builder.In("repo_id", repoIDs), + isClosed, + ) +} + +// SearchMilestones search milestones +func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType string) (MilestoneList, error) { miles := make([]*Milestone, 0, setting.UI.IssuePagingNum) sess := x.Where("is_closed = ?", isClosed) - sess.In("repo_id", repoIDs) + if repoCond.IsValid() { + sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond)) + } if page > 0 { sess = sess.Limit(setting.UI.IssuePagingNum, (page-1)*setting.UI.IssuePagingNum) } @@ -574,25 +586,45 @@ func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType s return miles, sess.Find(&miles) } +// GetMilestonesByRepoIDs returns a list of milestones of given repositories and status. +func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { + return SearchMilestones( + builder.In("repo_id", repoIDs), + page, + isClosed, + sortType, + ) +} + // MilestonesStats represents milestone statistic information. type MilestonesStats struct { OpenCount, ClosedCount int64 } +// Total returns the total counts of milestones +func (m MilestonesStats) Total() int64 { + return m.OpenCount + m.ClosedCount +} + // GetMilestonesStats returns milestone statistic information for dashboard by given conditions. -func GetMilestonesStats(userRepoIDs []int64) (*MilestonesStats, error) { +func GetMilestonesStats(repoCond builder.Cond) (*MilestonesStats, error) { var err error stats := &MilestonesStats{} - stats.OpenCount, err = x.Where("is_closed = ?", false). - And(builder.In("repo_id", userRepoIDs)). - Count(new(Milestone)) + sess := x.Where("is_closed = ?", false) + if repoCond.IsValid() { + sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) + } + stats.OpenCount, err = sess.Count(new(Milestone)) if err != nil { return nil, err } - stats.ClosedCount, err = x.Where("is_closed = ?", true). - And(builder.In("repo_id", userRepoIDs)). - Count(new(Milestone)) + + sess = x.Where("is_closed = ?", true) + if repoCond.IsValid() { + sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) + } + stats.ClosedCount, err = sess.Count(new(Milestone)) if err != nil { return nil, err } diff --git a/models/issue_milestone_test.go b/models/issue_milestone_test.go index 778ebfbda..607d36c31 100644 --- a/models/issue_milestone_test.go +++ b/models/issue_milestone_test.go @@ -11,6 +11,7 @@ import ( api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/timeutil" + "xorm.io/builder" "github.com/stretchr/testify/assert" ) @@ -370,7 +371,7 @@ func TestGetMilestonesStats(t *testing.T) { repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository) repo2 := AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository) - milestoneStats, err := GetMilestonesStats([]int64{repo1.ID, repo2.ID}) + milestoneStats, err := GetMilestonesStats(builder.In("repo_id", []int64{repo1.ID, repo2.ID})) assert.NoError(t, err) assert.EqualValues(t, repo1.NumOpenMilestones+repo2.NumOpenMilestones, milestoneStats.OpenCount) assert.EqualValues(t, repo1.NumClosedMilestones+repo2.NumClosedMilestones, milestoneStats.ClosedCount) diff --git a/models/repo_list.go b/models/repo_list.go index 7ceb88f08..1632e64eb 100644 --- a/models/repo_list.go +++ b/models/repo_list.go @@ -163,6 +163,10 @@ type SearchRepoOptions struct { TopicOnly bool // include description in keyword search IncludeDescription bool + // None -> include has milestones AND has no milestone + // True -> include just has milestones + // False -> include just has no milestone + HasMilestones util.OptionalBool } //SearchOrderBy is used to sort the result @@ -294,6 +298,14 @@ func SearchRepositoryCondition(opts *SearchRepoOptions) builder.Cond { if opts.Actor != nil && opts.Actor.IsRestricted { cond = cond.And(accessibleRepositoryCondition(opts.Actor)) } + + switch opts.HasMilestones { + case util.OptionalBoolTrue: + cond = cond.And(builder.Gt{"num_milestones": 0}) + case util.OptionalBoolFalse: + cond = cond.And(builder.Eq{"num_milestones": 0}.Or(builder.IsNull{"num_milestones"})) + } + return cond } @@ -301,7 +313,11 @@ func SearchRepositoryCondition(opts *SearchRepoOptions) builder.Cond { // it returns results in given range and number of total results. func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) { cond := SearchRepositoryCondition(opts) + return SearchRepositoryByCondition(opts, cond, true) +} +// SearchRepositoryByCondition search repositories by condition +func SearchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond, loadAttributes bool) (RepositoryList, int64, error) { if opts.Page <= 0 { opts.Page = 1 } @@ -326,16 +342,18 @@ func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) { } repos := make(RepositoryList, 0, opts.PageSize) - if err = sess. - Where(cond). - OrderBy(opts.OrderBy.String()). - Limit(opts.PageSize, (opts.Page-1)*opts.PageSize). - Find(&repos); err != nil { + sess.Where(cond).OrderBy(opts.OrderBy.String()) + if opts.PageSize > 0 { + sess.Limit(opts.PageSize, (opts.Page-1)*opts.PageSize) + } + if err = sess.Find(&repos); err != nil { return nil, 0, fmt.Errorf("Repo: %v", err) } - if err = repos.loadAttributes(sess); err != nil { - return nil, 0, fmt.Errorf("LoadAttributes: %v", err) + if loadAttributes { + if err = repos.loadAttributes(sess); err != nil { + return nil, 0, fmt.Errorf("LoadAttributes: %v", err) + } } return repos, count, nil diff --git a/routers/user/home.go b/routers/user/home.go index e310b97ad..71a65d6d9 100644 --- a/routers/user/home.go +++ b/routers/user/home.go @@ -25,7 +25,7 @@ import ( "github.com/keybase/go-crypto/openpgp" "github.com/keybase/go-crypto/openpgp/armor" - "github.com/unknwon/com" + "xorm.io/builder" ) const ( @@ -173,135 +173,114 @@ func Milestones(ctx *context.Context) { return } - sortType := ctx.Query("sort") - page := ctx.QueryInt("page") + var ( + repoOpts = models.SearchRepoOptions{ + Actor: ctxUser, + OwnerID: ctxUser.ID, + Private: true, + AllPublic: false, // Include also all public repositories of users and public organisations + AllLimited: false, // Include also all public repositories of limited organisations + HasMilestones: util.OptionalBoolTrue, // Just needs display repos has milestones + } + + userRepoCond = models.SearchRepositoryCondition(&repoOpts) // all repo condition user could visit + repoCond = userRepoCond + repoIDs []int64 + + reposQuery = ctx.Query("repos") + isShowClosed = ctx.Query("state") == "closed" + sortType = ctx.Query("sort") + page = ctx.QueryInt("page") + ) + if page <= 1 { page = 1 } - reposQuery := ctx.Query("repos") - isShowClosed := ctx.Query("state") == "closed" - - // Get repositories. - var err error - var userRepoIDs []int64 - if ctxUser.IsOrganization() { - env, err := ctxUser.AccessibleReposEnv(ctx.User.ID) - if err != nil { - ctx.ServerError("AccessibleReposEnv", err) - return - } - userRepoIDs, err = env.RepoIDs(1, ctxUser.NumRepos) - if err != nil { - ctx.ServerError("env.RepoIDs", err) - return - } - userRepoIDs, err = models.FilterOutRepoIdsWithoutUnitAccess(ctx.User, userRepoIDs, models.UnitTypeIssues, models.UnitTypePullRequests) - if err != nil { - ctx.ServerError("FilterOutRepoIdsWithoutUnitAccess", err) - return - } - } else { - userRepoIDs, err = ctxUser.GetAccessRepoIDs(models.UnitTypeIssues, models.UnitTypePullRequests) - if err != nil { - ctx.ServerError("ctxUser.GetAccessRepoIDs", err) - return - } - } - if len(userRepoIDs) == 0 { - userRepoIDs = []int64{-1} - } - - var repoIDs []int64 if len(reposQuery) != 0 { if issueReposQueryPattern.MatchString(reposQuery) { // remove "[" and "]" from string reposQuery = reposQuery[1 : len(reposQuery)-1] //for each ID (delimiter ",") add to int to repoIDs - reposSet := false + for _, rID := range strings.Split(reposQuery, ",") { // Ensure nonempty string entries if rID != "" && rID != "0" { - reposSet = true rIDint64, err := strconv.ParseInt(rID, 10, 64) // If the repo id specified by query is not parseable or not accessible by user, just ignore it. - if err == nil && com.IsSliceContainsInt64(userRepoIDs, rIDint64) { + if err == nil { repoIDs = append(repoIDs, rIDint64) } } } - if reposSet && len(repoIDs) == 0 { - // force an empty result - repoIDs = []int64{-1} + if len(repoIDs) > 0 { + // Don't just let repoCond = builder.In("id", repoIDs) because user may has no permission on repoIDs + // But the original repoCond has a limitation + repoCond = repoCond.And(builder.In("id", repoIDs)) } } else { log.Warn("issueReposQueryPattern not match with query") } } - if len(repoIDs) == 0 { - repoIDs = userRepoIDs - } - - counts, err := models.CountMilestonesByRepoIDs(userRepoIDs, isShowClosed) + counts, err := models.CountMilestones(userRepoCond, isShowClosed) if err != nil { ctx.ServerError("CountMilestonesByRepoIDs", err) return } - milestones, err := models.GetMilestonesByRepoIDs(repoIDs, page, isShowClosed, sortType) + milestones, err := models.SearchMilestones(repoCond, page, isShowClosed, sortType) if err != nil { ctx.ServerError("GetMilestonesByRepoIDs", err) return } - showReposMap := make(map[int64]*models.Repository, len(counts)) - for rID := range counts { - if rID == -1 { - break - } - repo, err := models.GetRepositoryByID(rID) - if err != nil { - if models.IsErrRepoNotExist(err) { - ctx.NotFound("GetRepositoryByID", err) - return - } else if err != nil { - ctx.ServerError("GetRepositoryByID", fmt.Errorf("[%d]%v", rID, err)) - return - } - } - showReposMap[rID] = repo - } - - showRepos := models.RepositoryListOfMap(showReposMap) - sort.Sort(showRepos) - if err = showRepos.LoadAttributes(); err != nil { - ctx.ServerError("LoadAttributes", err) + showRepos, _, err := models.SearchRepositoryByCondition(&repoOpts, userRepoCond, false) + if err != nil { + ctx.ServerError("SearchRepositoryByCondition", err) return } + sort.Sort(showRepos) - for _, m := range milestones { - m.Repo = showReposMap[m.RepoID] - m.RenderedContent = string(markdown.Render([]byte(m.Content), m.Repo.Link(), m.Repo.ComposeMetas())) - if m.Repo.IsTimetrackerEnabled() { - err := m.LoadTotalTrackedTime() + for i := 0; i < len(milestones); { + for _, repo := range showRepos { + if milestones[i].RepoID == repo.ID { + milestones[i].Repo = repo + break + } + } + if milestones[i].Repo == nil { + log.Warn("Cannot find milestone %d 's repository %d", milestones[i].ID, milestones[i].RepoID) + milestones = append(milestones[:i], milestones[i+1:]...) + continue + } + + milestones[i].RenderedContent = string(markdown.Render([]byte(milestones[i].Content), milestones[i].Repo.Link(), milestones[i].Repo.ComposeMetas())) + if milestones[i].Repo.IsTimetrackerEnabled() { + err := milestones[i].LoadTotalTrackedTime() if err != nil { ctx.ServerError("LoadTotalTrackedTime", err) return } } + i++ } - milestoneStats, err := models.GetMilestonesStats(repoIDs) + milestoneStats, err := models.GetMilestonesStats(repoCond) if err != nil { ctx.ServerError("GetMilestoneStats", err) return } - totalMilestoneStats, err := models.GetMilestonesStats(userRepoIDs) - if err != nil { - ctx.ServerError("GetMilestoneStats", err) - return + var totalMilestoneStats *models.MilestonesStats + if len(repoIDs) == 0 { + totalMilestoneStats = milestoneStats + } else { + totalMilestoneStats, err = models.GetMilestonesStats(userRepoCond) + if err != nil { + ctx.ServerError("GetMilestoneStats", err) + return + } } var pagerCount int @@ -320,7 +299,7 @@ func Milestones(ctx *context.Context) { ctx.Data["Counts"] = counts ctx.Data["MilestoneStats"] = milestoneStats ctx.Data["SortType"] = sortType - if len(repoIDs) != len(userRepoIDs) { + if milestoneStats.Total() != totalMilestoneStats.Total() { ctx.Data["RepoIDs"] = repoIDs } ctx.Data["IsShowClosed"] = isShowClosed diff --git a/routers/user/home_test.go b/routers/user/home_test.go index 39186d93e..ff48953d4 100644 --- a/routers/user/home_test.go +++ b/routers/user/home_test.go @@ -48,7 +48,7 @@ func TestMilestones(t *testing.T) { assert.EqualValues(t, "furthestduedate", ctx.Data["SortType"]) assert.EqualValues(t, 1, ctx.Data["Total"]) assert.Len(t, ctx.Data["Milestones"], 1) - assert.Len(t, ctx.Data["Repos"], 1) + assert.Len(t, ctx.Data["Repos"], 2) // both repo 42 and 1 have milestones and both are owned by user 2 } func TestMilestonesForSpecificRepo(t *testing.T) { @@ -68,5 +68,5 @@ func TestMilestonesForSpecificRepo(t *testing.T) { assert.EqualValues(t, "furthestduedate", ctx.Data["SortType"]) assert.EqualValues(t, 1, ctx.Data["Total"]) assert.Len(t, ctx.Data["Milestones"], 1) - assert.Len(t, ctx.Data["Repos"], 1) + assert.Len(t, ctx.Data["Repos"], 2) // both repo 42 and 1 have milestones and both are owned by user 2 }