From ec0ae5d50c59315a3c597b1cf24d4c5508c718e5 Mon Sep 17 00:00:00 2001 From: Ethan Koenig Date: Tue, 14 Mar 2017 20:51:46 -0400 Subject: [PATCH] Refactor and fix incorrect comment (#1247) --- cmd/serv.go | 2 +- models/access.go | 24 ++++++++++++------------ models/access_test.go | 16 ++++++++-------- models/issue.go | 27 +++++++++------------------ models/org_team.go | 34 ++++++++++++++++++++-------------- models/org_team_test.go | 2 +- models/release.go | 2 +- models/repo.go | 2 +- models/ssh_key.go | 2 +- models/user.go | 6 +++--- modules/context/repo.go | 2 +- modules/lfs/server.go | 4 ++-- routers/api/v1/api.go | 2 +- routers/api/v1/org/team.go | 6 +++--- routers/api/v1/repo/fork.go | 2 +- routers/api/v1/repo/release.go | 2 +- routers/api/v1/repo/repo.go | 6 +++--- routers/api/v1/user/star.go | 7 ++----- routers/api/v1/user/watch.go | 6 +----- routers/repo/http.go | 4 ++-- 20 files changed, 74 insertions(+), 84 deletions(-) diff --git a/cmd/serv.go b/cmd/serv.go index 5b1caf4d3..925cb2a67 100644 --- a/cmd/serv.go +++ b/cmd/serv.go @@ -232,7 +232,7 @@ func runServ(c *cli.Context) error { fail("internal error", "Failed to get user by key ID(%d): %v", keyID, err) } - mode, err := models.AccessLevel(user, repo) + mode, err := models.AccessLevel(user.ID, repo) if err != nil { fail("Internal error", "Failed to check access: %v", err) } else if mode < requestedMode { diff --git a/models/access.go b/models/access.go index 49a8838ea..98ead19a0 100644 --- a/models/access.go +++ b/models/access.go @@ -59,21 +59,21 @@ type Access struct { Mode AccessMode } -func accessLevel(e Engine, user *User, repo *Repository) (AccessMode, error) { +func accessLevel(e Engine, userID int64, repo *Repository) (AccessMode, error) { mode := AccessModeNone if !repo.IsPrivate { mode = AccessModeRead } - if user == nil { + if userID == 0 { return mode, nil } - if user.ID == repo.OwnerID { + if userID == repo.OwnerID { return AccessModeOwner, nil } - a := &Access{UserID: user.ID, RepoID: repo.ID} + a := &Access{UserID: userID, RepoID: repo.ID} if has, err := e.Get(a); !has || err != nil { return mode, err } @@ -81,19 +81,19 @@ func accessLevel(e Engine, user *User, repo *Repository) (AccessMode, error) { } // AccessLevel returns the Access a user has to a repository. Will return NoneAccess if the -// user does not have access. User can be nil! -func AccessLevel(user *User, repo *Repository) (AccessMode, error) { - return accessLevel(x, user, repo) +// user does not have access. +func AccessLevel(userID int64, repo *Repository) (AccessMode, error) { + return accessLevel(x, userID, repo) } -func hasAccess(e Engine, user *User, repo *Repository, testMode AccessMode) (bool, error) { - mode, err := accessLevel(e, user, repo) +func hasAccess(e Engine, userID int64, repo *Repository, testMode AccessMode) (bool, error) { + mode, err := accessLevel(e, userID, repo) return testMode <= mode, err } -// HasAccess returns true if someone has the request access level. User can be nil! -func HasAccess(user *User, repo *Repository, testMode AccessMode) (bool, error) { - return hasAccess(x, user, repo, testMode) +// HasAccess returns true if user has access to repo +func HasAccess(userID int64, repo *Repository, testMode AccessMode) (bool, error) { + return hasAccess(x, userID, repo, testMode) } type repoAccess struct { diff --git a/models/access_test.go b/models/access_test.go index 6b3cce502..29d40d958 100644 --- a/models/access_test.go +++ b/models/access_test.go @@ -25,19 +25,19 @@ func TestAccessLevel(t *testing.T) { repo1 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 2, IsPrivate: false}).(*Repository) repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository) - level, err := AccessLevel(user1, repo1) + level, err := AccessLevel(user1.ID, repo1) assert.NoError(t, err) assert.Equal(t, AccessModeOwner, level) - level, err = AccessLevel(user1, repo2) + level, err = AccessLevel(user1.ID, repo2) assert.NoError(t, err) assert.Equal(t, AccessModeWrite, level) - level, err = AccessLevel(user2, repo1) + level, err = AccessLevel(user2.ID, repo1) assert.NoError(t, err) assert.Equal(t, AccessModeRead, level) - level, err = AccessLevel(user2, repo2) + level, err = AccessLevel(user2.ID, repo2) assert.NoError(t, err) assert.Equal(t, AccessModeNone, level) } @@ -51,19 +51,19 @@ func TestHasAccess(t *testing.T) { repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository) for _, accessMode := range accessModes { - has, err := HasAccess(user1, repo1, accessMode) + has, err := HasAccess(user1.ID, repo1, accessMode) assert.NoError(t, err) assert.True(t, has) - has, err = HasAccess(user1, repo2, accessMode) + has, err = HasAccess(user1.ID, repo2, accessMode) assert.NoError(t, err) assert.Equal(t, accessMode <= AccessModeWrite, has) - has, err = HasAccess(user2, repo1, accessMode) + has, err = HasAccess(user2.ID, repo1, accessMode) assert.NoError(t, err) assert.Equal(t, accessMode <= AccessModeRead, has) - has, err = HasAccess(user2, repo2, accessMode) + has, err = HasAccess(user2.ID, repo2, accessMode) assert.NoError(t, err) assert.Equal(t, accessMode <= AccessModeNone, has) } diff --git a/models/issue.go b/models/issue.go index c740d8fec..d5e20eb20 100644 --- a/models/issue.go +++ b/models/issue.go @@ -374,7 +374,7 @@ func (issue *Issue) RemoveLabel(doer *User, label *Label) error { return err } - if has, err := HasAccess(doer, issue.Repo, AccessModeWrite); err != nil { + if has, err := HasAccess(doer.ID, issue.Repo, AccessModeWrite); err != nil { return err } else if !has { return ErrLabelNotExist{} @@ -415,7 +415,7 @@ func (issue *Issue) ClearLabels(doer *User) (err error) { return err } - if has, err := hasAccess(sess, doer, issue.Repo, AccessModeWrite); err != nil { + if has, err := hasAccess(sess, doer.ID, issue.Repo, AccessModeWrite); err != nil { return err } else if !has { return ErrLabelNotExist{} @@ -809,23 +809,14 @@ func newIssue(e *xorm.Session, doer *User, opts NewIssueOptions) (err error) { } } - if opts.Issue.AssigneeID > 0 { - assignee, err := getUserByID(e, opts.Issue.AssigneeID) - if err != nil && !IsErrUserNotExist(err) { - return fmt.Errorf("getUserByID: %v", err) + if assigneeID := opts.Issue.AssigneeID; assigneeID > 0 { + valid, err := hasAccess(e, assigneeID, opts.Repo, AccessModeWrite) + if err != nil { + return fmt.Errorf("hasAccess [user_id: %d, repo_id: %d]: %v", assigneeID, opts.Repo.ID, err) } - - // Assume assignee is invalid and drop silently. - opts.Issue.AssigneeID = 0 - if assignee != nil { - valid, err := hasAccess(e, assignee, opts.Repo, AccessModeWrite) - if err != nil { - return fmt.Errorf("hasAccess [user_id: %d, repo_id: %d]: %v", assignee.ID, opts.Repo.ID, err) - } - if valid { - opts.Issue.AssigneeID = assignee.ID - opts.Issue.Assignee = assignee - } + if !valid { + opts.Issue.AssigneeID = 0 + opts.Issue.Assignee = nil } } diff --git a/models/org_team.go b/models/org_team.go index 84282da83..db25e6154 100644 --- a/models/org_team.go +++ b/models/org_team.go @@ -139,18 +139,19 @@ func (t *Team) removeRepository(e Engine, repo *Repository, recalculate bool) (e } } - if err = t.getMembers(e); err != nil { - return fmt.Errorf("get team members: %v", err) + teamUsers, err := getTeamUsersByTeamID(e, t.ID) + if err != nil { + return fmt.Errorf("getTeamUsersByTeamID: %v", err) } - for _, u := range t.Members { - has, err := hasAccess(e, u, repo, AccessModeRead) + for _, teamUser:= range teamUsers { + has, err := hasAccess(e, teamUser.UID, repo, AccessModeRead) if err != nil { return err } else if has { continue } - if err = watchRepo(e, u.ID, repo.ID, false); err != nil { + if err = watchRepo(e, teamUser.UID, repo.ID, false); err != nil { return err } } @@ -399,20 +400,25 @@ func IsTeamMember(orgID, teamID, userID int64) bool { return isTeamMember(x, orgID, teamID, userID) } -func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) { +func getTeamUsersByTeamID(e Engine, teamID int64) ([]*TeamUser, error) { teamUsers := make([]*TeamUser, 0, 10) - if err = e. + return teamUsers, e. Where("team_id=?", teamID). - Find(&teamUsers); err != nil { + Find(&teamUsers) +} + +func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) { + teamUsers, err := getTeamUsersByTeamID(e, teamID) + if err != nil { return nil, fmt.Errorf("get team-users: %v", err) } - members := make([]*User, 0, len(teamUsers)) - for i := range teamUsers { - member := new(User) - if _, err = e.Id(teamUsers[i].UID).Get(member); err != nil { - return nil, fmt.Errorf("get user '%d': %v", teamUsers[i].UID, err) + members := make([]*User, len(teamUsers)) + for i, teamUser := range teamUsers { + member, err := getUserByID(e, teamUser.UID) + if err != nil { + return nil, fmt.Errorf("get user '%d': %v", teamUser.UID, err) } - members = append(members, member) + members[i] = member } return members, nil } diff --git a/models/org_team_test.go b/models/org_team_test.go index db0a81468..506de6e6b 100644 --- a/models/org_team_test.go +++ b/models/org_team_test.go @@ -243,7 +243,7 @@ func TestDeleteTeam(t *testing.T) { // check that team members don't have "leftover" access to repos user := AssertExistsAndLoadBean(t, &User{ID: 4}).(*User) repo := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository) - accessMode, err := AccessLevel(user, repo) + accessMode, err := AccessLevel(user.ID, repo) assert.NoError(t, err) assert.True(t, accessMode < AccessModeWrite) } diff --git a/models/release.go b/models/release.go index 3e3db98ab..ae3200d39 100644 --- a/models/release.go +++ b/models/release.go @@ -365,7 +365,7 @@ func DeleteReleaseByID(id int64, u *User, delTag bool) error { return fmt.Errorf("GetRepositoryByID: %v", err) } - has, err := HasAccess(u, repo, AccessModeWrite) + has, err := HasAccess(u.ID, repo, AccessModeWrite) if err != nil { return fmt.Errorf("HasAccess: %v", err) } else if !has { diff --git a/models/repo.go b/models/repo.go index d44f4ba48..eac0f015f 100644 --- a/models/repo.go +++ b/models/repo.go @@ -531,7 +531,7 @@ func (repo *Repository) ComposeCompareURL(oldCommitID, newCommitID string) strin // HasAccess returns true when user has access to this repository func (repo *Repository) HasAccess(u *User) bool { - has, _ := HasAccess(u, repo, AccessModeRead) + has, _ := HasAccess(u.ID, repo, AccessModeRead) return has } diff --git a/models/ssh_key.go b/models/ssh_key.go index 802333f48..ba7007028 100644 --- a/models/ssh_key.go +++ b/models/ssh_key.go @@ -794,7 +794,7 @@ func DeleteDeployKey(doer *User, id int64) error { if err != nil { return fmt.Errorf("GetRepositoryByID: %v", err) } - yes, err := HasAccess(doer, repo, AccessModeAdmin) + yes, err := HasAccess(doer.ID, repo, AccessModeAdmin) if err != nil { return fmt.Errorf("HasAccess: %v", err) } else if !yes { diff --git a/models/user.go b/models/user.go index 7cdff1a46..cfc01936f 100644 --- a/models/user.go +++ b/models/user.go @@ -478,7 +478,7 @@ func (u *User) DeleteAvatar() error { // IsAdminOfRepo returns true if user has admin or higher access of repository. func (u *User) IsAdminOfRepo(repo *Repository) bool { - has, err := HasAccess(u, repo, AccessModeAdmin) + has, err := HasAccess(u.ID, repo, AccessModeAdmin) if err != nil { log.Error(3, "HasAccess: %v", err) } @@ -487,7 +487,7 @@ func (u *User) IsAdminOfRepo(repo *Repository) bool { // IsWriterOfRepo returns true if user has write access to given repository. func (u *User) IsWriterOfRepo(repo *Repository) bool { - has, err := HasAccess(u, repo, AccessModeWrite) + has, err := HasAccess(u.ID, repo, AccessModeWrite) if err != nil { log.Error(3, "HasAccess: %v", err) } @@ -1103,7 +1103,7 @@ func GetUserByID(id int64) (*User, error) { // GetAssigneeByID returns the user with write access of repository by given ID. func GetAssigneeByID(repo *Repository, userID int64) (*User, error) { - has, err := HasAccess(&User{ID: userID}, repo, AccessModeWrite) + has, err := HasAccess(userID, repo, AccessModeWrite) if err != nil { return nil, err } else if !has { diff --git a/modules/context/repo.go b/modules/context/repo.go index 1ae98545a..76af49ffd 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -219,7 +219,7 @@ func RepoAssignment(args ...bool) macaron.Handler { if ctx.IsSigned && ctx.User.IsAdmin { ctx.Repo.AccessMode = models.AccessModeOwner } else { - mode, err := models.AccessLevel(ctx.User, repo) + mode, err := models.AccessLevel(ctx.User.ID, repo) if err != nil { ctx.Handle(500, "AccessLevel", err) return diff --git a/modules/lfs/server.go b/modules/lfs/server.go index 1bdeadc46..782d972fa 100644 --- a/modules/lfs/server.go +++ b/modules/lfs/server.go @@ -463,7 +463,7 @@ func authenticate(ctx *context.Context, repository *models.Repository, authoriza } if ctx.IsSigned { - accessCheck, _ := models.HasAccess(ctx.User, repository, accessMode) + accessCheck, _ := models.HasAccess(ctx.User.ID, repository, accessMode) return accessCheck } @@ -499,7 +499,7 @@ func authenticate(ctx *context.Context, repository *models.Repository, authoriza return false } - accessCheck, _ := models.HasAccess(userModel, repository, accessMode) + accessCheck, _ := models.HasAccess(userModel.ID, repository, accessMode) return accessCheck } diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index 611a8f91d..ca4592c60 100644 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -70,7 +70,7 @@ func repoAssignment() macaron.Handler { if ctx.IsSigned && ctx.User.IsAdmin { ctx.Repo.AccessMode = models.AccessModeOwner } else { - mode, err := models.AccessLevel(ctx.User, repo) + mode, err := models.AccessLevel(ctx.User.ID, repo) if err != nil { ctx.Error(500, "AccessLevel", err) return diff --git a/routers/api/v1/org/team.go b/routers/api/v1/org/team.go index f9d93399a..dbd6ccc46 100644 --- a/routers/api/v1/org/team.go +++ b/routers/api/v1/org/team.go @@ -131,7 +131,7 @@ func GetTeamRepos(ctx *context.APIContext) { } repos := make([]*api.Repository, len(team.Repos)) for i, repo := range team.Repos { - access, err := models.AccessLevel(ctx.User, repo) + access, err := models.AccessLevel(ctx.User.ID, repo) if err != nil { ctx.Error(500, "GetTeamRepos", err) return @@ -161,7 +161,7 @@ func AddTeamRepository(ctx *context.APIContext) { if ctx.Written() { return } - if access, err := models.AccessLevel(ctx.User, repo); err != nil { + if access, err := models.AccessLevel(ctx.User.ID, repo); err != nil { ctx.Error(500, "AccessLevel", err) return } else if access < models.AccessModeAdmin { @@ -181,7 +181,7 @@ func RemoveTeamRepository(ctx *context.APIContext) { if ctx.Written() { return } - if access, err := models.AccessLevel(ctx.User, repo); err != nil { + if access, err := models.AccessLevel(ctx.User.ID, repo); err != nil { ctx.Error(500, "AccessLevel", err) return } else if access < models.AccessModeAdmin { diff --git a/routers/api/v1/repo/fork.go b/routers/api/v1/repo/fork.go index e8f57ace7..9c6da754e 100644 --- a/routers/api/v1/repo/fork.go +++ b/routers/api/v1/repo/fork.go @@ -20,7 +20,7 @@ func ListForks(ctx *context.APIContext) { } apiForks := make([]*api.Repository, len(forks)) for i, fork := range forks { - access, err := models.AccessLevel(ctx.User, fork) + access, err := models.AccessLevel(ctx.User.ID, fork) if err != nil { ctx.Error(500, "AccessLevel", err) return diff --git a/routers/api/v1/repo/release.go b/routers/api/v1/repo/release.go index 7dacb8018..a367e5571 100644 --- a/routers/api/v1/repo/release.go +++ b/routers/api/v1/repo/release.go @@ -40,7 +40,7 @@ func ListReleases(ctx *context.APIContext) { return } rels := make([]*api.Release, len(releases)) - access, err := models.AccessLevel(ctx.User, ctx.Repo.Repository) + access, err := models.AccessLevel(ctx.User.ID, ctx.Repo.Repository) if err != nil { ctx.Error(500, "AccessLevel", err) return diff --git a/routers/api/v1/repo/repo.go b/routers/api/v1/repo/repo.go index a43246624..317b0c57a 100644 --- a/routers/api/v1/repo/repo.go +++ b/routers/api/v1/repo/repo.go @@ -64,7 +64,7 @@ func Search(ctx *context.APIContext) { }) return } - accessMode, err := models.AccessLevel(ctx.User, repo) + accessMode, err := models.AccessLevel(ctx.User.ID, repo) if err != nil { ctx.JSON(500, map[string]interface{}{ "ok": false, @@ -218,7 +218,7 @@ func Migrate(ctx *context.APIContext, form auth.MigrateRepoForm) { // see https://github.com/gogits/go-gogs-client/wiki/Repositories#get func Get(ctx *context.APIContext) { repo := ctx.Repo.Repository - access, err := models.AccessLevel(ctx.User, repo) + access, err := models.AccessLevel(ctx.User.ID, repo) if err != nil { ctx.Error(500, "GetRepository", err) return @@ -238,7 +238,7 @@ func GetByID(ctx *context.APIContext) { return } - access, err := models.AccessLevel(ctx.User, repo) + access, err := models.AccessLevel(ctx.User.ID, repo) if err != nil { ctx.Error(500, "GetRepositoryByID", err) return diff --git a/routers/api/v1/user/star.go b/routers/api/v1/user/star.go index 0937fd190..47d3ed507 100644 --- a/routers/api/v1/user/star.go +++ b/routers/api/v1/user/star.go @@ -18,13 +18,10 @@ func getStarredRepos(userID int64, private bool) ([]*api.Repository, error) { if err != nil { return nil, err } - user, err := models.GetUserByID(userID) - if err != nil { - return nil, err - } + repos := make([]*api.Repository, len(starredRepos)) for i, starred := range starredRepos { - access, err := models.AccessLevel(user, starred) + access, err := models.AccessLevel(userID, starred) if err != nil { return nil, err } diff --git a/routers/api/v1/user/watch.go b/routers/api/v1/user/watch.go index 6a9ad670d..2a94e219f 100644 --- a/routers/api/v1/user/watch.go +++ b/routers/api/v1/user/watch.go @@ -31,14 +31,10 @@ func getWatchedRepos(userID int64, private bool) ([]*api.Repository, error) { if err != nil { return nil, err } - user, err := models.GetUserByID(userID) - if err != nil { - return nil, err - } repos := make([]*api.Repository, len(watchedRepos)) for i, watched := range watchedRepos { - access, err := models.AccessLevel(user, watched) + access, err := models.AccessLevel(userID, watched) if err != nil { return nil, err } diff --git a/routers/repo/http.go b/routers/repo/http.go index dc2965184..12fcbcfb3 100644 --- a/routers/repo/http.go +++ b/routers/repo/http.go @@ -152,13 +152,13 @@ func HTTP(ctx *context.Context) { } if !isPublicPull { - has, err := models.HasAccess(authUser, repo, accessMode) + has, err := models.HasAccess(authUser.ID, repo, accessMode) if err != nil { ctx.Handle(http.StatusInternalServerError, "HasAccess", err) return } else if !has { if accessMode == models.AccessModeRead { - has, err = models.HasAccess(authUser, repo, models.AccessModeWrite) + has, err = models.HasAccess(authUser.ID, repo, models.AccessModeWrite) if err != nil { ctx.Handle(http.StatusInternalServerError, "HasAccess2", err) return