Refactor and fix incorrect comment (#1247)

This commit is contained in:
Ethan Koenig 2017-03-14 20:51:46 -04:00 committed by Lunny Xiao
parent 7d8f9d1c46
commit ec0ae5d50c
20 changed files with 74 additions and 84 deletions

View file

@ -232,7 +232,7 @@ func runServ(c *cli.Context) error {
fail("internal error", "Failed to get user by key ID(%d): %v", keyID, err) fail("internal error", "Failed to get user by key ID(%d): %v", keyID, err)
} }
mode, err := models.AccessLevel(user, repo) mode, err := models.AccessLevel(user.ID, repo)
if err != nil { if err != nil {
fail("Internal error", "Failed to check access: %v", err) fail("Internal error", "Failed to check access: %v", err)
} else if mode < requestedMode { } else if mode < requestedMode {

View file

@ -59,21 +59,21 @@ type Access struct {
Mode AccessMode Mode AccessMode
} }
func accessLevel(e Engine, user *User, repo *Repository) (AccessMode, error) { func accessLevel(e Engine, userID int64, repo *Repository) (AccessMode, error) {
mode := AccessModeNone mode := AccessModeNone
if !repo.IsPrivate { if !repo.IsPrivate {
mode = AccessModeRead mode = AccessModeRead
} }
if user == nil { if userID == 0 {
return mode, nil return mode, nil
} }
if user.ID == repo.OwnerID { if userID == repo.OwnerID {
return AccessModeOwner, nil return AccessModeOwner, nil
} }
a := &Access{UserID: user.ID, RepoID: repo.ID} a := &Access{UserID: userID, RepoID: repo.ID}
if has, err := e.Get(a); !has || err != nil { if has, err := e.Get(a); !has || err != nil {
return mode, err return mode, err
} }
@ -81,19 +81,19 @@ func accessLevel(e Engine, user *User, repo *Repository) (AccessMode, error) {
} }
// AccessLevel returns the Access a user has to a repository. Will return NoneAccess if the // AccessLevel returns the Access a user has to a repository. Will return NoneAccess if the
// user does not have access. User can be nil! // user does not have access.
func AccessLevel(user *User, repo *Repository) (AccessMode, error) { func AccessLevel(userID int64, repo *Repository) (AccessMode, error) {
return accessLevel(x, user, repo) return accessLevel(x, userID, repo)
} }
func hasAccess(e Engine, user *User, repo *Repository, testMode AccessMode) (bool, error) { func hasAccess(e Engine, userID int64, repo *Repository, testMode AccessMode) (bool, error) {
mode, err := accessLevel(e, user, repo) mode, err := accessLevel(e, userID, repo)
return testMode <= mode, err return testMode <= mode, err
} }
// HasAccess returns true if someone has the request access level. User can be nil! // HasAccess returns true if user has access to repo
func HasAccess(user *User, repo *Repository, testMode AccessMode) (bool, error) { func HasAccess(userID int64, repo *Repository, testMode AccessMode) (bool, error) {
return hasAccess(x, user, repo, testMode) return hasAccess(x, userID, repo, testMode)
} }
type repoAccess struct { type repoAccess struct {

View file

@ -25,19 +25,19 @@ func TestAccessLevel(t *testing.T) {
repo1 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 2, IsPrivate: false}).(*Repository) repo1 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 2, IsPrivate: false}).(*Repository)
repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository) repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository)
level, err := AccessLevel(user1, repo1) level, err := AccessLevel(user1.ID, repo1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, AccessModeOwner, level) assert.Equal(t, AccessModeOwner, level)
level, err = AccessLevel(user1, repo2) level, err = AccessLevel(user1.ID, repo2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, AccessModeWrite, level) assert.Equal(t, AccessModeWrite, level)
level, err = AccessLevel(user2, repo1) level, err = AccessLevel(user2.ID, repo1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, AccessModeRead, level) assert.Equal(t, AccessModeRead, level)
level, err = AccessLevel(user2, repo2) level, err = AccessLevel(user2.ID, repo2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, AccessModeNone, level) assert.Equal(t, AccessModeNone, level)
} }
@ -51,19 +51,19 @@ func TestHasAccess(t *testing.T) {
repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository) repo2 := AssertExistsAndLoadBean(t, &Repository{OwnerID: 3, IsPrivate: true}).(*Repository)
for _, accessMode := range accessModes { for _, accessMode := range accessModes {
has, err := HasAccess(user1, repo1, accessMode) has, err := HasAccess(user1.ID, repo1, accessMode)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
has, err = HasAccess(user1, repo2, accessMode) has, err = HasAccess(user1.ID, repo2, accessMode)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, accessMode <= AccessModeWrite, has) assert.Equal(t, accessMode <= AccessModeWrite, has)
has, err = HasAccess(user2, repo1, accessMode) has, err = HasAccess(user2.ID, repo1, accessMode)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, accessMode <= AccessModeRead, has) assert.Equal(t, accessMode <= AccessModeRead, has)
has, err = HasAccess(user2, repo2, accessMode) has, err = HasAccess(user2.ID, repo2, accessMode)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, accessMode <= AccessModeNone, has) assert.Equal(t, accessMode <= AccessModeNone, has)
} }

