Move almost all functions' parameter db.Engine to context.Context (#19748)
* Move almost all functions' parameter db.Engine to context.Context * remove some unnecessary wrap functions
This commit is contained in:
@ -490,7 +490,7 @@ func runChangePassword(c *cli.Context) error {
|
||||
return errors.New("The password you chose is on a list of stolen passwords previously exposed in public data breaches. Please try again with a different password.\nFor more details, see https://haveibeenpwned.com/Passwords")
|
||||
}
|
||||
uname := c.String("username")
|
||||
user, err := user_model.GetUserByName(uname)
|
||||
user, err := user_model.GetUserByName(ctx, uname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -659,7 +659,7 @@ func runDeleteUser(c *cli.Context) error {
|
||||
if c.IsSet("email") {
|
||||
user, err = user_model.GetUserByEmail(c.String("email"))
|
||||
} else if c.IsSet("username") {
|
||||
user, err = user_model.GetUserByName(c.String("username"))
|
||||
user, err = user_model.GetUserByName(ctx, c.String("username"))
|
||||
} else {
|
||||
user, err = user_model.GetUserByID(c.Int64("id"))
|
||||
}
|
||||
@ -689,7 +689,7 @@ func runGenerateAccessToken(c *cli.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := user_model.GetUserByName(c.String("username"))
|
||||
user, err := user_model.GetUserByName(ctx, c.String("username"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func TestAPIGetTrackedTimes(t *testing.T) {
|
||||
resp := session.MakeRequest(t, req, http.StatusOK)
|
||||
var apiTimes api.TrackedTimeList
|
||||
DecodeJSON(t, resp, &apiTimes)
|
||||
expect, err := models.GetTrackedTimes(&models.FindTrackedTimesOptions{IssueID: issue2.ID})
|
||||
expect, err := models.GetTrackedTimes(db.DefaultContext, &models.FindTrackedTimesOptions{IssueID: issue2.ID})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, apiTimes, 3)
|
||||
|
||||
@ -83,7 +83,7 @@ func TestAPIDeleteTrackedTime(t *testing.T) {
|
||||
session.MakeRequest(t, req, http.StatusNotFound)
|
||||
|
||||
// Reset time of user 2 on issue 2
|
||||
trackedSeconds, err := models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
|
||||
trackedSeconds, err := models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3661), trackedSeconds)
|
||||
|
||||
@ -91,7 +91,7 @@ func TestAPIDeleteTrackedTime(t *testing.T) {
|
||||
session.MakeRequest(t, req, http.StatusNoContent)
|
||||
session.MakeRequest(t, req, http.StatusNotFound)
|
||||
|
||||
trackedSeconds, err = models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
|
||||
trackedSeconds, err = models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), trackedSeconds)
|
||||
}
|
||||
|
@ -388,7 +388,7 @@ func testAPIRepoMigrateConflict(t *testing.T, u *url.URL) {
|
||||
defer util.RemoveAll(dstPath)
|
||||
t.Run("CreateRepo", doAPICreateRepository(httpContext, false))
|
||||
|
||||
user, err := user_model.GetUserByName(httpContext.Username)
|
||||
user, err := user_model.GetUserByName(db.DefaultContext, httpContext.Username)
|
||||
assert.NoError(t, err)
|
||||
userID := user.ID
|
||||
|
||||
|
@ -321,7 +321,7 @@ func TestLDAPGroupTeamSyncAddMember(t *testing.T) {
|
||||
addAuthSourceLDAP(t, "", "on", `{"cn=ship_crew,ou=people,dc=planetexpress,dc=com":{"org26": ["team11"]},"cn=admin_staff,ou=people,dc=planetexpress,dc=com": {"non-existent": ["non-existent"]}}`)
|
||||
org, err := organization.GetOrgByName("org26")
|
||||
assert.NoError(t, err)
|
||||
team, err := organization.GetTeam(org.ID, "team11")
|
||||
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
|
||||
assert.NoError(t, err)
|
||||
auth.SyncExternalUsers(context.Background(), true)
|
||||
for _, gitLDAPUser := range gitLDAPUsers {
|
||||
@ -366,7 +366,7 @@ func TestLDAPGroupTeamSyncRemoveMember(t *testing.T) {
|
||||
addAuthSourceLDAP(t, "", "on", `{"cn=dispatch,ou=people,dc=planetexpress,dc=com": {"org26": ["team11"]}}`)
|
||||
org, err := organization.GetOrgByName("org26")
|
||||
assert.NoError(t, err)
|
||||
team, err := organization.GetTeam(org.ID, "team11")
|
||||
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
|
||||
assert.NoError(t, err)
|
||||
loginUserWithPassword(t, gitLDAPUsers[0].UserName, gitLDAPUsers[0].Password)
|
||||
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/perm"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
@ -438,7 +439,7 @@ func doProtectBranch(ctx APITestContext, branch, userToWhitelist, unprotectedFil
|
||||
})
|
||||
ctx.Session.MakeRequest(t, req, http.StatusSeeOther)
|
||||
} else {
|
||||
user, err := user_model.GetUserByName(userToWhitelist)
|
||||
user, err := user_model.GetUserByName(db.DefaultContext, userToWhitelist)
|
||||
assert.NoError(t, err)
|
||||
// Change branch to protected
|
||||
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings/branches/%s", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame), url.PathEscape(branch)), map[string]string{
|
||||
|
@ -75,7 +75,7 @@ func TestMirrorPull(t *testing.T) {
|
||||
IsTag: true,
|
||||
}, nil, ""))
|
||||
|
||||
_, err = repo_model.GetMirrorByRepoID(mirror.ID)
|
||||
_, err = repo_model.GetMirrorByRepoID(ctx, mirror.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ok := mirror_service.SyncPullMirror(ctx, mirror.ID)
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
@ -407,7 +408,7 @@ func TestConflictChecking(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "PR with conflict!"}).(*models.Issue)
|
||||
conflictingPR, err := models.GetPullRequestByIssueID(issue.ID)
|
||||
conflictingPR, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure conflictedFiles is populated.
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/git"
|
||||
@ -165,7 +166,7 @@ func createOutdatedPR(t *testing.T, actor, forkOrg *user_model.User) *models.Pul
|
||||
assert.NoError(t, err)
|
||||
|
||||
issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "Test Pull -to-update-"}).(*models.Issue)
|
||||
pr, err := models.GetPullRequestByIssueID(issue.ID)
|
||||
pr, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return pr
|
||||
|
@ -222,9 +222,8 @@ func (a *Action) getCommentLink(ctx context.Context) string {
|
||||
if a == nil {
|
||||
return "#"
|
||||
}
|
||||
e := db.GetEngine(ctx)
|
||||
if a.Comment == nil && a.CommentID != 0 {
|
||||
a.Comment, _ = getCommentByID(e, a.CommentID)
|
||||
a.Comment, _ = GetCommentByID(ctx, a.CommentID)
|
||||
}
|
||||
if a.Comment != nil {
|
||||
return a.Comment.HTMLURL()
|
||||
@ -239,7 +238,7 @@ func (a *Action) getCommentLink(ctx context.Context) string {
|
||||
return "#"
|
||||
}
|
||||
|
||||
issue, err := getIssueByID(e, issueID)
|
||||
issue, err := getIssueByID(ctx, issueID)
|
||||
if err != nil {
|
||||
return "#"
|
||||
}
|
||||
@ -340,8 +339,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e := db.GetEngine(ctx)
|
||||
sess := e.Where(cond).
|
||||
sess := db.GetEngine(ctx).Where(cond).
|
||||
Select("`action`.*"). // this line will avoid select other joined table's columns
|
||||
Join("INNER", "repository", "`repository`.id = `action`.repo_id")
|
||||
|
||||
@ -354,7 +352,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
|
||||
return nil, fmt.Errorf("Find: %v", err)
|
||||
}
|
||||
|
||||
if err := ActionList(actions).loadAttributes(e); err != nil {
|
||||
if err := ActionList(actions).loadAttributes(ctx); err != nil {
|
||||
return nil, fmt.Errorf("LoadAttributes: %v", err)
|
||||
}
|
||||
|
||||
@ -504,7 +502,7 @@ func notifyWatchers(ctx context.Context, actions ...*Action) error {
|
||||
permIssue = make([]bool, len(watchers))
|
||||
permPR = make([]bool, len(watchers))
|
||||
for i, watcher := range watchers {
|
||||
user, err := user_model.GetUserByIDEngine(e, watcher.UserID)
|
||||
user, err := user_model.GetUserByIDCtx(ctx, watcher.UserID)
|
||||
if err != nil {
|
||||
permCode[i] = false
|
||||
permIssue[i] = false
|
||||
|
@ -5,6 +5,7 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
@ -26,14 +27,14 @@ func (actions ActionList) getUserIDs() []int64 {
|
||||
return container.KeysInt64(userIDs)
|
||||
}
|
||||
|
||||
func (actions ActionList) loadUsers(e db.Engine) (map[int64]*user_model.User, error) {
|
||||
func (actions ActionList) loadUsers(ctx context.Context) (map[int64]*user_model.User, error) {
|
||||
if len(actions) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
userIDs := actions.getUserIDs()
|
||||
userMaps := make(map[int64]*user_model.User, len(userIDs))
|
||||
err := e.
|
||||
err := db.GetEngine(ctx).
|
||||
In("id", userIDs).
|
||||
Find(&userMaps)
|
||||
if err != nil {
|
||||
@ -56,14 +57,14 @@ func (actions ActionList) getRepoIDs() []int64 {
|
||||
return container.KeysInt64(repoIDs)
|
||||
}
|
||||
|
||||
func (actions ActionList) loadRepositories(e db.Engine) error {
|
||||
func (actions ActionList) loadRepositories(ctx context.Context) error {
|
||||
if len(actions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
repoIDs := actions.getRepoIDs()
|
||||
repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs))
|
||||
err := e.In("id", repoIDs).Find(&repoMaps)
|
||||
err := db.GetEngine(ctx).In("id", repoIDs).Find(&repoMaps)
|
||||
if err != nil {
|
||||
return fmt.Errorf("find repository: %v", err)
|
||||
}
|
||||
@ -74,7 +75,7 @@ func (actions ActionList) loadRepositories(e db.Engine) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_model.User) (err error) {
|
||||
func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*user_model.User) (err error) {
|
||||
if userMap == nil {
|
||||
userMap = make(map[int64]*user_model.User)
|
||||
}
|
||||
@ -85,7 +86,7 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
|
||||
}
|
||||
repoOwner, ok := userMap[action.Repo.OwnerID]
|
||||
if !ok {
|
||||
repoOwner, err = user_model.GetUserByID(action.Repo.OwnerID)
|
||||
repoOwner, err = user_model.GetUserByIDCtx(ctx, action.Repo.OwnerID)
|
||||
if err != nil {
|
||||
if user_model.IsErrUserNotExist(err) {
|
||||
continue
|
||||
@ -101,15 +102,15 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
|
||||
}
|
||||
|
||||
// loadAttributes loads all attributes
|
||||
func (actions ActionList) loadAttributes(e db.Engine) error {
|
||||
userMap, err := actions.loadUsers(e)
|
||||
func (actions ActionList) loadAttributes(ctx context.Context) error {
|
||||
userMap, err := actions.loadUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := actions.loadRepositories(e); err != nil {
|
||||
if err := actions.loadRepositories(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return actions.loadRepoOwner(e, userMap)
|
||||
return actions.loadRepoOwner(ctx, userMap)
|
||||
}
|
||||
|
@ -198,16 +198,16 @@ func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, erro
|
||||
}
|
||||
|
||||
// deleteGPGKey does the actual key deletion
|
||||
func deleteGPGKey(e db.Engine, keyID string) (int64, error) {
|
||||
func deleteGPGKey(ctx context.Context, keyID string) (int64, error) {
|
||||
if keyID == "" {
|
||||
return 0, fmt.Errorf("empty KeyId forbidden") // Should never happen but just to be sure
|
||||
}
|
||||
// Delete imported key
|
||||
n, err := e.Where("key_id=?", keyID).Delete(new(GPGKeyImport))
|
||||
n, err := db.GetEngine(ctx).Where("key_id=?", keyID).Delete(new(GPGKeyImport))
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return e.Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
|
||||
return db.GetEngine(ctx).Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
|
||||
}
|
||||
|
||||
// DeleteGPGKey deletes GPG key information in database.
|
||||
@ -231,7 +231,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
|
||||
}
|
||||
defer committer.Close()
|
||||
|
||||
if _, err = deleteGPGKey(db.GetEngine(ctx), key.KeyID); err != nil {
|
||||
if _, err = deleteGPGKey(ctx, key.KeyID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
@ -29,21 +30,21 @@ import (
|
||||
// This file contains functions relating to adding GPG Keys
|
||||
|
||||
// addGPGKey add key, import and subkeys to database
|
||||
func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) {
|
||||
func addGPGKey(ctx context.Context, key *GPGKey, content string) (err error) {
|
||||
// Add GPGKeyImport
|
||||
if _, err = e.Insert(GPGKeyImport{
|
||||
if err = db.Insert(ctx, &GPGKeyImport{
|
||||
KeyID: key.KeyID,
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
// Save GPG primary key.
|
||||
if _, err = e.Insert(key); err != nil {
|
||||
if err = db.Insert(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
// Save GPG subs key.
|
||||
for _, subkey := range key.SubsKey {
|
||||
if err := addGPGSubKey(e, subkey); err != nil {
|
||||
if err := addGPGSubKey(ctx, subkey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -51,14 +52,14 @@ func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) {
|
||||
}
|
||||
|
||||
// addGPGSubKey add subkeys to database
|
||||
func addGPGSubKey(e db.Engine, key *GPGKey) (err error) {
|
||||
func addGPGSubKey(ctx context.Context, key *GPGKey) (err error) {
|
||||
// Save GPG primary key.
|
||||
if _, err = e.Insert(key); err != nil {
|
||||
if err = db.Insert(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
// Save GPG subs key.
|
||||
for _, subkey := range key.SubsKey {
|
||||
if err := addGPGSubKey(e, subkey); err != nil {
|
||||
if err := addGPGSubKey(ctx, subkey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -158,7 +159,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = addGPGKey(db.GetEngine(ctx), key, content); err != nil {
|
||||
if err = addGPGKey(ctx, key, content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
|
@ -75,7 +75,7 @@ func (key *PublicKey) AuthorizedString() string {
|
||||
return AuthorizedStringForKey(key)
|
||||
}
|
||||
|
||||
func addKey(e db.Engine, key *PublicKey) (err error) {
|
||||
func addKey(ctx context.Context, key *PublicKey) (err error) {
|
||||
if len(key.Fingerprint) == 0 {
|
||||
key.Fingerprint, err = calcFingerprint(key.Content)
|
||||
if err != nil {
|
||||
@ -84,7 +84,7 @@ func addKey(e db.Engine, key *PublicKey) (err error) {
|
||||
}
|
||||
|
||||
// Save SSH key.
|
||||
if _, err = e.Insert(key); err != nil {
|
||||
if err = db.Insert(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -105,14 +105,13 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
|
||||
return nil, err
|
||||
}
|
||||
defer committer.Close()
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
if err := checkKeyFingerprint(sess, fingerprint); err != nil {
|
||||
if err := checkKeyFingerprint(ctx, fingerprint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Key name of same user cannot be duplicated.
|
||||
has, err := sess.
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("owner_id = ? AND name = ?", ownerID, name).
|
||||
Get(new(PublicKey))
|
||||
if err != nil {
|
||||
@ -130,7 +129,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
|
||||
Type: KeyTypeUser,
|
||||
LoginSourceID: authSourceID,
|
||||
}
|
||||
if err = addKey(sess, key); err != nil {
|
||||
if err = addKey(ctx, key); err != nil {
|
||||
return nil, fmt.Errorf("addKey: %v", err)
|
||||
}
|
||||
|
||||
@ -151,29 +150,12 @@ func GetPublicKeyByID(keyID int64) (*PublicKey, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func searchPublicKeyByContentWithEngine(e db.Engine, content string) (*PublicKey, error) {
|
||||
key := new(PublicKey)
|
||||
has, err := e.
|
||||
Where("content like ?", content+"%").
|
||||
Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrKeyNotExist{}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SearchPublicKeyByContent searches content as prefix (leak e-mail part)
|
||||
// and returns public key found.
|
||||
func SearchPublicKeyByContent(content string) (*PublicKey, error) {
|
||||
return searchPublicKeyByContentWithEngine(db.GetEngine(db.DefaultContext), content)
|
||||
}
|
||||
|
||||
func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*PublicKey, error) {
|
||||
func SearchPublicKeyByContent(ctx context.Context, content string) (*PublicKey, error) {
|
||||
key := new(PublicKey)
|
||||
has, err := e.
|
||||
Where("content = ?", content).
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("content like ?", content+"%").
|
||||
Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -185,8 +167,17 @@ func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*Publ
|
||||
|
||||
// SearchPublicKeyByContentExact searches content
|
||||
// and returns public key found.
|
||||
func SearchPublicKeyByContentExact(content string) (*PublicKey, error) {
|
||||
return searchPublicKeyByContentExactWithEngine(db.GetEngine(db.DefaultContext), content)
|
||||
func SearchPublicKeyByContentExact(ctx context.Context, content string) (*PublicKey, error) {
|
||||
key := new(PublicKey)
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("content = ?", content).
|
||||
Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrKeyNotExist{}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SearchPublicKey returns a list of public keys matching the provided arguments.
|
||||
@ -335,12 +326,11 @@ func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
defer committer.Close()
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
// Delete keys marked for deletion
|
||||
var sshKeysNeedUpdate bool
|
||||
for _, KeyToDelete := range keys {
|
||||
key, err := searchPublicKeyByContentWithEngine(sess, KeyToDelete)
|
||||
key, err := SearchPublicKeyByContent(ctx, KeyToDelete)
|
||||
if err != nil {
|
||||
log.Error("SearchPublicKeyByContent: %v", err)
|
||||
continue
|
||||
|
@ -6,6 +6,7 @@ package asymkey
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@ -165,7 +166,7 @@ func RewriteAllPublicKeys() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := RegeneratePublicKeys(t); err != nil {
|
||||
if err := RegeneratePublicKeys(db.DefaultContext, t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -174,12 +175,8 @@ func RewriteAllPublicKeys() error {
|
||||
}
|
||||
|
||||
// RegeneratePublicKeys regenerates the authorized_keys file
|
||||
func RegeneratePublicKeys(t io.StringWriter) error {
|
||||
return regeneratePublicKeys(db.GetEngine(db.DefaultContext), t)
|
||||
}
|
||||
|
||||
func regeneratePublicKeys(e db.Engine, t io.StringWriter) error {
|
||||
if err := e.Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
|
||||
func RegeneratePublicKeys(ctx context.Context, t io.StringWriter) error {
|
||||
if err := db.GetEngine(ctx).Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
|
||||
_, err = t.WriteString((bean.(*PublicKey)).AuthorizedString())
|
||||
return err
|
||||
}); err != nil {
|
||||
|
@ -6,6 +6,7 @@ package asymkey
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@ -42,11 +43,7 @@ const authorizedPrincipalsFile = "authorized_principals"
|
||||
// RewriteAllPrincipalKeys removes any authorized principal and rewrite all keys from database again.
|
||||
// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function
|
||||
// outside any session scope independently.
|
||||
func RewriteAllPrincipalKeys() error {
|
||||
return rewriteAllPrincipalKeys(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
func rewriteAllPrincipalKeys(e db.Engine) error {
|
||||
func RewriteAllPrincipalKeys(ctx context.Context) error {
|
||||
// Don't rewrite key if internal server
|
||||
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedPrincipalsFile {
|
||||
return nil
|
||||
@ -92,7 +89,7 @@ func rewriteAllPrincipalKeys(e db.Engine) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := regeneratePrincipalKeys(e, t); err != nil {
|
||||
if err := regeneratePrincipalKeys(ctx, t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -100,13 +97,8 @@ func rewriteAllPrincipalKeys(e db.Engine) error {
|
||||
return util.Rename(tmpPath, fPath)
|
||||
}
|
||||
|
||||
// RegeneratePrincipalKeys regenerates the authorized_principals file
|
||||
func RegeneratePrincipalKeys(t io.StringWriter) error {
|
||||
return regeneratePrincipalKeys(db.GetEngine(db.DefaultContext), t)
|
||||
}
|
||||
|
||||
func regeneratePrincipalKeys(e db.Engine, t io.StringWriter) error {
|
||||
if err := e.Where("type = ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
|
||||
func regeneratePrincipalKeys(ctx context.Context, t io.StringWriter) error {
|
||||
if err := db.GetEngine(ctx).Where("type = ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
|
||||
_, err = t.WriteString((bean.(*PublicKey)).AuthorizedString())
|
||||
return err
|
||||
}); err != nil {
|
||||
|
@ -67,9 +67,9 @@ func init() {
|
||||
db.RegisterModel(new(DeployKey))
|
||||
}
|
||||
|
||||
func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error {
|
||||
func checkDeployKey(ctx context.Context, keyID, repoID int64, name string) error {
|
||||
// Note: We want error detail, not just true or false here.
|
||||
has, err := e.
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("key_id = ? AND repo_id = ?", keyID, repoID).
|
||||
Get(new(DeployKey))
|
||||
if err != nil {
|
||||
@ -78,7 +78,7 @@ func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error {
|
||||
return ErrDeployKeyAlreadyExist{keyID, repoID}
|
||||
}
|
||||
|
||||
has, err = e.
|
||||
has, err = db.GetEngine(ctx).
|
||||
Where("repo_id = ? AND name = ?", repoID, name).
|
||||
Get(new(DeployKey))
|
||||
if err != nil {
|
||||
@ -91,8 +91,8 @@ func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error {
|
||||
}
|
||||
|
||||
// addDeployKey adds new key-repo relation.
|
||||
func addDeployKey(e db.Engine, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) {
|
||||
if err := checkDeployKey(e, keyID, repoID, name); err != nil {
|
||||
func addDeployKey(ctx context.Context, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) {
|
||||
if err := checkDeployKey(ctx, keyID, repoID, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -103,8 +103,7 @@ func addDeployKey(e db.Engine, keyID, repoID int64, name, fingerprint string, mo
|
||||
Fingerprint: fingerprint,
|
||||
Mode: mode,
|
||||
}
|
||||
_, err := e.Insert(key)
|
||||
return key, err
|
||||
return key, db.Insert(ctx, key)
|
||||
}
|
||||
|
||||
// HasDeployKey returns true if public key is a deploy key of given repository.
|
||||
@ -133,12 +132,10 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
|
||||
}
|
||||
defer committer.Close()
|
||||
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
pkey := &PublicKey{
|
||||
Fingerprint: fingerprint,
|
||||
}
|
||||
has, err := sess.Get(pkey)
|
||||
has, err := db.GetByBean(ctx, pkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -153,12 +150,12 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
|
||||
pkey.Type = KeyTypeDeploy
|
||||
pkey.Content = content
|
||||
pkey.Name = name
|
||||
if err = addKey(sess, pkey); err != nil {
|
||||
if err = addKey(ctx, pkey); err != nil {
|
||||
return nil, fmt.Errorf("addKey: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
key, err := addDeployKey(sess, pkey.ID, repoID, name, pkey.Fingerprint, accessMode)
|
||||
key, err := addDeployKey(ctx, pkey.ID, repoID, name, pkey.Fingerprint, accessMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -179,16 +176,12 @@ func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) {
|
||||
}
|
||||
|
||||
// GetDeployKeyByRepo returns deploy key by given public key ID and repository ID.
|
||||
func GetDeployKeyByRepo(keyID, repoID int64) (*DeployKey, error) {
|
||||
return getDeployKeyByRepo(db.GetEngine(db.DefaultContext), keyID, repoID)
|
||||
}
|
||||
|
||||
func getDeployKeyByRepo(e db.Engine, keyID, repoID int64) (*DeployKey, error) {
|
||||
func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) {
|
||||
key := &DeployKey{
|
||||
KeyID: keyID,
|
||||
RepoID: repoID,
|
||||
}
|
||||
has, err := e.Get(key)
|
||||
has, err := db.GetByBean(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
|
@ -5,6 +5,7 @@
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
@ -31,8 +32,8 @@ import (
|
||||
|
||||
// checkKeyFingerprint only checks if key fingerprint has been used as public key,
|
||||
// it is OK to use same key as deploy key for multiple repositories/users.
|
||||
func checkKeyFingerprint(e db.Engine, fingerprint string) error {
|
||||
has, err := e.Get(&PublicKey{
|
||||
func checkKeyFingerprint(ctx context.Context, fingerprint string) error {
|
||||
has, err := db.GetByBean(ctx, &PublicKey{
|
||||
Fingerprint: fingerprint,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -31,10 +31,9 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public
|
||||
return nil, err
|
||||
}
|
||||
defer committer.Close()
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
// Principals cannot be duplicated.
|
||||
has, err := sess.
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("content = ? AND type = ?", content, KeyTypePrincipal).
|
||||
Get(new(PublicKey))
|
||||
if err != nil {
|
||||
@ -51,7 +50,7 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public
|
||||
Type: KeyTypePrincipal,
|
||||
LoginSourceID: authSourceID,
|
||||
}
|
||||
if err = addPrincipalKey(sess, key); err != nil {
|
||||
if err = db.Insert(ctx, key); err != nil {
|
||||
return nil, fmt.Errorf("addKey: %v", err)
|
||||
}
|
||||
|
||||
@ -61,16 +60,7 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public
|
||||
|
||||
committer.Close()
|
||||
|
||||
return key, RewriteAllPrincipalKeys()
|
||||
}
|
||||
|
||||
func addPrincipalKey(e db.Engine, key *PublicKey) (err error) {
|
||||
// Save Key representing a principal.
|
||||
if _, err = e.Insert(key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return key, RewriteAllPrincipalKeys(db.DefaultContext)
|
||||
}
|
||||
|
||||
// CheckPrincipalKeyString strips spaces and returns an error if the given principal contains newlines
|
||||
|
@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
|
||||
}
|
||||
|
||||
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
|
||||
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
|
||||
return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
|
||||
}
|
||||
|
||||
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
|
||||
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
|
||||
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant
|
||||
}
|
||||
|
||||
// CreateGrant generates a grant for an user
|
||||
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
|
||||
return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
|
||||
}
|
||||
|
||||
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
|
||||
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
|
||||
grant := &OAuth2Grant{
|
||||
ApplicationID: app.ID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
}
|
||||
_, err := e.Insert(grant)
|
||||
err := db.Insert(ctx, grant)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
|
||||
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := e.Where("client_id = ?", clientID).Get(app)
|
||||
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
|
||||
if !has {
|
||||
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
|
||||
}
|
||||
@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
|
||||
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := e.ID(id).Get(app)
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(app)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
|
||||
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
|
||||
func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) {
|
||||
apps = make([]*OAuth2Application, 0)
|
||||
err = e.Where("uid = ?", userID).Find(&apps)
|
||||
err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps)
|
||||
return
|
||||
}
|
||||
|
||||
@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct {
|
||||
}
|
||||
|
||||
// CreateOAuth2Application inserts a new oauth2 application
|
||||
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
|
||||
}
|
||||
|
||||
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
clientID := uuid.New().String()
|
||||
app := &OAuth2Application{
|
||||
UID: opts.UserID,
|
||||
@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (
|
||||
ClientID: clientID,
|
||||
RedirectURIs: opts.RedirectURIs,
|
||||
}
|
||||
if _, err := e.Insert(app); err != nil {
|
||||
if err := db.Insert(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return app, nil
|
||||
@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
|
||||
return nil, err
|
||||
}
|
||||
defer committer.Close()
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
app, err := getOAuth2ApplicationByID(sess, opts.ID)
|
||||
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
|
||||
app.Name = opts.Name
|
||||
app.RedirectURIs = opts.RedirectURIs
|
||||
|
||||
if err = updateOAuth2Application(sess, app); err != nil {
|
||||
if err = updateOAuth2Application(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app.ClientSecret = ""
|
||||
@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
|
||||
return app, committer.Commit()
|
||||
}
|
||||
|
||||
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
|
||||
if _, err := e.ID(app.ID).Update(app); err != nil {
|
||||
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
|
||||
if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
|
||||
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
|
||||
sess := db.GetEngine(ctx)
|
||||
if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
|
||||
return err
|
||||
} else if deleted == 0 {
|
||||
@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error {
|
||||
return err
|
||||
}
|
||||
defer committer.Close()
|
||||
if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
|
||||
if err := deleteOAuth2Application(ctx, id, userid); err != nil {
|
||||
return err
|
||||
}
|
||||
return committer.Commit()
|
||||
@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect
|
||||
}
|
||||
|
||||
// Invalidate deletes the auth code from the database to invalidate this code
|
||||
func (code *OAuth2AuthorizationCode) Invalidate() error {
|
||||
return code.invalidate(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
|
||||
_, err := e.Delete(code)
|
||||
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
|
||||
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
|
||||
return code.validateCodeChallenge(verifier)
|
||||
}
|
||||
|
||||
func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
|
||||
switch code.CodeChallengeMethod {
|
||||
case "S256":
|
||||
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
|
||||
@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool
|
||||
}
|
||||
|
||||
// GetOAuth2AuthorizationByCode returns an authorization by its code
|
||||
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
|
||||
return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
|
||||
}
|
||||
|
||||
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
|
||||
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
|
||||
auth = new(OAuth2AuthorizationCode)
|
||||
if has, err := e.Where("code = ?", code).Get(auth); err != nil {
|
||||
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
auth.Grant = new(OAuth2Grant)
|
||||
if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
|
||||
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string {
|
||||
}
|
||||
|
||||
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
|
||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
|
||||
return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
|
||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
|
||||
rBytes, err := util.CryptoRandomBytes(32)
|
||||
if err != nil {
|
||||
return &OAuth2AuthorizationCode{}, err
|
||||
@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
}
|
||||
if _, err := e.Insert(code); err != nil {
|
||||
if err := db.Insert(ctx, code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// IncreaseCounter increases the counter and updates the grant
|
||||
func (grant *OAuth2Grant) IncreaseCounter() error {
|
||||
return grant.increaseCount(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
|
||||
_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
|
||||
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
|
||||
updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool {
|
||||
}
|
||||
|
||||
// SetNonce updates the current nonce value of a grant
|
||||
func (grant *OAuth2Grant) SetNonce(nonce string) error {
|
||||
return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
|
||||
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
|
||||
grant.Nonce = nonce
|
||||
_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
|
||||
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
|
||||
}
|
||||
|
||||
// GetOAuth2GrantByID returns the grant with the given ID
|
||||
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
|
||||
return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
|
||||
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := e.ID(id).Get(grant); err != nil {
|
||||
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
|
||||
}
|
||||
|
||||
// GetOAuth2GrantsByUserID lists all grants of a certain user
|
||||
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
|
||||
return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
|
||||
}
|
||||
|
||||
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
|
||||
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
|
||||
type joinedOAuth2Grant struct {
|
||||
Grant *OAuth2Grant `xorm:"extends"`
|
||||
Application *OAuth2Application `xorm:"extends"`
|
||||
}
|
||||
var results *xorm.Rows
|
||||
var err error
|
||||
if results, err = e.
|
||||
if results, err = db.GetEngine(ctx).
|
||||
Table("oauth2_grant").
|
||||
Where("user_id = ?", uid).
|
||||
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
|
||||
@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
|
||||
}
|
||||
|
||||
// RevokeOAuth2Grant deletes the grant with grantID and userID
|
||||
func RevokeOAuth2Grant(grantID, userID int64) error {
|
||||
return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
|
||||
}
|
||||
|
||||
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
|
||||
_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
|
||||
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
|
||||
_, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID})
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@ package auth
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -52,18 +53,18 @@ func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
|
||||
|
||||
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138")
|
||||
app, err := GetOAuth2ApplicationByClientID(db.DefaultContext, "da7da3ba-9a13-4167-856f-3899de0b0138")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
|
||||
|
||||
app, err = GetOAuth2ApplicationByClientID("invalid client id")
|
||||
app, err = GetOAuth2ApplicationByClientID(db.DefaultContext, "invalid client id")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, app)
|
||||
}
|
||||
|
||||
func TestCreateOAuth2Application(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
|
||||
app, err := CreateOAuth2Application(db.DefaultContext, CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newapp", app.Name)
|
||||
assert.Len(t, app.ClientID, 36)
|
||||
@ -77,11 +78,11 @@ func TestOAuth2Application_TableName(t *testing.T) {
|
||||
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
|
||||
grant, err := app.GetGrantByUserID(1)
|
||||
grant, err := app.GetGrantByUserID(db.DefaultContext, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.UserID)
|
||||
|
||||
grant, err = app.GetGrantByUserID(34923458)
|
||||
grant, err = app.GetGrantByUserID(db.DefaultContext, 34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
@ -89,7 +90,7 @@ func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
|
||||
func TestOAuth2Application_CreateGrant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
|
||||
grant, err := app.CreateGrant(2, "")
|
||||
grant, err := app.CreateGrant(db.DefaultContext, 2, "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, grant)
|
||||
assert.Equal(t, int64(2), grant.UserID)
|
||||
@ -101,11 +102,11 @@ func TestOAuth2Application_CreateGrant(t *testing.T) {
|
||||
|
||||
func TestGetOAuth2GrantByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant, err := GetOAuth2GrantByID(1)
|
||||
grant, err := GetOAuth2GrantByID(db.DefaultContext, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.ID)
|
||||
|
||||
grant, err = GetOAuth2GrantByID(34923458)
|
||||
grant, err = GetOAuth2GrantByID(db.DefaultContext, 34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
@ -113,7 +114,7 @@ func TestGetOAuth2GrantByID(t *testing.T) {
|
||||
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
|
||||
assert.NoError(t, grant.IncreaseCounter())
|
||||
assert.NoError(t, grant.IncreaseCounter(db.DefaultContext))
|
||||
assert.Equal(t, int64(2), grant.Counter)
|
||||
unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
|
||||
}
|
||||
@ -130,7 +131,7 @@ func TestOAuth2Grant_ScopeContains(t *testing.T) {
|
||||
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
|
||||
code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
|
||||
code, err := grant.GenerateNewAuthorizationCode(db.DefaultContext, "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.True(t, len(code.Code) > 32) // secret length > 32
|
||||
@ -142,20 +143,20 @@ func TestOAuth2Grant_TableName(t *testing.T) {
|
||||
|
||||
func TestGetOAuth2GrantsByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
result, err := GetOAuth2GrantsByUserID(1)
|
||||
result, err := GetOAuth2GrantsByUserID(db.DefaultContext, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, int64(1), result[0].ID)
|
||||
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
|
||||
|
||||
result, err = GetOAuth2GrantsByUserID(34134)
|
||||
result, err = GetOAuth2GrantsByUserID(db.DefaultContext, 34134)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestRevokeOAuth2Grant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.NoError(t, RevokeOAuth2Grant(1, 1))
|
||||
assert.NoError(t, RevokeOAuth2Grant(db.DefaultContext, 1, 1))
|
||||
unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
|
||||
}
|
||||
|
||||
@ -163,13 +164,13 @@ func TestRevokeOAuth2Grant(t *testing.T) {
|
||||
|
||||
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
code, err := GetOAuth2AuthorizationByCode("authcode")
|
||||
code, err := GetOAuth2AuthorizationByCode(db.DefaultContext, "authcode")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.Equal(t, "authcode", code.Code)
|
||||
assert.Equal(t, int64(1), code.ID)
|
||||
|
||||
code, err = GetOAuth2AuthorizationByCode("does not exist")
|
||||
code, err = GetOAuth2AuthorizationByCode(db.DefaultContext, "does not exist")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, code)
|
||||
}
|
||||
@ -224,7 +225,7 @@ func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
|
||||
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
|
||||
assert.NoError(t, code.Invalidate())
|
||||
assert.NoError(t, code.Invalidate(db.DefaultContext))
|
||||
unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
|
||||
}
|
||||
|
||||
|
@ -306,13 +306,9 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa
|
||||
}
|
||||
|
||||
// GetProtectedBranchBy getting protected branch by ID/Name
|
||||
func GetProtectedBranchBy(repoID int64, branchName string) (*ProtectedBranch, error) {
|
||||
return getProtectedBranchBy(db.GetEngine(db.DefaultContext), repoID, branchName)
|
||||
}
|
||||
|
||||
func getProtectedBranchBy(e db.Engine, repoID int64, branchName string) (*ProtectedBranch, error) {
|
||||
func GetProtectedBranchBy(ctx context.Context, repoID int64, branchName string) (*ProtectedBranch, error) {
|
||||
rel := &ProtectedBranch{RepoID: repoID, BranchName: branchName}
|
||||
has, err := e.Get(rel)
|
||||
has, err := db.GetByBean(ctx, rel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -632,7 +628,7 @@ func RenameBranch(repo *repo_model.Repository, from, to string, gitAction func(i
|
||||
}
|
||||
|
||||
// 2. Update protected branch if needed
|
||||
protectedBranch, err := getProtectedBranchBy(sess, repo.ID, from)
|
||||
protectedBranch, err := GetProtectedBranchBy(ctx, repo.ID, from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -49,21 +49,21 @@ func init() {
|
||||
}
|
||||
|
||||
// upsertCommitStatusIndex the function will not return until it acquires the lock or receives an error.
|
||||
func upsertCommitStatusIndex(e db.Engine, repoID int64, sha string) (err error) {
|
||||
func upsertCommitStatusIndex(ctx context.Context, repoID int64, sha string) (err error) {
|
||||
// An atomic UPSERT operation (INSERT/UPDATE) is the only operation
|
||||
// that ensures that the key is actually locked.
|
||||
switch {
|
||||
case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL:
|
||||
_, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
|
||||
_, err = db.Exec(ctx, "INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
|
||||
"VALUES (?,?,1) ON CONFLICT (repo_id,sha) DO UPDATE SET max_index = `commit_status_index`.max_index+1",
|
||||
repoID, sha)
|
||||
case setting.Database.UseMySQL:
|
||||
_, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
|
||||
_, err = db.Exec(ctx, "INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
|
||||
"VALUES (?,?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1",
|
||||
repoID, sha)
|
||||
case setting.Database.UseMSSQL:
|
||||
// https://weblogs.sqlteam.com/dang/2009/01/31/upsert-race-condition-with-merge/
|
||||
_, err = e.Exec("MERGE `commit_status_index` WITH (HOLDLOCK) as target "+
|
||||
_, err = db.Exec(ctx, "MERGE `commit_status_index` WITH (HOLDLOCK) as target "+
|
||||
"USING (SELECT ? AS repo_id, ? AS sha) AS src "+
|
||||
"ON src.repo_id = target.repo_id AND src.sha = target.sha "+
|
||||
"WHEN MATCHED THEN UPDATE SET target.max_index = target.max_index+1 "+
|
||||
@ -100,17 +100,17 @@ func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) {
|
||||
defer commiter.Close()
|
||||
|
||||
var preIdx int64
|
||||
_, err = ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?", repoID, sha).Get(&preIdx)
|
||||
_, err = db.GetEngine(ctx).SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?", repoID, sha).Get(&preIdx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := upsertCommitStatusIndex(ctx.Engine(), repoID, sha); err != nil {
|
||||
if err := upsertCommitStatusIndex(ctx, repoID, sha); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var curIdx int64
|
||||
has, err := ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ? AND max_index=?", repoID, sha, preIdx+1).Get(&curIdx)
|
||||
has, err := db.GetEngine(ctx).SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ? AND max_index=?", repoID, sha, preIdx+1).Get(&curIdx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -131,7 +131,7 @@ func (status *CommitStatus) loadAttributes(ctx context.Context) (err error) {
|
||||
}
|
||||
}
|
||||
if status.Creator == nil && status.CreatorID > 0 {
|
||||
status.Creator, err = user_model.GetUserByIDEngine(db.GetEngine(ctx), status.CreatorID)
|
||||
status.Creator, err = user_model.GetUserByIDCtx(ctx, status.CreatorID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getUserByID [%d]: %v", status.CreatorID, err)
|
||||
}
|
||||
@ -231,12 +231,7 @@ type CommitStatusIndex struct {
|
||||
}
|
||||
|
||||
// GetLatestCommitStatus returns all statuses with a unique context for a given commit.
|
||||
func GetLatestCommitStatus(repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
|
||||
return GetLatestCommitStatusCtx(db.DefaultContext, repoID, sha, listOptions)
|
||||
}
|
||||
|
||||
// GetLatestCommitStatusCtx returns all statuses with a unique context for a given commit.
|
||||
func GetLatestCommitStatusCtx(ctx context.Context, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
|
||||
func GetLatestCommitStatus(ctx context.Context, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
|
||||
ids := make([]int64, 0, 10)
|
||||
sess := db.GetEngine(ctx).Table(&CommitStatus{}).
|
||||
Where("repo_id = ?", repoID).And("sha = ?", sha).
|
||||
@ -341,7 +336,7 @@ func ParseCommitsWithStatus(oldCommits []*asymkey_model.SignCommit, repo *repo_m
|
||||
commit := &SignCommitWithStatuses{
|
||||
SignCommit: c,
|
||||
}
|
||||
statuses, _, err := GetLatestCommitStatus(repo.ID, commit.ID.String(), db.ListOptions{})
|
||||
statuses, _, err := GetLatestCommitStatus(db.DefaultContext, repo.ID, commit.ID.String(), db.ListOptions{})
|
||||
if err != nil {
|
||||
log.Error("GetLatestCommitStatus: %v", err)
|
||||
} else {
|
||||
|
@ -117,7 +117,7 @@ func DeleteOrphanedIssues() error {
|
||||
|
||||
var attachmentPaths []string
|
||||
for i := range ids {
|
||||
paths, err := deleteIssuesByRepoID(db.GetEngine(ctx), ids[i])
|
||||
paths, err := deleteIssuesByRepoID(ctx, ids[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
@ -74,8 +75,8 @@ func GetNextResourceIndex(tableName string, groupID int64) (int64, error) {
|
||||
}
|
||||
|
||||
// DeleteResouceIndex delete resource index
|
||||
func DeleteResouceIndex(e Engine, tableName string, groupID int64) error {
|
||||
_, err := e.Exec(fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
|
||||
func DeleteResouceIndex(ctx context.Context, tableName string, groupID int64) error {
|
||||
_, err := Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
153
models/issue.go
153
models/issue.go
File diff suppressed because it is too large
Load Diff
@ -25,20 +25,15 @@ func init() {
|
||||
}
|
||||
|
||||
// LoadAssignees load assignees of this issue.
|
||||
func (issue *Issue) LoadAssignees() error {
|
||||
return issue.loadAssignees(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
// This loads all assignees of an issue
|
||||
func (issue *Issue) loadAssignees(e db.Engine) (err error) {
|
||||
func (issue *Issue) LoadAssignees(ctx context.Context) (err error) {
|
||||
// Reset maybe preexisting assignees
|
||||
issue.Assignees = []*user_model.User{}
|
||||
issue.Assignee = nil
|
||||
|
||||
err = e.Table("`user`").
|
||||
err = db.GetEngine(ctx).Table("`user`").
|
||||
Join("INNER", "issue_assignees", "assignee_id = `user`.id").
|
||||
Where("issue_assignees.issue_id = ?", issue.ID).
|
||||
Find(&issue.Assignees)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -47,7 +42,6 @@ func (issue *Issue) loadAssignees(e db.Engine) (err error) {
|
||||
if len(issue.Assignees) > 0 {
|
||||
issue.Assignee = issue.Assignees[0]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -63,33 +57,9 @@ func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) {
|
||||
Find(&userIDs)
|
||||
}
|
||||
|
||||
// GetAssigneesByIssue returns everyone assigned to that issue
|
||||
func GetAssigneesByIssue(issue *Issue) (assignees []*user_model.User, err error) {
|
||||
return getAssigneesByIssue(db.GetEngine(db.DefaultContext), issue)
|
||||
}
|
||||
|
||||
func getAssigneesByIssue(e db.Engine, issue *Issue) (assignees []*user_model.User, err error) {
|
||||
err = issue.loadAssignees(e)
|
||||
if err != nil {
|
||||
return assignees, err
|
||||
}
|
||||
|
||||
return issue.Assignees, nil
|
||||
}
|
||||
|
||||
// IsUserAssignedToIssue returns true when the user is assigned to the issue
|
||||
func IsUserAssignedToIssue(issue *Issue, user *user_model.User) (isAssigned bool, err error) {
|
||||
return isUserAssignedToIssue(db.GetEngine(db.DefaultContext), issue, user)
|
||||
}
|
||||
|
||||
func isUserAssignedToIssue(e db.Engine, issue *Issue, user *user_model.User) (isAssigned bool, err error) {
|
||||
return e.Get(&IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID})
|
||||
}
|
||||
|
||||
// ClearAssigneeByUserID deletes all assignments of an user
|
||||
func clearAssigneeByUserID(sess db.Engine, userID int64) (err error) {
|
||||
_, err = sess.Delete(&IssueAssignees{AssigneeID: userID})
|
||||
return
|
||||
func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.User) (isAssigned bool, err error) {
|
||||
return db.GetByBean(ctx, &IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID})
|
||||
}
|
||||
|
||||
// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
|
||||
@ -113,8 +83,7 @@ func ToggleIssueAssignee(issue *Issue, doer *user_model.User, assigneeID int64)
|
||||
}
|
||||
|
||||
func toggleIssueAssignee(ctx context.Context, issue *Issue, doer *user_model.User, assigneeID int64, isCreate bool) (removed bool, comment *Comment, err error) {
|
||||
sess := db.GetEngine(ctx)
|
||||
removed, err = toggleUserAssignee(sess, issue, assigneeID)
|
||||
removed, err = toggleUserAssignee(ctx, issue, assigneeID)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("UpdateIssueUserByAssignee: %v", err)
|
||||
}
|
||||
@ -147,39 +116,38 @@ func toggleIssueAssignee(ctx context.Context, issue *Issue, doer *user_model.Use
|
||||
}
|
||||
|
||||
// toggles user assignee state in database
|
||||
func toggleUserAssignee(e db.Engine, issue *Issue, assigneeID int64) (removed bool, err error) {
|
||||
func toggleUserAssignee(ctx context.Context, issue *Issue, assigneeID int64) (removed bool, err error) {
|
||||
// Check if the user exists
|
||||
assignee, err := user_model.GetUserByIDEngine(e, assigneeID)
|
||||
assignee, err := user_model.GetUserByIDCtx(ctx, assigneeID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Check if the submitted user is already assigned, if yes delete him otherwise add him
|
||||
var i int
|
||||
for i = 0; i < len(issue.Assignees); i++ {
|
||||
found := false
|
||||
i := 0
|
||||
for ; i < len(issue.Assignees); i++ {
|
||||
if issue.Assignees[i].ID == assigneeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assigneeIn := IssueAssignees{AssigneeID: assigneeID, IssueID: issue.ID}
|
||||
|
||||
toBeDeleted := i < len(issue.Assignees)
|
||||
if toBeDeleted {
|
||||
issue.Assignees = append(issue.Assignees[:i], issue.Assignees[i:]...)
|
||||
_, err = e.Delete(assigneeIn)
|
||||
if found {
|
||||
issue.Assignees = append(issue.Assignees[:i], issue.Assignees[i+1:]...)
|
||||
_, err = db.DeleteByBean(ctx, &assigneeIn)
|
||||
if err != nil {
|
||||
return toBeDeleted, err
|
||||
return found, err
|
||||
}
|
||||
} else {
|
||||
issue.Assignees = append(issue.Assignees, assignee)
|
||||
_, err = e.Insert(assigneeIn)
|
||||
if err != nil {
|
||||
return toBeDeleted, err
|
||||
if err = db.Insert(ctx, &assigneeIn); err != nil {
|
||||
return found, err
|
||||
}
|
||||
}
|
||||
|
||||
return toBeDeleted, nil
|
||||
return found, nil
|
||||
}
|
||||
|
||||
// MakeIDsFromAPIAssigneesToAdd returns an array with all assignee IDs
|
||||
|
@ -7,6 +7,7 @@ package models
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
|
||||
@ -37,28 +38,28 @@ func TestUpdateAssignee(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check if he got removed
|
||||
isAssigned, err := IsUserAssignedToIssue(issue, user1)
|
||||
isAssigned, err := IsUserAssignedToIssue(db.DefaultContext, issue, user1)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isAssigned)
|
||||
|
||||
// Check if they're all there
|
||||
assignees, err := GetAssigneesByIssue(issue)
|
||||
err = issue.LoadAssignees(db.DefaultContext)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var expectedAssignees []*user_model.User
|
||||
expectedAssignees = append(expectedAssignees, user2, user3)
|
||||
|
||||
for in, assignee := range assignees {
|
||||
for in, assignee := range issue.Assignees {
|
||||
assert.Equal(t, assignee.ID, expectedAssignees[in].ID)
|
||||
}
|
||||
|
||||
// Check if the user is assigned
|
||||
isAssigned, err = IsUserAssignedToIssue(issue, user2)
|
||||
isAssigned, err = IsUserAssignedToIssue(db.DefaultContext, issue, user2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isAssigned)
|
||||
|
||||
// This user should not be assigned
|
||||
isAssigned, err = IsUserAssignedToIssue(issue, &user_model.User{ID: 4})
|
||||
isAssigned, err = IsUserAssignedToIssue(db.DefaultContext, issue, &user_model.User{ID: 4})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isAssigned)
|
||||
}
|
||||
|
@ -298,7 +298,7 @@ func (c *Comment) LoadIssueCtx(ctx context.Context) (err error) {
|
||||
if c.Issue != nil {
|
||||
return nil
|
||||
}
|
||||
c.Issue, err = getIssueByID(db.GetEngine(ctx), c.IssueID)
|
||||
c.Issue, err = getIssueByID(ctx, c.IssueID)
|
||||
return
|
||||
}
|
||||
|
||||
@ -329,12 +329,12 @@ func (c *Comment) AfterLoad(session *xorm.Session) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Comment) loadPoster(e db.Engine) (err error) {
|
||||
func (c *Comment) loadPoster(ctx context.Context) (err error) {
|
||||
if c.PosterID <= 0 || c.Poster != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Poster, err = user_model.GetUserByIDEngine(e, c.PosterID)
|
||||
c.Poster, err = user_model.GetUserByIDCtx(ctx, c.PosterID)
|
||||
if err != nil {
|
||||
if user_model.IsErrUserNotExist(err) {
|
||||
c.PosterID = -1
|
||||
@ -525,7 +525,7 @@ func (c *Comment) LoadMilestone() error {
|
||||
|
||||
// LoadPoster loads comment poster
|
||||
func (c *Comment) LoadPoster() error {
|
||||
return c.loadPoster(db.GetEngine(db.DefaultContext))
|
||||
return c.loadPoster(db.DefaultContext)
|
||||
}
|
||||
|
||||
// LoadAttachments loads attachments (it never returns error, the error during `GetAttachmentsByCommentIDCtx` is ignored)
|
||||
@ -535,7 +535,7 @@ func (c *Comment) LoadAttachments() error {
|
||||
}
|
||||
|
||||
var err error
|
||||
c.Attachments, err = repo_model.GetAttachmentsByCommentIDCtx(db.DefaultContext, c.ID)
|
||||
c.Attachments, err = repo_model.GetAttachmentsByCommentID(db.DefaultContext, c.ID)
|
||||
if err != nil {
|
||||
log.Error("getAttachmentsByCommentID[%d]: %v", c.ID, err)
|
||||
}
|
||||
@ -557,7 +557,7 @@ func (c *Comment) UpdateAttachments(uuids []string) error {
|
||||
for i := 0; i < len(attachments); i++ {
|
||||
attachments[i].IssueID = c.IssueID
|
||||
attachments[i].CommentID = c.ID
|
||||
if err := repo_model.UpdateAttachmentCtx(ctx, attachments[i]); err != nil {
|
||||
if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil {
|
||||
return fmt.Errorf("update attachment [id: %d]: %v", attachments[i].ID, err)
|
||||
}
|
||||
}
|
||||
@ -590,7 +590,7 @@ func (c *Comment) LoadAssigneeUserAndTeam() error {
|
||||
}
|
||||
|
||||
if c.Issue.Repo.Owner.IsOrganization() {
|
||||
c.AssigneeTeam, err = organization.GetTeamByID(c.AssigneeTeamID)
|
||||
c.AssigneeTeam, err = organization.GetTeamByID(db.DefaultContext, c.AssigneeTeamID)
|
||||
if err != nil && !organization.IsErrTeamNotExist(err) {
|
||||
return err
|
||||
}
|
||||
@ -624,7 +624,7 @@ func (c *Comment) LoadDepIssueDetails() (err error) {
|
||||
if c.DependentIssueID <= 0 || c.DependentIssue != nil {
|
||||
return nil
|
||||
}
|
||||
c.DependentIssue, err = getIssueByID(db.GetEngine(db.DefaultContext), c.DependentIssueID)
|
||||
c.DependentIssue, err = getIssueByID(db.DefaultContext, c.DependentIssueID)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -661,9 +661,9 @@ func (c *Comment) LoadReactions(repo *repo_model.Repository) error {
|
||||
return c.loadReactions(db.DefaultContext, repo)
|
||||
}
|
||||
|
||||
func (c *Comment) loadReview(e db.Engine) (err error) {
|
||||
func (c *Comment) loadReview(ctx context.Context) (err error) {
|
||||
if c.Review == nil {
|
||||
if c.Review, err = getReviewByID(e, c.ReviewID); err != nil {
|
||||
if c.Review, err = GetReviewByID(ctx, c.ReviewID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -673,7 +673,7 @@ func (c *Comment) loadReview(e db.Engine) (err error) {
|
||||
|
||||
// LoadReview loads the associated review
|
||||
func (c *Comment) LoadReview() error {
|
||||
return c.loadReview(db.GetEngine(db.DefaultContext))
|
||||
return c.loadReview(db.DefaultContext)
|
||||
}
|
||||
|
||||
var notEnoughLines = regexp.MustCompile(`fatal: file .* has only \d+ lines?`)
|
||||
@ -830,13 +830,12 @@ func CreateCommentCtx(ctx context.Context, opts *CreateCommentOptions) (_ *Comme
|
||||
}
|
||||
|
||||
func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment *Comment) (err error) {
|
||||
e := db.GetEngine(ctx)
|
||||
// Check comment type.
|
||||
switch opts.Type {
|
||||
case CommentTypeCode:
|
||||
if comment.ReviewID != 0 {
|
||||
if comment.Review == nil {
|
||||
if err := comment.loadReview(e); err != nil {
|
||||
if err := comment.loadReview(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -846,7 +845,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment
|
||||
}
|
||||
fallthrough
|
||||
case CommentTypeComment:
|
||||
if _, err = e.Exec("UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil {
|
||||
if _, err = db.Exec(ctx, "UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
@ -861,7 +860,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment
|
||||
attachments[i].IssueID = opts.Issue.ID
|
||||
attachments[i].CommentID = comment.ID
|
||||
// No assign value could be 0, so ignore AllCols().
|
||||
if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil {
|
||||
if _, err = db.GetEngine(ctx).ID(attachments[i].ID).Update(attachments[i]); err != nil {
|
||||
return fmt.Errorf("update attachment [%d]: %v", attachments[i].ID, err)
|
||||
}
|
||||
}
|
||||
@ -1031,13 +1030,9 @@ func CreateRefComment(doer *user_model.User, repo *repo_model.Repository, issue
|
||||
}
|
||||
|
||||
// GetCommentByID returns the comment by given ID.
|
||||
func GetCommentByID(id int64) (*Comment, error) {
|
||||
return getCommentByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getCommentByID(e db.Engine, id int64) (*Comment, error) {
|
||||
func GetCommentByID(ctx context.Context, id int64) (*Comment, error) {
|
||||
c := new(Comment)
|
||||
has, err := e.ID(id).Get(c)
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
@ -1088,9 +1083,10 @@ func (opts *FindCommentsOptions) toConds() builder.Cond {
|
||||
return cond
|
||||
}
|
||||
|
||||
func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) {
|
||||
// FindComments returns all comments according options
|
||||
func FindComments(ctx context.Context, opts *FindCommentsOptions) ([]*Comment, error) {
|
||||
comments := make([]*Comment, 0, 10)
|
||||
sess := e.Where(opts.toConds())
|
||||
sess := db.GetEngine(ctx).Where(opts.toConds())
|
||||
if opts.RepoID > 0 {
|
||||
sess.Join("INNER", "issue", "issue.id = comment.issue_id")
|
||||
}
|
||||
@ -1107,11 +1103,6 @@ func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) {
|
||||
Find(&comments)
|
||||
}
|
||||
|
||||
// FindComments returns all comments according options
|
||||
func FindComments(opts *FindCommentsOptions) ([]*Comment, error) {
|
||||
return findComments(db.GetEngine(db.DefaultContext), opts)
|
||||
}
|
||||
|
||||
// CountComments count all comments according options by ignoring pagination
|
||||
func CountComments(opts *FindCommentsOptions) (int64, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).Where(opts.toConds())
|
||||
@ -1167,7 +1158,7 @@ func deleteComment(ctx context.Context, comment *Comment) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := e.Delete(&issues_model.ContentHistory{
|
||||
if _, err := db.DeleteByBean(ctx, &issues_model.ContentHistory{
|
||||
CommentID: comment.ID,
|
||||
}); err != nil {
|
||||
return err
|
||||
@ -1182,7 +1173,7 @@ func deleteComment(ctx context.Context, comment *Comment) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := comment.neuterCrossReferences(e); err != nil {
|
||||
if err := comment.neuterCrossReferences(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1192,7 +1183,8 @@ func deleteComment(ctx context.Context, comment *Comment) error {
|
||||
// CodeComments represents comments on code by using this structure: FILENAME -> LINE (+ == proposed; - == previous) -> COMMENTS
|
||||
type CodeComments map[string]map[int64][]*Comment
|
||||
|
||||
func fetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) {
|
||||
// FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line
|
||||
func FetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) {
|
||||
return fetchCodeCommentsByReview(ctx, issue, currentUser, nil)
|
||||
}
|
||||
|
||||
@ -1242,7 +1234,7 @@ func findCodeComments(ctx context.Context, opts FindCommentsOptions, issue *Issu
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := CommentList(comments).loadPosters(e); err != nil {
|
||||
if err := CommentList(comments).loadPosters(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -1302,11 +1294,6 @@ func FetchCodeCommentsByLine(ctx context.Context, issue *Issue, currentUser *use
|
||||
return findCodeComments(ctx, opts, issue, currentUser, nil)
|
||||
}
|
||||
|
||||
// FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line
|
||||
func FetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) {
|
||||
return fetchCodeComments(ctx, issue, currentUser)
|
||||
}
|
||||
|
||||
// UpdateCommentsMigrationsByType updates comments' migrations information via given git service type and original id and poster id
|
||||
func UpdateCommentsMigrationsByType(tp structs.GitServiceType, originalAuthorID string, posterID int64) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).Table("comment").
|
||||
|
@ -27,7 +27,7 @@ func (comments CommentList) getPosterIDs() []int64 {
|
||||
return container.KeysInt64(posterIDs)
|
||||
}
|
||||
|
||||
func (comments CommentList) loadPosters(e db.Engine) error {
|
||||
func (comments CommentList) loadPosters(ctx context.Context) error {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -40,7 +40,7 @@ func (comments CommentList) loadPosters(e db.Engine) error {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
err := e.
|
||||
err := db.GetEngine(ctx).
|
||||
In("id", posterIDs[:limit]).
|
||||
Find(&posterMaps)
|
||||
if err != nil {
|
||||
@ -80,7 +80,7 @@ func (comments CommentList) getLabelIDs() []int64 {
|
||||
return container.KeysInt64(ids)
|
||||
}
|
||||
|
||||
func (comments CommentList) loadLabels(e db.Engine) error {
|
||||
func (comments CommentList) loadLabels(ctx context.Context) error {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -93,7 +93,7 @@ func (comments CommentList) loadLabels(e db.Engine) error {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
rows, err := e.
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", labelIDs[:limit]).
|
||||
Rows(new(Label))
|
||||
if err != nil {
|
||||
@ -130,7 +130,7 @@ func (comments CommentList) getMilestoneIDs() []int64 {
|
||||
return container.KeysInt64(ids)
|
||||
}
|
||||
|
||||
func (comments CommentList) loadMilestones(e db.Engine) error {
|
||||
func (comments CommentList) loadMilestones(ctx context.Context) error {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -147,7 +147,7 @@ func (comments CommentList) loadMilestones(e db.Engine) error {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
err := e.
|
||||
err := db.GetEngine(ctx).
|
||||
In("id", milestoneIDs[:limit]).
|
||||
Find(&milestoneMaps)
|
||||
if err != nil {
|
||||
@ -173,7 +173,7 @@ func (comments CommentList) getOldMilestoneIDs() []int64 {
|
||||
return container.KeysInt64(ids)
|
||||
}
|
||||
|
||||
func (comments CommentList) loadOldMilestones(e db.Engine) error {
|
||||
func (comments CommentList) loadOldMilestones(ctx context.Context) error {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -190,7 +190,7 @@ func (comments CommentList) loadOldMilestones(e db.Engine) error {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
err := e.
|
||||
err := db.GetEngine(ctx).
|
||||
In("id", milestoneIDs[:limit]).
|
||||
Find(&milestoneMaps)
|
||||
if err != nil {
|
||||
@ -216,7 +216,7 @@ func (comments CommentList) getAssigneeIDs() []int64 {
|
||||
return container.KeysInt64(ids)
|
||||
}
|
||||
|
||||
func (comments CommentList) loadAssignees(e db.Engine) error {
|
||||
func (comments CommentList) loadAssignees(ctx context.Context) error {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -229,7 +229,7 @@ func (comments CommentList) loadAssignees(e db.Engine) error {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
rows, err := e.
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", assigneeIDs[:limit]).
|
||||
Rows(new(user_model.User))
|
||||
if err != nil {
|
||||
@ -290,7 +290,7 @@ func (comments CommentList) Issues() IssueList {
|
||||
return issueList
|
||||
}
|
||||
|
||||
func (comments CommentList) loadIssues(e db.Engine) error {
|
||||
func (comments CommentList) loadIssues(ctx context.Context) error {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -303,7 +303,7 @@ func (comments CommentList) loadIssues(e db.Engine) error {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
rows, err := e.
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", issueIDs[:limit]).
|
||||
Rows(new(Issue))
|
||||
if err != nil {
|
||||
@ -397,7 +397,7 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (comments CommentList) loadAttachments(e db.Engine) (err error) {
|
||||
func (comments CommentList) loadAttachments(ctx context.Context) (err error) {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -410,7 +410,7 @@ func (comments CommentList) loadAttachments(e db.Engine) (err error) {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
rows, err := e.Table("attachment").
|
||||
rows, err := db.GetEngine(ctx).Table("attachment").
|
||||
Join("INNER", "comment", "comment.id = attachment.comment_id").
|
||||
In("comment.id", commentsIDs[:limit]).
|
||||
Rows(new(repo_model.Attachment))
|
||||
@ -449,7 +449,7 @@ func (comments CommentList) getReviewIDs() []int64 {
|
||||
return container.KeysInt64(ids)
|
||||
}
|
||||
|
||||
func (comments CommentList) loadReviews(e db.Engine) error {
|
||||
func (comments CommentList) loadReviews(ctx context.Context) error {
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -462,7 +462,7 @@ func (comments CommentList) loadReviews(e db.Engine) error {
|
||||
if left < limit {
|
||||
limit = left
|
||||
}
|
||||
rows, err := e.
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", reviewIDs[:limit]).
|
||||
Rows(new(Review))
|
||||
if err != nil {
|
||||
@ -493,36 +493,35 @@ func (comments CommentList) loadReviews(e db.Engine) error {
|
||||
|
||||
// loadAttributes loads all attributes
|
||||
func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
|
||||
e := db.GetEngine(ctx)
|
||||
if err = comments.loadPosters(e); err != nil {
|
||||
if err = comments.loadPosters(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = comments.loadLabels(e); err != nil {
|
||||
if err = comments.loadLabels(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = comments.loadMilestones(e); err != nil {
|
||||
if err = comments.loadMilestones(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = comments.loadOldMilestones(e); err != nil {
|
||||
if err = comments.loadOldMilestones(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = comments.loadAssignees(e); err != nil {
|
||||
if err = comments.loadAssignees(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = comments.loadAttachments(e); err != nil {
|
||||
if err = comments.loadAttachments(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = comments.loadReviews(e); err != nil {
|
||||
if err = comments.loadReviews(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = comments.loadIssues(e); err != nil {
|
||||
if err = comments.loadIssues(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -541,15 +540,15 @@ func (comments CommentList) LoadAttributes() error {
|
||||
|
||||
// LoadAttachments loads attachments
|
||||
func (comments CommentList) LoadAttachments() error {
|
||||
return comments.loadAttachments(db.GetEngine(db.DefaultContext))
|
||||
return comments.loadAttachments(db.DefaultContext)
|
||||
}
|
||||
|
||||
// LoadPosters loads posters
|
||||
func (comments CommentList) LoadPosters() error {
|
||||
return comments.loadPosters(db.GetEngine(db.DefaultContext))
|
||||
return comments.loadPosters(db.DefaultContext)
|
||||
}
|
||||
|
||||
// LoadIssues loads issues of comments
|
||||
func (comments CommentList) LoadIssues() error {
|
||||
return comments.loadIssues(db.GetEngine(db.DefaultContext))
|
||||
return comments.loadIssues(db.DefaultContext)
|
||||
}
|
||||
|
@ -42,10 +42,9 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
|
||||
return err
|
||||
}
|
||||
defer committer.Close()
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
// Check if it aleready exists
|
||||
exists, err := issueDepExists(sess, issue.ID, dep.ID)
|
||||
exists, err := issueDepExists(ctx, issue.ID, dep.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -53,7 +52,7 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
|
||||
return ErrDependencyExists{issue.ID, dep.ID}
|
||||
}
|
||||
// And if it would be circular
|
||||
circular, err := issueDepExists(sess, dep.ID, issue.ID)
|
||||
circular, err := issueDepExists(ctx, dep.ID, issue.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -114,8 +113,8 @@ func RemoveIssueDependency(user *user_model.User, issue, dep *Issue, depType Dep
|
||||
}
|
||||
|
||||
// Check if the dependency already exists
|
||||
func issueDepExists(e db.Engine, issueID, depID int64) (bool, error) {
|
||||
return e.Where("(issue_id = ? AND dependency_id = ?)", issueID, depID).Exist(&IssueDependency{})
|
||||
func issueDepExists(ctx context.Context, issueID, depID int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("(issue_id = ? AND dependency_id = ?)", issueID, depID).Exist(&IssueDependency{})
|
||||
}
|
||||
|
||||
// IssueNoDependenciesLeft checks if issue can be closed
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user