Move almost all functions' parameter db.Engine to context.Context (#19748)

* Move almost all functions' parameter db.Engine to context.Context
* remove some unnecessary wrap functions
This commit is contained in:
2022-05-20 22:08:52 +08:00
committed by GitHub
parent d81e31ad78
commit fd7d83ace6
232 changed files with 1463 additions and 2108 deletions

View File

@ -490,7 +490,7 @@ func runChangePassword(c *cli.Context) error {
return errors.New("The password you chose is on a list of stolen passwords previously exposed in public data breaches. Please try again with a different password.\nFor more details, see https://haveibeenpwned.com/Passwords")
}
uname := c.String("username")
user, err := user_model.GetUserByName(uname)
user, err := user_model.GetUserByName(ctx, uname)
if err != nil {
return err
}
@ -659,7 +659,7 @@ func runDeleteUser(c *cli.Context) error {
if c.IsSet("email") {
user, err = user_model.GetUserByEmail(c.String("email"))
} else if c.IsSet("username") {
user, err = user_model.GetUserByName(c.String("username"))
user, err = user_model.GetUserByName(ctx, c.String("username"))
} else {
user, err = user_model.GetUserByID(c.Int64("id"))
}
@ -689,7 +689,7 @@ func runGenerateAccessToken(c *cli.Context) error {
return err
}
user, err := user_model.GetUserByName(c.String("username"))
user, err := user_model.GetUserByName(ctx, c.String("username"))
if err != nil {
return err
}

View File

@ -33,7 +33,7 @@ func TestAPIGetTrackedTimes(t *testing.T) {
resp := session.MakeRequest(t, req, http.StatusOK)
var apiTimes api.TrackedTimeList
DecodeJSON(t, resp, &apiTimes)
expect, err := models.GetTrackedTimes(&models.FindTrackedTimesOptions{IssueID: issue2.ID})
expect, err := models.GetTrackedTimes(db.DefaultContext, &models.FindTrackedTimesOptions{IssueID: issue2.ID})
assert.NoError(t, err)
assert.Len(t, apiTimes, 3)
@ -83,7 +83,7 @@ func TestAPIDeleteTrackedTime(t *testing.T) {
session.MakeRequest(t, req, http.StatusNotFound)
// Reset time of user 2 on issue 2
trackedSeconds, err := models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
trackedSeconds, err := models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
assert.NoError(t, err)
assert.Equal(t, int64(3661), trackedSeconds)
@ -91,7 +91,7 @@ func TestAPIDeleteTrackedTime(t *testing.T) {
session.MakeRequest(t, req, http.StatusNoContent)
session.MakeRequest(t, req, http.StatusNotFound)
trackedSeconds, err = models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
trackedSeconds, err = models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
assert.NoError(t, err)
assert.Equal(t, int64(0), trackedSeconds)
}

View File

@ -388,7 +388,7 @@ func testAPIRepoMigrateConflict(t *testing.T, u *url.URL) {
defer util.RemoveAll(dstPath)
t.Run("CreateRepo", doAPICreateRepository(httpContext, false))
user, err := user_model.GetUserByName(httpContext.Username)
user, err := user_model.GetUserByName(db.DefaultContext, httpContext.Username)
assert.NoError(t, err)
userID := user.ID

View File

@ -321,7 +321,7 @@ func TestLDAPGroupTeamSyncAddMember(t *testing.T) {
addAuthSourceLDAP(t, "", "on", `{"cn=ship_crew,ou=people,dc=planetexpress,dc=com":{"org26": ["team11"]},"cn=admin_staff,ou=people,dc=planetexpress,dc=com": {"non-existent": ["non-existent"]}}`)
org, err := organization.GetOrgByName("org26")
assert.NoError(t, err)
team, err := organization.GetTeam(org.ID, "team11")
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
assert.NoError(t, err)
auth.SyncExternalUsers(context.Background(), true)
for _, gitLDAPUser := range gitLDAPUsers {
@ -366,7 +366,7 @@ func TestLDAPGroupTeamSyncRemoveMember(t *testing.T) {
addAuthSourceLDAP(t, "", "on", `{"cn=dispatch,ou=people,dc=planetexpress,dc=com": {"org26": ["team11"]}}`)
org, err := organization.GetOrgByName("org26")
assert.NoError(t, err)
team, err := organization.GetTeam(org.ID, "team11")
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
assert.NoError(t, err)
loginUserWithPassword(t, gitLDAPUsers[0].UserName, gitLDAPUsers[0].Password)
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{

View File

@ -18,6 +18,7 @@ import (
"time"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/perm"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
@ -438,7 +439,7 @@ func doProtectBranch(ctx APITestContext, branch, userToWhitelist, unprotectedFil
})
ctx.Session.MakeRequest(t, req, http.StatusSeeOther)
} else {
user, err := user_model.GetUserByName(userToWhitelist)
user, err := user_model.GetUserByName(db.DefaultContext, userToWhitelist)
assert.NoError(t, err)
// Change branch to protected
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings/branches/%s", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame), url.PathEscape(branch)), map[string]string{

View File

@ -75,7 +75,7 @@ func TestMirrorPull(t *testing.T) {
IsTag: true,
}, nil, ""))
_, err = repo_model.GetMirrorByRepoID(mirror.ID)
_, err = repo_model.GetMirrorByRepoID(ctx, mirror.ID)
assert.NoError(t, err)
ok := mirror_service.SyncPullMirror(ctx, mirror.ID)

View File

@ -18,6 +18,7 @@ import (
"time"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
@ -407,7 +408,7 @@ func TestConflictChecking(t *testing.T) {
assert.NoError(t, err)
issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "PR with conflict!"}).(*models.Issue)
conflictingPR, err := models.GetPullRequestByIssueID(issue.ID)
conflictingPR, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
assert.NoError(t, err)
// Ensure conflictedFiles is populated.

View File

@ -11,6 +11,7 @@ import (
"time"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
@ -165,7 +166,7 @@ func createOutdatedPR(t *testing.T, actor, forkOrg *user_model.User) *models.Pul
assert.NoError(t, err)
issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "Test Pull -to-update-"}).(*models.Issue)
pr, err := models.GetPullRequestByIssueID(issue.ID)
pr, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
assert.NoError(t, err)
return pr

View File

@ -222,9 +222,8 @@ func (a *Action) getCommentLink(ctx context.Context) string {
if a == nil {
return "#"
}
e := db.GetEngine(ctx)
if a.Comment == nil && a.CommentID != 0 {
a.Comment, _ = getCommentByID(e, a.CommentID)
a.Comment, _ = GetCommentByID(ctx, a.CommentID)
}
if a.Comment != nil {
return a.Comment.HTMLURL()
@ -239,7 +238,7 @@ func (a *Action) getCommentLink(ctx context.Context) string {
return "#"
}
issue, err := getIssueByID(e, issueID)
issue, err := getIssueByID(ctx, issueID)
if err != nil {
return "#"
}
@ -340,8 +339,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
return nil, err
}
e := db.GetEngine(ctx)
sess := e.Where(cond).
sess := db.GetEngine(ctx).Where(cond).
Select("`action`.*"). // this line will avoid select other joined table's columns
Join("INNER", "repository", "`repository`.id = `action`.repo_id")
@ -354,7 +352,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
return nil, fmt.Errorf("Find: %v", err)
}
if err := ActionList(actions).loadAttributes(e); err != nil {
if err := ActionList(actions).loadAttributes(ctx); err != nil {
return nil, fmt.Errorf("LoadAttributes: %v", err)
}
@ -504,7 +502,7 @@ func notifyWatchers(ctx context.Context, actions ...*Action) error {
permIssue = make([]bool, len(watchers))
permPR = make([]bool, len(watchers))
for i, watcher := range watchers {
user, err := user_model.GetUserByIDEngine(e, watcher.UserID)
user, err := user_model.GetUserByIDCtx(ctx, watcher.UserID)
if err != nil {
permCode[i] = false
permIssue[i] = false

View File

@ -5,6 +5,7 @@
package models
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
@ -26,14 +27,14 @@ func (actions ActionList) getUserIDs() []int64 {
return container.KeysInt64(userIDs)
}
func (actions ActionList) loadUsers(e db.Engine) (map[int64]*user_model.User, error) {
func (actions ActionList) loadUsers(ctx context.Context) (map[int64]*user_model.User, error) {
if len(actions) == 0 {
return nil, nil
}
userIDs := actions.getUserIDs()
userMaps := make(map[int64]*user_model.User, len(userIDs))
err := e.
err := db.GetEngine(ctx).
In("id", userIDs).
Find(&userMaps)
if err != nil {
@ -56,14 +57,14 @@ func (actions ActionList) getRepoIDs() []int64 {
return container.KeysInt64(repoIDs)
}
func (actions ActionList) loadRepositories(e db.Engine) error {
func (actions ActionList) loadRepositories(ctx context.Context) error {
if len(actions) == 0 {
return nil
}
repoIDs := actions.getRepoIDs()
repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs))
err := e.In("id", repoIDs).Find(&repoMaps)
err := db.GetEngine(ctx).In("id", repoIDs).Find(&repoMaps)
if err != nil {
return fmt.Errorf("find repository: %v", err)
}
@ -74,7 +75,7 @@ func (actions ActionList) loadRepositories(e db.Engine) error {
return nil
}
func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_model.User) (err error) {
func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*user_model.User) (err error) {
if userMap == nil {
userMap = make(map[int64]*user_model.User)
}
@ -85,7 +86,7 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
}
repoOwner, ok := userMap[action.Repo.OwnerID]
if !ok {
repoOwner, err = user_model.GetUserByID(action.Repo.OwnerID)
repoOwner, err = user_model.GetUserByIDCtx(ctx, action.Repo.OwnerID)
if err != nil {
if user_model.IsErrUserNotExist(err) {
continue
@ -101,15 +102,15 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
}
// loadAttributes loads all attributes
func (actions ActionList) loadAttributes(e db.Engine) error {
userMap, err := actions.loadUsers(e)
func (actions ActionList) loadAttributes(ctx context.Context) error {
userMap, err := actions.loadUsers(ctx)
if err != nil {
return err
}
if err := actions.loadRepositories(e); err != nil {
if err := actions.loadRepositories(ctx); err != nil {
return err
}
return actions.loadRepoOwner(e, userMap)
return actions.loadRepoOwner(ctx, userMap)
}

View File

@ -198,16 +198,16 @@ func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, erro
}
// deleteGPGKey does the actual key deletion
func deleteGPGKey(e db.Engine, keyID string) (int64, error) {
func deleteGPGKey(ctx context.Context, keyID string) (int64, error) {
if keyID == "" {
return 0, fmt.Errorf("empty KeyId forbidden") // Should never happen but just to be sure
}
// Delete imported key
n, err := e.Where("key_id=?", keyID).Delete(new(GPGKeyImport))
n, err := db.GetEngine(ctx).Where("key_id=?", keyID).Delete(new(GPGKeyImport))
if err != nil {
return n, err
}
return e.Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
return db.GetEngine(ctx).Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
}
// DeleteGPGKey deletes GPG key information in database.
@ -231,7 +231,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
}
defer committer.Close()
if _, err = deleteGPGKey(db.GetEngine(ctx), key.KeyID); err != nil {
if _, err = deleteGPGKey(ctx, key.KeyID); err != nil {
return err
}

View File

@ -5,6 +5,7 @@
package asymkey
import (
"context"
"strings"
"code.gitea.io/gitea/models/db"
@ -29,21 +30,21 @@ import (
// This file contains functions relating to adding GPG Keys
// addGPGKey add key, import and subkeys to database
func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) {
func addGPGKey(ctx context.Context, key *GPGKey, content string) (err error) {
// Add GPGKeyImport
if _, err = e.Insert(GPGKeyImport{
if err = db.Insert(ctx, &GPGKeyImport{
KeyID: key.KeyID,
Content: content,
}); err != nil {
return err
}
// Save GPG primary key.
if _, err = e.Insert(key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(e, subkey); err != nil {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
@ -51,14 +52,14 @@ func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) {
}
// addGPGSubKey add subkeys to database
func addGPGSubKey(e db.Engine, key *GPGKey) (err error) {
func addGPGSubKey(ctx context.Context, key *GPGKey) (err error) {
// Save GPG primary key.
if _, err = e.Insert(key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(e, subkey); err != nil {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
@ -158,7 +159,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}
if err = addGPGKey(db.GetEngine(ctx), key, content); err != nil {
if err = addGPGKey(ctx, key, content); err != nil {
return nil, err
}
keys = append(keys, key)

View File

@ -75,7 +75,7 @@ func (key *PublicKey) AuthorizedString() string {
return AuthorizedStringForKey(key)
}
func addKey(e db.Engine, key *PublicKey) (err error) {
func addKey(ctx context.Context, key *PublicKey) (err error) {
if len(key.Fingerprint) == 0 {
key.Fingerprint, err = calcFingerprint(key.Content)
if err != nil {
@ -84,7 +84,7 @@ func addKey(e db.Engine, key *PublicKey) (err error) {
}
// Save SSH key.
if _, err = e.Insert(key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return err
}
@ -105,14 +105,13 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
return nil, err
}
defer committer.Close()
sess := db.GetEngine(ctx)
if err := checkKeyFingerprint(sess, fingerprint); err != nil {
if err := checkKeyFingerprint(ctx, fingerprint); err != nil {
return nil, err
}
// Key name of same user cannot be duplicated.
has, err := sess.
has, err := db.GetEngine(ctx).
Where("owner_id = ? AND name = ?", ownerID, name).
Get(new(PublicKey))
if err != nil {
@ -130,7 +129,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
Type: KeyTypeUser,
LoginSourceID: authSourceID,
}
if err = addKey(sess, key); err != nil {
if err = addKey(ctx, key); err != nil {
return nil, fmt.Errorf("addKey: %v", err)
}
@ -151,29 +150,12 @@ func GetPublicKeyByID(keyID int64) (*PublicKey, error) {
return key, nil
}
func searchPublicKeyByContentWithEngine(e db.Engine, content string) (*PublicKey, error) {
key := new(PublicKey)
has, err := e.
Where("content like ?", content+"%").
Get(key)
if err != nil {
return nil, err
} else if !has {
return nil, ErrKeyNotExist{}
}
return key, nil
}
// SearchPublicKeyByContent searches content as prefix (leak e-mail part)
// and returns public key found.
func SearchPublicKeyByContent(content string) (*PublicKey, error) {
return searchPublicKeyByContentWithEngine(db.GetEngine(db.DefaultContext), content)
}
func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*PublicKey, error) {
func SearchPublicKeyByContent(ctx context.Context, content string) (*PublicKey, error) {
key := new(PublicKey)
has, err := e.
Where("content = ?", content).
has, err := db.GetEngine(ctx).
Where("content like ?", content+"%").
Get(key)
if err != nil {
return nil, err
@ -185,8 +167,17 @@ func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*Publ
// SearchPublicKeyByContentExact searches content
// and returns public key found.
func SearchPublicKeyByContentExact(content string) (*PublicKey, error) {
return searchPublicKeyByContentExactWithEngine(db.GetEngine(db.DefaultContext), content)
func SearchPublicKeyByContentExact(ctx context.Context, content string) (*PublicKey, error) {
key := new(PublicKey)
has, err := db.GetEngine(ctx).
Where("content = ?", content).
Get(key)
if err != nil {
return nil, err
} else if !has {
return nil, ErrKeyNotExist{}
}
return key, nil
}
// SearchPublicKey returns a list of public keys matching the provided arguments.
@ -335,12 +326,11 @@ func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
return false, err
}
defer committer.Close()
sess := db.GetEngine(ctx)
// Delete keys marked for deletion
var sshKeysNeedUpdate bool
for _, KeyToDelete := range keys {
key, err := searchPublicKeyByContentWithEngine(sess, KeyToDelete)
key, err := SearchPublicKeyByContent(ctx, KeyToDelete)
if err != nil {
log.Error("SearchPublicKeyByContent: %v", err)
continue

View File

@ -6,6 +6,7 @@ package asymkey
import (
"bufio"
"context"
"fmt"
"io"
"os"
@ -165,7 +166,7 @@ func RewriteAllPublicKeys() error {
}
}
if err := RegeneratePublicKeys(t); err != nil {
if err := RegeneratePublicKeys(db.DefaultContext, t); err != nil {
return err
}
@ -174,12 +175,8 @@ func RewriteAllPublicKeys() error {
}
// RegeneratePublicKeys regenerates the authorized_keys file
func RegeneratePublicKeys(t io.StringWriter) error {
return regeneratePublicKeys(db.GetEngine(db.DefaultContext), t)
}
func regeneratePublicKeys(e db.Engine, t io.StringWriter) error {
if err := e.Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
func RegeneratePublicKeys(ctx context.Context, t io.StringWriter) error {
if err := db.GetEngine(ctx).Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
_, err = t.WriteString((bean.(*PublicKey)).AuthorizedString())
return err
}); err != nil {

View File

@ -6,6 +6,7 @@ package asymkey
import (
"bufio"
"context"
"fmt"
"io"
"os"
@ -42,11 +43,7 @@ const authorizedPrincipalsFile = "authorized_principals"
// RewriteAllPrincipalKeys removes any authorized principal and rewrite all keys from database again.
// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function
// outside any session scope independently.
func RewriteAllPrincipalKeys() error {
return rewriteAllPrincipalKeys(db.GetEngine(db.DefaultContext))
}
func rewriteAllPrincipalKeys(e db.Engine) error {
func RewriteAllPrincipalKeys(ctx context.Context) error {
// Don't rewrite key if internal server
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedPrincipalsFile {
return nil
@ -92,7 +89,7 @@ func rewriteAllPrincipalKeys(e db.Engine) error {
}
}
if err := regeneratePrincipalKeys(e, t); err != nil {
if err := regeneratePrincipalKeys(ctx, t); err != nil {
return err
}
@ -100,13 +97,8 @@ func rewriteAllPrincipalKeys(e db.Engine) error {
return util.Rename(tmpPath, fPath)
}
// RegeneratePrincipalKeys regenerates the authorized_principals file
func RegeneratePrincipalKeys(t io.StringWriter) error {
return regeneratePrincipalKeys(db.GetEngine(db.DefaultContext), t)
}
func regeneratePrincipalKeys(e db.Engine, t io.StringWriter) error {
if err := e.Where("type = ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
func regeneratePrincipalKeys(ctx context.Context, t io.StringWriter) error {
if err := db.GetEngine(ctx).Where("type = ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
_, err = t.WriteString((bean.(*PublicKey)).AuthorizedString())
return err
}); err != nil {

View File

@ -67,9 +67,9 @@ func init() {
db.RegisterModel(new(DeployKey))
}
func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error {
func checkDeployKey(ctx context.Context, keyID, repoID int64, name string) error {
// Note: We want error detail, not just true or false here.
has, err := e.
has, err := db.GetEngine(ctx).
Where("key_id = ? AND repo_id = ?", keyID, repoID).
Get(new(DeployKey))
if err != nil {
@ -78,7 +78,7 @@ func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error {
return ErrDeployKeyAlreadyExist{keyID, repoID}
}
has, err = e.
has, err = db.GetEngine(ctx).
Where("repo_id = ? AND name = ?", repoID, name).
Get(new(DeployKey))
if err != nil {
@ -91,8 +91,8 @@ func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error {
}
// addDeployKey adds new key-repo relation.
func addDeployKey(e db.Engine, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) {
if err := checkDeployKey(e, keyID, repoID, name); err != nil {
func addDeployKey(ctx context.Context, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) {
if err := checkDeployKey(ctx, keyID, repoID, name); err != nil {
return nil, err
}
@ -103,8 +103,7 @@ func addDeployKey(e db.Engine, keyID, repoID int64, name, fingerprint string, mo
Fingerprint: fingerprint,
Mode: mode,
}
_, err := e.Insert(key)
return key, err
return key, db.Insert(ctx, key)
}
// HasDeployKey returns true if public key is a deploy key of given repository.
@ -133,12 +132,10 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
}
defer committer.Close()
sess := db.GetEngine(ctx)
pkey := &PublicKey{
Fingerprint: fingerprint,
}
has, err := sess.Get(pkey)
has, err := db.GetByBean(ctx, pkey)
if err != nil {
return nil, err
}
@ -153,12 +150,12 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
pkey.Type = KeyTypeDeploy
pkey.Content = content
pkey.Name = name
if err = addKey(sess, pkey); err != nil {
if err = addKey(ctx, pkey); err != nil {
return nil, fmt.Errorf("addKey: %v", err)
}
}
key, err := addDeployKey(sess, pkey.ID, repoID, name, pkey.Fingerprint, accessMode)
key, err := addDeployKey(ctx, pkey.ID, repoID, name, pkey.Fingerprint, accessMode)
if err != nil {
return nil, err
}
@ -179,16 +176,12 @@ func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) {
}
// GetDeployKeyByRepo returns deploy key by given public key ID and repository ID.
func GetDeployKeyByRepo(keyID, repoID int64) (*DeployKey, error) {
return getDeployKeyByRepo(db.GetEngine(db.DefaultContext), keyID, repoID)
}
func getDeployKeyByRepo(e db.Engine, keyID, repoID int64) (*DeployKey, error) {
func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) {
key := &DeployKey{
KeyID: keyID,
RepoID: repoID,
}
has, err := e.Get(key)
has, err := db.GetByBean(ctx, key)
if err != nil {
return nil, err
} else if !has {

View File

@ -5,6 +5,7 @@
package asymkey
import (
"context"
"errors"
"fmt"
"strings"
@ -31,8 +32,8 @@ import (
// checkKeyFingerprint only checks if key fingerprint has been used as public key,
// it is OK to use same key as deploy key for multiple repositories/users.
func checkKeyFingerprint(e db.Engine, fingerprint string) error {
has, err := e.Get(&PublicKey{
func checkKeyFingerprint(ctx context.Context, fingerprint string) error {
has, err := db.GetByBean(ctx, &PublicKey{
Fingerprint: fingerprint,
})
if err != nil {

View File

@ -31,10 +31,9 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public
return nil, err
}
defer committer.Close()
sess := db.GetEngine(ctx)
// Principals cannot be duplicated.
has, err := sess.
has, err := db.GetEngine(ctx).
Where("content = ? AND type = ?", content, KeyTypePrincipal).
Get(new(PublicKey))
if err != nil {
@ -51,7 +50,7 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public
Type: KeyTypePrincipal,
LoginSourceID: authSourceID,
}
if err = addPrincipalKey(sess, key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return nil, fmt.Errorf("addKey: %v", err)
}
@ -61,16 +60,7 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public
committer.Close()
return key, RewriteAllPrincipalKeys()
}
func addPrincipalKey(e db.Engine, key *PublicKey) (err error) {
// Save Key representing a principal.
if _, err = e.Insert(key); err != nil {
return err
}
return nil
return key, RewriteAllPrincipalKeys(db.DefaultContext)
}
// CheckPrincipalKeyString strips spaces and returns an error if the given principal contains newlines

View File

@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
}
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
}
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant
}
// CreateGrant generates a grant for an user
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
}
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
grant := &OAuth2Grant{
ApplicationID: app.ID,
UserID: userID,
Scope: scope,
}
_, err := e.Insert(grant)
err := db.Insert(ctx, grant)
if err != nil {
return nil, err
}
@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin
}
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
}
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.Where("client_id = ?", clientID).Get(app)
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
if !has {
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
}
@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap
}
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
}
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.ID(id).Get(app)
has, err := db.GetEngine(ctx).ID(id).Get(app)
if err != nil {
return nil, err
}
@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er
}
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
}
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) {
apps = make([]*OAuth2Application, 0)
err = e.Where("uid = ?", userID).Find(&apps)
err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps)
return
}
@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct {
}
// CreateOAuth2Application inserts a new oauth2 application
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
}
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
clientID := uuid.New().String()
app := &OAuth2Application{
UID: opts.UserID,
@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (
ClientID: clientID,
RedirectURIs: opts.RedirectURIs,
}
if _, err := e.Insert(app); err != nil {
if err := db.Insert(ctx, app); err != nil {
return nil, err
}
return app, nil
@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return nil, err
}
defer committer.Close()
sess := db.GetEngine(ctx)
app, err := getOAuth2ApplicationByID(sess, opts.ID)
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
if err != nil {
return nil, err
}
@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
app.Name = opts.Name
app.RedirectURIs = opts.RedirectURIs
if err = updateOAuth2Application(sess, app); err != nil {
if err = updateOAuth2Application(ctx, app); err != nil {
return nil, err
}
app.ClientSecret = ""
@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return app, committer.Commit()
}
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
if _, err := e.ID(app.ID).Update(app); err != nil {
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil {
return err
}
return nil
}
func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
sess := db.GetEngine(ctx)
if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
return err
} else if deleted == 0 {
@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error {
return err
}
defer committer.Close()
if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
if err := deleteOAuth2Application(ctx, id, userid); err != nil {
return err
}
return committer.Commit()
@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect
}
// Invalidate deletes the auth code from the database to invalidate this code
func (code *OAuth2AuthorizationCode) Invalidate() error {
return code.invalidate(db.GetEngine(db.DefaultContext))
}
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
_, err := e.Delete(code)
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
return err
}
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
return code.validateCodeChallenge(verifier)
}
func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
switch code.CodeChallengeMethod {
case "S256":
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool
}
// GetOAuth2AuthorizationByCode returns an authorization by its code
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
}
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
auth = new(OAuth2AuthorizationCode)
if has, err := e.Where("code = ?", code).Get(auth); err != nil {
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
return nil, err
} else if !has {
return nil, nil
}
auth.Grant = new(OAuth2Grant)
if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string {
}
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
}
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
rBytes, err := util.CryptoRandomBytes(32)
if err != nil {
return &OAuth2AuthorizationCode{}, err
@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}
if _, err := e.Insert(code); err != nil {
if err := db.Insert(ctx, code); err != nil {
return nil, err
}
return code, nil
}
// IncreaseCounter increases the counter and updates the grant
func (grant *OAuth2Grant) IncreaseCounter() error {
return grant.increaseCount(db.GetEngine(db.DefaultContext))
}
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
if err != nil {
return err
}
updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
if err != nil {
return err
}
@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool {
}
// SetNonce updates the current nonce value of a grant
func (grant *OAuth2Grant) SetNonce(nonce string) error {
return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
}
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
grant.Nonce = nonce
_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
if err != nil {
return err
}
@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
}
// GetOAuth2GrantByID returns the grant with the given ID
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
}
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.ID(id).Get(grant); err != nil {
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
}
// GetOAuth2GrantsByUserID lists all grants of a certain user
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
}
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
type joinedOAuth2Grant struct {
Grant *OAuth2Grant `xorm:"extends"`
Application *OAuth2Application `xorm:"extends"`
}
var results *xorm.Rows
var err error
if results, err = e.
if results, err = db.GetEngine(ctx).
Table("oauth2_grant").
Where("user_id = ?", uid).
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
}
// RevokeOAuth2Grant deletes the grant with grantID and userID
func RevokeOAuth2Grant(grantID, userID int64) error {
return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
}
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
_, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID})
return err
}

View File

@ -7,6 +7,7 @@ package auth
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
@ -52,18 +53,18 @@ func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138")
app, err := GetOAuth2ApplicationByClientID(db.DefaultContext, "da7da3ba-9a13-4167-856f-3899de0b0138")
assert.NoError(t, err)
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
app, err = GetOAuth2ApplicationByClientID("invalid client id")
app, err = GetOAuth2ApplicationByClientID(db.DefaultContext, "invalid client id")
assert.Error(t, err)
assert.Nil(t, app)
}
func TestCreateOAuth2Application(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
app, err := CreateOAuth2Application(db.DefaultContext, CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
assert.NoError(t, err)
assert.Equal(t, "newapp", app.Name)
assert.Len(t, app.ClientID, 36)
@ -77,11 +78,11 @@ func TestOAuth2Application_TableName(t *testing.T) {
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
grant, err := app.GetGrantByUserID(1)
grant, err := app.GetGrantByUserID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.UserID)
grant, err = app.GetGrantByUserID(34923458)
grant, err = app.GetGrantByUserID(db.DefaultContext, 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
@ -89,7 +90,7 @@ func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
func TestOAuth2Application_CreateGrant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
grant, err := app.CreateGrant(2, "")
grant, err := app.CreateGrant(db.DefaultContext, 2, "")
assert.NoError(t, err)
assert.NotNil(t, grant)
assert.Equal(t, int64(2), grant.UserID)
@ -101,11 +102,11 @@ func TestOAuth2Application_CreateGrant(t *testing.T) {
func TestGetOAuth2GrantByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant, err := GetOAuth2GrantByID(1)
grant, err := GetOAuth2GrantByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.ID)
grant, err = GetOAuth2GrantByID(34923458)
grant, err = GetOAuth2GrantByID(db.DefaultContext, 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
@ -113,7 +114,7 @@ func TestGetOAuth2GrantByID(t *testing.T) {
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
assert.NoError(t, grant.IncreaseCounter())
assert.NoError(t, grant.IncreaseCounter(db.DefaultContext))
assert.Equal(t, int64(2), grant.Counter)
unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
}
@ -130,7 +131,7 @@ func TestOAuth2Grant_ScopeContains(t *testing.T) {
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
code, err := grant.GenerateNewAuthorizationCode(db.DefaultContext, "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.True(t, len(code.Code) > 32) // secret length > 32
@ -142,20 +143,20 @@ func TestOAuth2Grant_TableName(t *testing.T) {
func TestGetOAuth2GrantsByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
result, err := GetOAuth2GrantsByUserID(1)
result, err := GetOAuth2GrantsByUserID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, int64(1), result[0].ID)
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
result, err = GetOAuth2GrantsByUserID(34134)
result, err = GetOAuth2GrantsByUserID(db.DefaultContext, 34134)
assert.NoError(t, err)
assert.Empty(t, result)
}
func TestRevokeOAuth2Grant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.NoError(t, RevokeOAuth2Grant(1, 1))
assert.NoError(t, RevokeOAuth2Grant(db.DefaultContext, 1, 1))
unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
}
@ -163,13 +164,13 @@ func TestRevokeOAuth2Grant(t *testing.T) {
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
code, err := GetOAuth2AuthorizationByCode("authcode")
code, err := GetOAuth2AuthorizationByCode(db.DefaultContext, "authcode")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.Equal(t, "authcode", code.Code)
assert.Equal(t, int64(1), code.ID)
code, err = GetOAuth2AuthorizationByCode("does not exist")
code, err = GetOAuth2AuthorizationByCode(db.DefaultContext, "does not exist")
assert.NoError(t, err)
assert.Nil(t, code)
}
@ -224,7 +225,7 @@ func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
assert.NoError(t, code.Invalidate())
assert.NoError(t, code.Invalidate(db.DefaultContext))
unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
}

View File

@ -306,13 +306,9 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa
}
// GetProtectedBranchBy getting protected branch by ID/Name
func GetProtectedBranchBy(repoID int64, branchName string) (*ProtectedBranch, error) {
return getProtectedBranchBy(db.GetEngine(db.DefaultContext), repoID, branchName)
}
func getProtectedBranchBy(e db.Engine, repoID int64, branchName string) (*ProtectedBranch, error) {
func GetProtectedBranchBy(ctx context.Context, repoID int64, branchName string) (*ProtectedBranch, error) {
rel := &ProtectedBranch{RepoID: repoID, BranchName: branchName}
has, err := e.Get(rel)
has, err := db.GetByBean(ctx, rel)
if err != nil {
return nil, err
}
@ -632,7 +628,7 @@ func RenameBranch(repo *repo_model.Repository, from, to string, gitAction func(i
}
// 2. Update protected branch if needed
protectedBranch, err := getProtectedBranchBy(sess, repo.ID, from)
protectedBranch, err := GetProtectedBranchBy(ctx, repo.ID, from)
if err != nil {
return err
}

View File

@ -49,21 +49,21 @@ func init() {
}
// upsertCommitStatusIndex the function will not return until it acquires the lock or receives an error.
func upsertCommitStatusIndex(e db.Engine, repoID int64, sha string) (err error) {
func upsertCommitStatusIndex(ctx context.Context, repoID int64, sha string) (err error) {
// An atomic UPSERT operation (INSERT/UPDATE) is the only operation
// that ensures that the key is actually locked.
switch {
case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL:
_, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
_, err = db.Exec(ctx, "INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
"VALUES (?,?,1) ON CONFLICT (repo_id,sha) DO UPDATE SET max_index = `commit_status_index`.max_index+1",
repoID, sha)
case setting.Database.UseMySQL:
_, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
_, err = db.Exec(ctx, "INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+
"VALUES (?,?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1",
repoID, sha)
case setting.Database.UseMSSQL:
// https://weblogs.sqlteam.com/dang/2009/01/31/upsert-race-condition-with-merge/
_, err = e.Exec("MERGE `commit_status_index` WITH (HOLDLOCK) as target "+
_, err = db.Exec(ctx, "MERGE `commit_status_index` WITH (HOLDLOCK) as target "+
"USING (SELECT ? AS repo_id, ? AS sha) AS src "+
"ON src.repo_id = target.repo_id AND src.sha = target.sha "+
"WHEN MATCHED THEN UPDATE SET target.max_index = target.max_index+1 "+
@ -100,17 +100,17 @@ func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) {
defer commiter.Close()
var preIdx int64
_, err = ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?", repoID, sha).Get(&preIdx)
_, err = db.GetEngine(ctx).SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?", repoID, sha).Get(&preIdx)
if err != nil {
return 0, err
}
if err := upsertCommitStatusIndex(ctx.Engine(), repoID, sha); err != nil {
if err := upsertCommitStatusIndex(ctx, repoID, sha); err != nil {
return 0, err
}
var curIdx int64
has, err := ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ? AND max_index=?", repoID, sha, preIdx+1).Get(&curIdx)
has, err := db.GetEngine(ctx).SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ? AND max_index=?", repoID, sha, preIdx+1).Get(&curIdx)
if err != nil {
return 0, err
}
@ -131,7 +131,7 @@ func (status *CommitStatus) loadAttributes(ctx context.Context) (err error) {
}
}
if status.Creator == nil && status.CreatorID > 0 {
status.Creator, err = user_model.GetUserByIDEngine(db.GetEngine(ctx), status.CreatorID)
status.Creator, err = user_model.GetUserByIDCtx(ctx, status.CreatorID)
if err != nil {
return fmt.Errorf("getUserByID [%d]: %v", status.CreatorID, err)
}
@ -231,12 +231,7 @@ type CommitStatusIndex struct {
}
// GetLatestCommitStatus returns all statuses with a unique context for a given commit.
func GetLatestCommitStatus(repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
return GetLatestCommitStatusCtx(db.DefaultContext, repoID, sha, listOptions)
}
// GetLatestCommitStatusCtx returns all statuses with a unique context for a given commit.
func GetLatestCommitStatusCtx(ctx context.Context, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
func GetLatestCommitStatus(ctx context.Context, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
ids := make([]int64, 0, 10)
sess := db.GetEngine(ctx).Table(&CommitStatus{}).
Where("repo_id = ?", repoID).And("sha = ?", sha).
@ -341,7 +336,7 @@ func ParseCommitsWithStatus(oldCommits []*asymkey_model.SignCommit, repo *repo_m
commit := &SignCommitWithStatuses{
SignCommit: c,
}
statuses, _, err := GetLatestCommitStatus(repo.ID, commit.ID.String(), db.ListOptions{})
statuses, _, err := GetLatestCommitStatus(db.DefaultContext, repo.ID, commit.ID.String(), db.ListOptions{})
if err != nil {
log.Error("GetLatestCommitStatus: %v", err)
} else {

View File

@ -117,7 +117,7 @@ func DeleteOrphanedIssues() error {
var attachmentPaths []string
for i := range ids {
paths, err := deleteIssuesByRepoID(db.GetEngine(ctx), ids[i])
paths, err := deleteIssuesByRepoID(ctx, ids[i])
if err != nil {
return err
}

View File

@ -5,6 +5,7 @@
package db
import (
"context"
"errors"
"fmt"
@ -74,8 +75,8 @@ func GetNextResourceIndex(tableName string, groupID int64) (int64, error) {
}
// DeleteResouceIndex delete resource index
func DeleteResouceIndex(e Engine, tableName string, groupID int64) error {
_, err := e.Exec(fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
func DeleteResouceIndex(ctx context.Context, tableName string, groupID int64) error {
_, err := Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -25,20 +25,15 @@ func init() {
}
// LoadAssignees load assignees of this issue.
func (issue *Issue) LoadAssignees() error {
return issue.loadAssignees(db.GetEngine(db.DefaultContext))
}
// This loads all assignees of an issue
func (issue *Issue) loadAssignees(e db.Engine) (err error) {
func (issue *Issue) LoadAssignees(ctx context.Context) (err error) {
// Reset maybe preexisting assignees
issue.Assignees = []*user_model.User{}
issue.Assignee = nil
err = e.Table("`user`").
err = db.GetEngine(ctx).Table("`user`").
Join("INNER", "issue_assignees", "assignee_id = `user`.id").
Where("issue_assignees.issue_id = ?", issue.ID).
Find(&issue.Assignees)
if err != nil {
return err
}
@ -47,7 +42,6 @@ func (issue *Issue) loadAssignees(e db.Engine) (err error) {
if len(issue.Assignees) > 0 {
issue.Assignee = issue.Assignees[0]
}
return
}
@ -63,33 +57,9 @@ func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) {
Find(&userIDs)
}
// GetAssigneesByIssue returns everyone assigned to that issue
func GetAssigneesByIssue(issue *Issue) (assignees []*user_model.User, err error) {
return getAssigneesByIssue(db.GetEngine(db.DefaultContext), issue)
}
func getAssigneesByIssue(e db.Engine, issue *Issue) (assignees []*user_model.User, err error) {
err = issue.loadAssignees(e)
if err != nil {
return assignees, err
}
return issue.Assignees, nil
}
// IsUserAssignedToIssue returns true when the user is assigned to the issue
func IsUserAssignedToIssue(issue *Issue, user *user_model.User) (isAssigned bool, err error) {
return isUserAssignedToIssue(db.GetEngine(db.DefaultContext), issue, user)
}
func isUserAssignedToIssue(e db.Engine, issue *Issue, user *user_model.User) (isAssigned bool, err error) {
return e.Get(&IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID})
}
// ClearAssigneeByUserID deletes all assignments of an user
func clearAssigneeByUserID(sess db.Engine, userID int64) (err error) {
_, err = sess.Delete(&IssueAssignees{AssigneeID: userID})
return
func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.User) (isAssigned bool, err error) {
return db.GetByBean(ctx, &IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID})
}
// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
@ -113,8 +83,7 @@ func ToggleIssueAssignee(issue *Issue, doer *user_model.User, assigneeID int64)
}
func toggleIssueAssignee(ctx context.Context, issue *Issue, doer *user_model.User, assigneeID int64, isCreate bool) (removed bool, comment *Comment, err error) {
sess := db.GetEngine(ctx)
removed, err = toggleUserAssignee(sess, issue, assigneeID)
removed, err = toggleUserAssignee(ctx, issue, assigneeID)
if err != nil {
return false, nil, fmt.Errorf("UpdateIssueUserByAssignee: %v", err)
}
@ -147,39 +116,38 @@ func toggleIssueAssignee(ctx context.Context, issue *Issue, doer *user_model.Use
}
// toggles user assignee state in database
func toggleUserAssignee(e db.Engine, issue *Issue, assigneeID int64) (removed bool, err error) {
func toggleUserAssignee(ctx context.Context, issue *Issue, assigneeID int64) (removed bool, err error) {
// Check if the user exists
assignee, err := user_model.GetUserByIDEngine(e, assigneeID)
assignee, err := user_model.GetUserByIDCtx(ctx, assigneeID)
if err != nil {
return false, err
}
// Check if the submitted user is already assigned, if yes delete him otherwise add him
var i int
for i = 0; i < len(issue.Assignees); i++ {
found := false
i := 0
for ; i < len(issue.Assignees); i++ {
if issue.Assignees[i].ID == assigneeID {
found = true
break
}
}
assigneeIn := IssueAssignees{AssigneeID: assigneeID, IssueID: issue.ID}
toBeDeleted := i < len(issue.Assignees)
if toBeDeleted {
issue.Assignees = append(issue.Assignees[:i], issue.Assignees[i:]...)
_, err = e.Delete(assigneeIn)
if found {
issue.Assignees = append(issue.Assignees[:i], issue.Assignees[i+1:]...)
_, err = db.DeleteByBean(ctx, &assigneeIn)
if err != nil {
return toBeDeleted, err
return found, err
}
} else {
issue.Assignees = append(issue.Assignees, assignee)
_, err = e.Insert(assigneeIn)
if err != nil {
return toBeDeleted, err
if err = db.Insert(ctx, &assigneeIn); err != nil {
return found, err
}
}
return toBeDeleted, nil
return found, nil
}
// MakeIDsFromAPIAssigneesToAdd returns an array with all assignee IDs

View File

@ -7,6 +7,7 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
@ -37,28 +38,28 @@ func TestUpdateAssignee(t *testing.T) {
assert.NoError(t, err)
// Check if he got removed
isAssigned, err := IsUserAssignedToIssue(issue, user1)
isAssigned, err := IsUserAssignedToIssue(db.DefaultContext, issue, user1)
assert.NoError(t, err)
assert.False(t, isAssigned)
// Check if they're all there
assignees, err := GetAssigneesByIssue(issue)
err = issue.LoadAssignees(db.DefaultContext)
assert.NoError(t, err)
var expectedAssignees []*user_model.User
expectedAssignees = append(expectedAssignees, user2, user3)
for in, assignee := range assignees {
for in, assignee := range issue.Assignees {
assert.Equal(t, assignee.ID, expectedAssignees[in].ID)
}
// Check if the user is assigned
isAssigned, err = IsUserAssignedToIssue(issue, user2)
isAssigned, err = IsUserAssignedToIssue(db.DefaultContext, issue, user2)
assert.NoError(t, err)
assert.True(t, isAssigned)
// This user should not be assigned
isAssigned, err = IsUserAssignedToIssue(issue, &user_model.User{ID: 4})
isAssigned, err = IsUserAssignedToIssue(db.DefaultContext, issue, &user_model.User{ID: 4})
assert.NoError(t, err)
assert.False(t, isAssigned)
}

View File

@ -298,7 +298,7 @@ func (c *Comment) LoadIssueCtx(ctx context.Context) (err error) {
if c.Issue != nil {
return nil
}
c.Issue, err = getIssueByID(db.GetEngine(ctx), c.IssueID)
c.Issue, err = getIssueByID(ctx, c.IssueID)
return
}
@ -329,12 +329,12 @@ func (c *Comment) AfterLoad(session *xorm.Session) {
}
}
func (c *Comment) loadPoster(e db.Engine) (err error) {
func (c *Comment) loadPoster(ctx context.Context) (err error) {
if c.PosterID <= 0 || c.Poster != nil {
return nil
}
c.Poster, err = user_model.GetUserByIDEngine(e, c.PosterID)
c.Poster, err = user_model.GetUserByIDCtx(ctx, c.PosterID)
if err != nil {
if user_model.IsErrUserNotExist(err) {
c.PosterID = -1
@ -525,7 +525,7 @@ func (c *Comment) LoadMilestone() error {
// LoadPoster loads comment poster
func (c *Comment) LoadPoster() error {
return c.loadPoster(db.GetEngine(db.DefaultContext))
return c.loadPoster(db.DefaultContext)
}
// LoadAttachments loads attachments (it never returns error, the error during `GetAttachmentsByCommentIDCtx` is ignored)
@ -535,7 +535,7 @@ func (c *Comment) LoadAttachments() error {
}
var err error
c.Attachments, err = repo_model.GetAttachmentsByCommentIDCtx(db.DefaultContext, c.ID)
c.Attachments, err = repo_model.GetAttachmentsByCommentID(db.DefaultContext, c.ID)
if err != nil {
log.Error("getAttachmentsByCommentID[%d]: %v", c.ID, err)
}
@ -557,7 +557,7 @@ func (c *Comment) UpdateAttachments(uuids []string) error {
for i := 0; i < len(attachments); i++ {
attachments[i].IssueID = c.IssueID
attachments[i].CommentID = c.ID
if err := repo_model.UpdateAttachmentCtx(ctx, attachments[i]); err != nil {
if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil {
return fmt.Errorf("update attachment [id: %d]: %v", attachments[i].ID, err)
}
}
@ -590,7 +590,7 @@ func (c *Comment) LoadAssigneeUserAndTeam() error {
}
if c.Issue.Repo.Owner.IsOrganization() {
c.AssigneeTeam, err = organization.GetTeamByID(c.AssigneeTeamID)
c.AssigneeTeam, err = organization.GetTeamByID(db.DefaultContext, c.AssigneeTeamID)
if err != nil && !organization.IsErrTeamNotExist(err) {
return err
}
@ -624,7 +624,7 @@ func (c *Comment) LoadDepIssueDetails() (err error) {
if c.DependentIssueID <= 0 || c.DependentIssue != nil {
return nil
}
c.DependentIssue, err = getIssueByID(db.GetEngine(db.DefaultContext), c.DependentIssueID)
c.DependentIssue, err = getIssueByID(db.DefaultContext, c.DependentIssueID)
return err
}
@ -661,9 +661,9 @@ func (c *Comment) LoadReactions(repo *repo_model.Repository) error {
return c.loadReactions(db.DefaultContext, repo)
}
func (c *Comment) loadReview(e db.Engine) (err error) {
func (c *Comment) loadReview(ctx context.Context) (err error) {
if c.Review == nil {
if c.Review, err = getReviewByID(e, c.ReviewID); err != nil {
if c.Review, err = GetReviewByID(ctx, c.ReviewID); err != nil {
return err
}
}
@ -673,7 +673,7 @@ func (c *Comment) loadReview(e db.Engine) (err error) {
// LoadReview loads the associated review
func (c *Comment) LoadReview() error {
return c.loadReview(db.GetEngine(db.DefaultContext))
return c.loadReview(db.DefaultContext)
}
var notEnoughLines = regexp.MustCompile(`fatal: file .* has only \d+ lines?`)
@ -830,13 +830,12 @@ func CreateCommentCtx(ctx context.Context, opts *CreateCommentOptions) (_ *Comme
}
func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment *Comment) (err error) {
e := db.GetEngine(ctx)
// Check comment type.
switch opts.Type {
case CommentTypeCode:
if comment.ReviewID != 0 {
if comment.Review == nil {
if err := comment.loadReview(e); err != nil {
if err := comment.loadReview(ctx); err != nil {
return err
}
}
@ -846,7 +845,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment
}
fallthrough
case CommentTypeComment:
if _, err = e.Exec("UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil {
if _, err = db.Exec(ctx, "UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil {
return err
}
fallthrough
@ -861,7 +860,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment
attachments[i].IssueID = opts.Issue.ID
attachments[i].CommentID = comment.ID
// No assign value could be 0, so ignore AllCols().
if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil {
if _, err = db.GetEngine(ctx).ID(attachments[i].ID).Update(attachments[i]); err != nil {
return fmt.Errorf("update attachment [%d]: %v", attachments[i].ID, err)
}
}
@ -1031,13 +1030,9 @@ func CreateRefComment(doer *user_model.User, repo *repo_model.Repository, issue
}
// GetCommentByID returns the comment by given ID.
func GetCommentByID(id int64) (*Comment, error) {
return getCommentByID(db.GetEngine(db.DefaultContext), id)
}
func getCommentByID(e db.Engine, id int64) (*Comment, error) {
func GetCommentByID(ctx context.Context, id int64) (*Comment, error) {
c := new(Comment)
has, err := e.ID(id).Get(c)
has, err := db.GetEngine(ctx).ID(id).Get(c)
if err != nil {
return nil, err
} else if !has {
@ -1088,9 +1083,10 @@ func (opts *FindCommentsOptions) toConds() builder.Cond {
return cond
}
func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) {
// FindComments returns all comments according options
func FindComments(ctx context.Context, opts *FindCommentsOptions) ([]*Comment, error) {
comments := make([]*Comment, 0, 10)
sess := e.Where(opts.toConds())
sess := db.GetEngine(ctx).Where(opts.toConds())
if opts.RepoID > 0 {
sess.Join("INNER", "issue", "issue.id = comment.issue_id")
}
@ -1107,11 +1103,6 @@ func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) {
Find(&comments)
}
// FindComments returns all comments according options
func FindComments(opts *FindCommentsOptions) ([]*Comment, error) {
return findComments(db.GetEngine(db.DefaultContext), opts)
}
// CountComments count all comments according options by ignoring pagination
func CountComments(opts *FindCommentsOptions) (int64, error) {
sess := db.GetEngine(db.DefaultContext).Where(opts.toConds())
@ -1167,7 +1158,7 @@ func deleteComment(ctx context.Context, comment *Comment) error {
return err
}
if _, err := e.Delete(&issues_model.ContentHistory{
if _, err := db.DeleteByBean(ctx, &issues_model.ContentHistory{
CommentID: comment.ID,
}); err != nil {
return err
@ -1182,7 +1173,7 @@ func deleteComment(ctx context.Context, comment *Comment) error {
return err
}
if err := comment.neuterCrossReferences(e); err != nil {
if err := comment.neuterCrossReferences(ctx); err != nil {
return err
}
@ -1192,7 +1183,8 @@ func deleteComment(ctx context.Context, comment *Comment) error {
// CodeComments represents comments on code by using this structure: FILENAME -> LINE (+ == proposed; - == previous) -> COMMENTS
type CodeComments map[string]map[int64][]*Comment
func fetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) {
// FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line
func FetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) {
return fetchCodeCommentsByReview(ctx, issue, currentUser, nil)
}
@ -1242,7 +1234,7 @@ func findCodeComments(ctx context.Context, opts FindCommentsOptions, issue *Issu
return nil, err
}
if err := CommentList(comments).loadPosters(e); err != nil {
if err := CommentList(comments).loadPosters(ctx); err != nil {
return nil, err
}
@ -1302,11 +1294,6 @@ func FetchCodeCommentsByLine(ctx context.Context, issue *Issue, currentUser *use
return findCodeComments(ctx, opts, issue, currentUser, nil)
}
// FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line
func FetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) {
return fetchCodeComments(ctx, issue, currentUser)
}
// UpdateCommentsMigrationsByType updates comments' migrations information via given git service type and original id and poster id
func UpdateCommentsMigrationsByType(tp structs.GitServiceType, originalAuthorID string, posterID int64) error {
_, err := db.GetEngine(db.DefaultContext).Table("comment").

View File

@ -27,7 +27,7 @@ func (comments CommentList) getPosterIDs() []int64 {
return container.KeysInt64(posterIDs)
}
func (comments CommentList) loadPosters(e db.Engine) error {
func (comments CommentList) loadPosters(ctx context.Context) error {
if len(comments) == 0 {
return nil
}
@ -40,7 +40,7 @@ func (comments CommentList) loadPosters(e db.Engine) error {
if left < limit {
limit = left
}
err := e.
err := db.GetEngine(ctx).
In("id", posterIDs[:limit]).
Find(&posterMaps)
if err != nil {
@ -80,7 +80,7 @@ func (comments CommentList) getLabelIDs() []int64 {
return container.KeysInt64(ids)
}
func (comments CommentList) loadLabels(e db.Engine) error {
func (comments CommentList) loadLabels(ctx context.Context) error {
if len(comments) == 0 {
return nil
}
@ -93,7 +93,7 @@ func (comments CommentList) loadLabels(e db.Engine) error {
if left < limit {
limit = left
}
rows, err := e.
rows, err := db.GetEngine(ctx).
In("id", labelIDs[:limit]).
Rows(new(Label))
if err != nil {
@ -130,7 +130,7 @@ func (comments CommentList) getMilestoneIDs() []int64 {
return container.KeysInt64(ids)
}
func (comments CommentList) loadMilestones(e db.Engine) error {
func (comments CommentList) loadMilestones(ctx context.Context) error {
if len(comments) == 0 {
return nil
}
@ -147,7 +147,7 @@ func (comments CommentList) loadMilestones(e db.Engine) error {
if left < limit {
limit = left
}
err := e.
err := db.GetEngine(ctx).
In("id", milestoneIDs[:limit]).
Find(&milestoneMaps)
if err != nil {
@ -173,7 +173,7 @@ func (comments CommentList) getOldMilestoneIDs() []int64 {
return container.KeysInt64(ids)
}
func (comments CommentList) loadOldMilestones(e db.Engine) error {
func (comments CommentList) loadOldMilestones(ctx context.Context) error {
if len(comments) == 0 {
return nil
}
@ -190,7 +190,7 @@ func (comments CommentList) loadOldMilestones(e db.Engine) error {
if left < limit {
limit = left
}
err := e.
err := db.GetEngine(ctx).
In("id", milestoneIDs[:limit]).
Find(&milestoneMaps)
if err != nil {
@ -216,7 +216,7 @@ func (comments CommentList) getAssigneeIDs() []int64 {
return container.KeysInt64(ids)
}
func (comments CommentList) loadAssignees(e db.Engine) error {
func (comments CommentList) loadAssignees(ctx context.Context) error {
if len(comments) == 0 {
return nil
}
@ -229,7 +229,7 @@ func (comments CommentList) loadAssignees(e db.Engine) error {
if left < limit {
limit = left
}
rows, err := e.
rows, err := db.GetEngine(ctx).
In("id", assigneeIDs[:limit]).
Rows(new(user_model.User))
if err != nil {
@ -290,7 +290,7 @@ func (comments CommentList) Issues() IssueList {
return issueList
}
func (comments CommentList) loadIssues(e db.Engine) error {
func (comments CommentList) loadIssues(ctx context.Context) error {
if len(comments) == 0 {
return nil
}
@ -303,7 +303,7 @@ func (comments CommentList) loadIssues(e db.Engine) error {
if left < limit {
limit = left
}
rows, err := e.
rows, err := db.GetEngine(ctx).
In("id", issueIDs[:limit]).
Rows(new(Issue))
if err != nil {
@ -397,7 +397,7 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error {
return nil
}
func (comments CommentList) loadAttachments(e db.Engine) (err error) {
func (comments CommentList) loadAttachments(ctx context.Context) (err error) {
if len(comments) == 0 {
return nil
}
@ -410,7 +410,7 @@ func (comments CommentList) loadAttachments(e db.Engine) (err error) {
if left < limit {
limit = left
}
rows, err := e.Table("attachment").
rows, err := db.GetEngine(ctx).Table("attachment").
Join("INNER", "comment", "comment.id = attachment.comment_id").
In("comment.id", commentsIDs[:limit]).
Rows(new(repo_model.Attachment))
@ -449,7 +449,7 @@ func (comments CommentList) getReviewIDs() []int64 {
return container.KeysInt64(ids)
}
func (comments CommentList) loadReviews(e db.Engine) error {
func (comments CommentList) loadReviews(ctx context.Context) error {
if len(comments) == 0 {
return nil
}
@ -462,7 +462,7 @@ func (comments CommentList) loadReviews(e db.Engine) error {
if left < limit {
limit = left
}
rows, err := e.
rows, err := db.GetEngine(ctx).
In("id", reviewIDs[:limit]).
Rows(new(Review))
if err != nil {
@ -493,36 +493,35 @@ func (comments CommentList) loadReviews(e db.Engine) error {
// loadAttributes loads all attributes
func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
e := db.GetEngine(ctx)
if err = comments.loadPosters(e); err != nil {
if err = comments.loadPosters(ctx); err != nil {
return
}
if err = comments.loadLabels(e); err != nil {
if err = comments.loadLabels(ctx); err != nil {
return
}
if err = comments.loadMilestones(e); err != nil {
if err = comments.loadMilestones(ctx); err != nil {
return
}
if err = comments.loadOldMilestones(e); err != nil {
if err = comments.loadOldMilestones(ctx); err != nil {
return
}
if err = comments.loadAssignees(e); err != nil {
if err = comments.loadAssignees(ctx); err != nil {
return
}
if err = comments.loadAttachments(e); err != nil {
if err = comments.loadAttachments(ctx); err != nil {
return
}
if err = comments.loadReviews(e); err != nil {
if err = comments.loadReviews(ctx); err != nil {
return
}
if err = comments.loadIssues(e); err != nil {
if err = comments.loadIssues(ctx); err != nil {
return
}
@ -541,15 +540,15 @@ func (comments CommentList) LoadAttributes() error {
// LoadAttachments loads attachments
func (comments CommentList) LoadAttachments() error {
return comments.loadAttachments(db.GetEngine(db.DefaultContext))
return comments.loadAttachments(db.DefaultContext)
}
// LoadPosters loads posters
func (comments CommentList) LoadPosters() error {
return comments.loadPosters(db.GetEngine(db.DefaultContext))
return comments.loadPosters(db.DefaultContext)
}
// LoadIssues loads issues of comments
func (comments CommentList) LoadIssues() error {
return comments.loadIssues(db.GetEngine(db.DefaultContext))
return comments.loadIssues(db.DefaultContext)
}

View File

@ -42,10 +42,9 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
return err
}
defer committer.Close()
sess := db.GetEngine(ctx)
// Check if it aleready exists
exists, err := issueDepExists(sess, issue.ID, dep.ID)
exists, err := issueDepExists(ctx, issue.ID, dep.ID)
if err != nil {
return err
}
@ -53,7 +52,7 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
return ErrDependencyExists{issue.ID, dep.ID}
}
// And if it would be circular
circular, err := issueDepExists(sess, dep.ID, issue.ID)
circular, err := issueDepExists(ctx, dep.ID, issue.ID)
if err != nil {
return err
}
@ -114,8 +113,8 @@ func RemoveIssueDependency(user *user_model.User, issue, dep *Issue, depType Dep
}
// Check if the dependency already exists
func issueDepExists(e db.Engine, issueID, depID int64) (bool, error) {
return e.Where("(issue_id = ? AND dependency_id = ?)", issueID, depID).Exist(&IssueDependency{})
func issueDepExists(ctx context.Context, issueID, depID int64) (bool, error) {
return db.GetEngine(ctx).Where("(issue_id = ? AND dependency_id = ?)", issueID, depID).Exist(&IssueDependency{})
}
// IssueNoDependenciesLeft checks if issue can be closed

Some files were not shown because too many files have changed in this diff Show More