View file

@ -374,7 +374,7 @@ func (issue *Issue) RemoveLabel(doer *User, label *Label) error {
return err return err
} }
if has, err := HasAccess(doer, issue.Repo, AccessModeWrite); err != nil { if has, err := HasAccess(doer.ID, issue.Repo, AccessModeWrite); err != nil {
return err return err
} else if !has { } else if !has {
return ErrLabelNotExist{} return ErrLabelNotExist{}
@ -415,7 +415,7 @@ func (issue *Issue) ClearLabels(doer *User) (err error) {
return err return err
} }
if has, err := hasAccess(sess, doer, issue.Repo, AccessModeWrite); err != nil { if has, err := hasAccess(sess, doer.ID, issue.Repo, AccessModeWrite); err != nil {
return err return err
} else if !has { } else if !has {
return ErrLabelNotExist{} return ErrLabelNotExist{}
@ -809,23 +809,14 @@ func newIssue(e *xorm.Session, doer *User, opts NewIssueOptions) (err error) {
} }
} }
if opts.Issue.AssigneeID > 0 { if assigneeID := opts.Issue.AssigneeID; assigneeID > 0 {
assignee, err := getUserByID(e, opts.Issue.AssigneeID) valid, err := hasAccess(e, assigneeID, opts.Repo, AccessModeWrite)
if err != nil && !IsErrUserNotExist(err) { if err != nil {
return fmt.Errorf("getUserByID: %v", err) return fmt.Errorf("hasAccess [user_id: %d, repo_id: %d]: %v", assigneeID, opts.Repo.ID, err)
} }
if !valid {
// Assume assignee is invalid and drop silently. opts.Issue.AssigneeID = 0
opts.Issue.AssigneeID = 0 opts.Issue.Assignee = nil
if assignee != nil {
valid, err := hasAccess(e, assignee, opts.Repo, AccessModeWrite)
if err != nil {
return fmt.Errorf("hasAccess [user_id: %d, repo_id: %d]: %v", assignee.ID, opts.Repo.ID, err)
}
if valid {
opts.Issue.AssigneeID = assignee.ID
opts.Issue.Assignee = assignee
}
} }
} }

View file

@ -139,18 +139,19 @@ func (t *Team) removeRepository(e Engine, repo *Repository, recalculate bool) (e
} }
} }
if err = t.getMembers(e); err != nil { teamUsers, err := getTeamUsersByTeamID(e, t.ID)
return fmt.Errorf("get team members: %v", err) if err != nil {
return fmt.Errorf("getTeamUsersByTeamID: %v", err)
} }
for _, u := range t.Members { for _, teamUser:= range teamUsers {
has, err := hasAccess(e, u, repo, AccessModeRead) has, err := hasAccess(e, teamUser.UID, repo, AccessModeRead)
if err != nil { if err != nil {
return err return err
} else if has { } else if has {
continue continue
} }
if err = watchRepo(e, u.ID, repo.ID, false); err != nil { if err = watchRepo(e, teamUser.UID, repo.ID, false); err != nil {
return err return err
} }
} }
@ -399,20 +400,25 @@ func IsTeamMember(orgID, teamID, userID int64) bool {
return isTeamMember(x, orgID, teamID, userID) return isTeamMember(x, orgID, teamID, userID)
} }
func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) { func getTeamUsersByTeamID(e Engine, teamID int64) ([]*TeamUser, error) {
teamUsers := make([]*TeamUser, 0, 10) teamUsers := make([]*TeamUser, 0, 10)
if err = e. return teamUsers, e.
Where("team_id=?", teamID). Where("team_id=?", teamID).
Find(&teamUsers); err != nil { Find(&teamUsers)
}
func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) {
teamUsers, err := getTeamUsersByTeamID(e, teamID)
if err != nil {
return nil, fmt.Errorf("get team-users: %v", err) return nil, fmt.Errorf("get team-users: %v", err)
} }
members := make([]*User, 0, len(teamUsers)) members := make([]*User, len(teamUsers))
for i := range teamUsers { for i, teamUser := range teamUsers {
member := new(User) member, err := getUserByID(e, teamUser.UID)
if _, err = e.Id(teamUsers[i].UID).Get(member); err != nil { if err != nil {
return nil, fmt.Errorf("get user '%d': %v", teamUsers[i].UID, err) return nil, fmt.Errorf("get user '%d': %v", teamUser.UID, err)
} }
members = append(members, member) members[i] = member
} }
return members, nil return members, nil
} }

View file

@ -243,7 +243,7 @@ func TestDeleteTeam(t *testing.T) {
// check that team members don't have "leftover" access to repos // check that team members don't have "leftover" access to repos
user := AssertExistsAndLoadBean(t, &User{ID: 4}).(*User) user := AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
repo := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository) repo := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
accessMode, err := AccessLevel(user, repo) accessMode, err := AccessLevel(user.ID, repo)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, accessMode < AccessModeWrite) assert.True(t, accessMode < AccessModeWrite)
} }

View file

@ -365,7 +365,7 @@ func DeleteReleaseByID(id int64, u *User, delTag bool) error {
return fmt.Errorf("GetRepositoryByID: %v", err) return fmt.Errorf("GetRepositoryByID: %v", err)
} }
has, err := HasAccess(u, repo, AccessModeWrite) has, err := HasAccess(u.ID, repo, AccessModeWrite)
if err != nil { if err != nil {
return fmt.Errorf("HasAccess: %v", err) return fmt.Errorf("HasAccess: %v", err)
} else if !has { } else if !has {

View file

@ -531,7 +531,7 @@ func (repo *Repository) ComposeCompareURL(oldCommitID, newCommitID string) strin
// HasAccess returns true when user has access to this repository // HasAccess returns true when user has access to this repository
func (repo *Repository) HasAccess(u *User) bool { func (repo *Repository) HasAccess(u *User) bool {
has, _ := HasAccess(u, repo, AccessModeRead) has, _ := HasAccess(u.ID, repo, AccessModeRead)
return has return has
} }

View file

@ -794,7 +794,7 @@ func DeleteDeployKey(doer *User, id int64) error {
if err != nil { if err != nil {
return fmt.Errorf("GetRepositoryByID: %v", err) return fmt.Errorf("GetRepositoryByID: %v", err)
} }
yes, err := HasAccess(doer, repo, AccessModeAdmin) yes, err := HasAccess(doer.ID, repo, AccessModeAdmin)
if err != nil { if err != nil {
return fmt.Errorf("HasAccess: %v", err) return fmt.Errorf("HasAccess: %v", err)
} else if !yes { } else if !yes {

View file

@ -478,7 +478,7 @@ func (u *User) DeleteAvatar() error {
// IsAdminOfRepo returns true if user has admin or higher access of repository. // IsAdminOfRepo returns true if user has admin or higher access of repository.
func (u *User) IsAdminOfRepo(repo *Repository) bool { func (u *User) IsAdminOfRepo(repo *Repository) bool {
has, err := HasAccess(u, repo, AccessModeAdmin) has, err := HasAccess(u.ID, repo, AccessModeAdmin)
if err != nil { if err != nil {
log.Error(3, "HasAccess: %v", err) log.Error(3, "HasAccess: %v", err)
} }
@ -487,7 +487,7 @@ func (u *User) IsAdminOfRepo(repo *Repository) bool {
// IsWriterOfRepo returns true if user has write access to given repository. // IsWriterOfRepo returns true if user has write access to given repository.
func (u *User) IsWriterOfRepo(repo *Repository) bool { func (u *User) IsWriterOfRepo(repo *Repository) bool {
has, err := HasAccess(u, repo, AccessModeWrite) has, err := HasAccess(u.ID, repo, AccessModeWrite)
if err != nil { if err != nil {
log.Error(3, "HasAccess: %v", err) log.Error(3, "HasAccess: %v", err)
} }
@ -1103,7 +1103,7 @@ func GetUserByID(id int64) (*User, error) {
// GetAssigneeByID returns the user with write access of repository by given ID. // GetAssigneeByID returns the user with write access of repository by given ID.
func GetAssigneeByID(repo *Repository, userID int64) (*User, error) { func GetAssigneeByID(repo *Repository, userID int64) (*User, error) {
has, err := HasAccess(&User{ID: userID}, repo, AccessModeWrite) has, err := HasAccess(userID, repo, AccessModeWrite)
if err != nil { if err != nil {
return nil, err return nil, err
} else if !has { } else if !has {

View file

@ -219,7 +219,7 @@ func RepoAssignment(args ...bool) macaron.Handler {
if ctx.IsSigned && ctx.User.IsAdmin { if ctx.IsSigned && ctx.User.IsAdmin {
ctx.Repo.AccessMode = models.AccessModeOwner ctx.Repo.AccessMode = models.AccessModeOwner
} else { } else {
mode, err := models.AccessLevel(ctx.User, repo) mode, err := models.AccessLevel(ctx.User.ID, repo)
if err != nil { if err != nil {
ctx.Handle(500, "AccessLevel", err) ctx.Handle(500, "AccessLevel", err)
return return

View file

@ -463,7 +463,7 @@ func authenticate(ctx *context.Context, repository *models.Repository, authoriza
} }
if ctx.IsSigned { if ctx.IsSigned {
accessCheck, _ := models.HasAccess(ctx.User, repository, accessMode) accessCheck, _ := models.HasAccess(ctx.User.ID, repository, accessMode)
return accessCheck return accessCheck
} }
@ -499,7 +499,7 @@ func authenticate(ctx *context.Context, repository *models.Repository, authoriza
return false return false
} }
accessCheck, _ := models.HasAccess(userModel, repository, accessMode) accessCheck, _ := models.HasAccess(userModel.ID, repository, accessMode)
return accessCheck return accessCheck
} }

View file

@ -70,7 +70,7 @@ func repoAssignment() macaron.Handler {
if ctx.IsSigned && ctx.User.IsAdmin { if ctx.IsSigned && ctx.User.IsAdmin {
ctx.Repo.AccessMode = models.AccessModeOwner ctx.Repo.AccessMode = models.AccessModeOwner
} else { } else {
mode, err := models.AccessLevel(ctx.User, repo) mode, err := models.AccessLevel(ctx.User.ID, repo)
if err != nil { if err != nil {
ctx.Error(500, "AccessLevel", err) ctx.Error(500, "AccessLevel", err)
return return

View file

@ -131,7 +131,7 @@ func GetTeamRepos(ctx *context.APIContext) {
} }
repos := make([]*api.Repository, len(team.Repos)) repos := make([]*api.Repository, len(team.Repos))
for i, repo := range team.Repos { for i, repo := range team.Repos {
access, err := models.AccessLevel(ctx.User, repo) access, err := models.AccessLevel(ctx.User.ID, repo)
if err != nil { if err != nil {
ctx.Error(500, "GetTeamRepos", err) ctx.Error(500, "GetTeamRepos", err)
return return
@ -161,7 +161,7 @@ func AddTeamRepository(ctx *context.APIContext) {
if ctx.Written() { if ctx.Written() {
return return
} }
if access, err := models.AccessLevel(ctx.User, repo); err != nil { if access, err := models.AccessLevel(ctx.User.ID, repo); err != nil {
ctx.Error(500, "AccessLevel", err) ctx.Error(500, "AccessLevel", err)
return return
} else if access < models.AccessModeAdmin { } else if access < models.AccessModeAdmin {
@ -181,7 +181,7 @@ func RemoveTeamRepository(ctx *context.APIContext) {
if ctx.Written() { if ctx.Written() {
return return
} }
if access, err := models.AccessLevel(ctx.User, repo); err != nil { if access, err := models.AccessLevel(ctx.User.ID, repo); err != nil {
ctx.Error(500, "AccessLevel", err) ctx.Error(500, "AccessLevel", err)
return return
} else if access < models.AccessModeAdmin { } else if access < models.AccessModeAdmin {

View file

@ -20,7 +20,7 @@ func ListForks(ctx *context.APIContext) {
} }
apiForks := make([]*api.Repository, len(forks)) apiForks := make([]*api.Repository, len(forks))
for i, fork := range forks { for i, fork := range forks {
access, err := models.AccessLevel(ctx.User, fork) access, err := models.AccessLevel(ctx.User.ID, fork)
if err != nil { if err != nil {
ctx.Error(500, "AccessLevel", err) ctx.Error(500, "AccessLevel", err)
return return

View file

@ -40,7 +40,7 @@ func ListReleases(ctx *context.APIContext) {
return return
} }
rels := make([]*api.Release, len(releases)) rels := make([]*api.Release, len(releases))
access, err := models.AccessLevel(ctx.User, ctx.Repo.Repository) access, err := models.AccessLevel(ctx.User.ID, ctx.Repo.Repository)
if err != nil { if err != nil {
ctx.Error(500, "AccessLevel", err) ctx.Error(500, "AccessLevel", err)
return return

View file

@ -64,7 +64,7 @@ func Search(ctx *context.APIContext) {
}) })
return return
} }
accessMode, err := models.AccessLevel(ctx.User, repo) accessMode, err := models.AccessLevel(ctx.User.ID, repo)
if err != nil { if err != nil {
ctx.JSON(500, map[string]interface{}{ ctx.JSON(500, map[string]interface{}{
"ok": false, "ok": false,
@ -218,7 +218,7 @@ func Migrate(ctx *context.APIContext, form auth.MigrateRepoForm) {
// see https://github.com/gogits/go-gogs-client/wiki/Repositories#get // see https://github.com/gogits/go-gogs-client/wiki/Repositories#get
func Get(ctx *context.APIContext) { func Get(ctx *context.APIContext) {
repo := ctx.Repo.Repository repo := ctx.Repo.Repository
access, err := models.AccessLevel(ctx.User, repo) access, err := models.AccessLevel(ctx.User.ID, repo)
if err != nil { if err != nil {
ctx.Error(500, "GetRepository", err) ctx.Error(500, "GetRepository", err)
return return
@ -238,7 +238,7 @@ func GetByID(ctx *context.APIContext) {
return return
} }
access, err := models.AccessLevel(ctx.User, repo) access, err := models.AccessLevel(ctx.User.ID, repo)
if err != nil { if err != nil {
ctx.Error(500, "GetRepositoryByID", err) ctx.Error(500, "GetRepositoryByID", err)
return return

View file

@ -18,13 +18,10 @@ func getStarredRepos(userID int64, private bool) ([]*api.Repository, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := models.GetUserByID(userID)
if err != nil {
return nil, err
}
repos := make([]*api.Repository, len(starredRepos)) repos := make([]*api.Repository, len(starredRepos))
for i, starred := range starredRepos { for i, starred := range starredRepos {
access, err := models.AccessLevel(user, starred) access, err := models.AccessLevel(userID, starred)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -31,14 +31,10 @@ func getWatchedRepos(userID int64, private bool) ([]*api.Repository, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := models.GetUserByID(userID)
if err != nil {
return nil, err
}
repos := make([]*api.Repository, len(watchedRepos)) repos := make([]*api.Repository, len(watchedRepos))
for i, watched := range watchedRepos { for i, watched := range watchedRepos {
access, err := models.AccessLevel(user, watched) access, err := models.AccessLevel(userID, watched)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -152,13 +152,13 @@ func HTTP(ctx *context.Context) {
} }
if !isPublicPull { if !isPublicPull {
has, err := models.HasAccess(authUser, repo, accessMode) has, err := models.HasAccess(authUser.ID, repo, accessMode)
if err != nil { if err != nil {
ctx.Handle(http.StatusInternalServerError, "HasAccess", err) ctx.Handle(http.StatusInternalServerError, "HasAccess", err)
return return
} else if !has { } else if !has {
if accessMode == models.AccessModeRead { if accessMode == models.AccessModeRead {
has, err = models.HasAccess(authUser, repo, models.AccessModeWrite) has, err = models.HasAccess(authUser.ID, repo, models.AccessModeWrite)
if err != nil { if err != nil {
ctx.Handle(http.StatusInternalServerError, "HasAccess2", err) ctx.Handle(http.StatusInternalServerError, "HasAccess2", err)
return return