Penultimate round of db.DefaultContext refactor (#27414)

Part of #27065

---------

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
JakobDev 2023-10-11 06:24:07 +02:00 committed by GitHub
parent 50166d1f7c
commit ebe803e514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
136 changed files with 428 additions and 421 deletions

View File

@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error {
return err
}
authSources, err := auth_model.Sources()
authSources, err := auth_model.Sources(ctx)
if err != nil {
return err
}
@ -100,7 +100,7 @@ func runDeleteAuth(c *cli.Context) error {
return err
}
source, err := auth_model.GetSourceByID(c.Int64("id"))
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil {
return err
}

View File

@ -17,9 +17,9 @@ import (
type (
authService struct {
initDB func(ctx context.Context) error
createAuthSource func(*auth.Source) error
updateAuthSource func(*auth.Source) error
getAuthSourceByID func(id int64) (*auth.Source, error)
createAuthSource func(context.Context, *auth.Source) error
updateAuthSource func(context.Context, *auth.Source) error
getAuthSourceByID func(ctx context.Context, id int64) (*auth.Source, error)
}
)
@ -289,12 +289,12 @@ func findLdapSecurityProtocolByName(name string) (ldap.SecurityProtocol, bool) {
// getAuthSource gets the login source by its id defined in the command line flags.
// It returns an error if the id is not set, does not match any source or if the source is not of expected type.
func (a *authService) getAuthSource(c *cli.Context, authType auth.Type) (*auth.Source, error) {
func (a *authService) getAuthSource(ctx context.Context, c *cli.Context, authType auth.Type) (*auth.Source, error) {
if err := argsSet(c, "id"); err != nil {
return nil, err
}
authSource, err := a.getAuthSourceByID(c.Int64("id"))
authSource, err := a.getAuthSourceByID(ctx, c.Int64("id"))
if err != nil {
return nil, err
}
@ -332,7 +332,7 @@ func (a *authService) addLdapBindDn(c *cli.Context) error {
return err
}
return a.createAuthSource(authSource)
return a.createAuthSource(ctx, authSource)
}
// updateLdapBindDn updates a new LDAP via Bind DN authentication source.
@ -344,7 +344,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
return err
}
authSource, err := a.getAuthSource(c, auth.LDAP)
authSource, err := a.getAuthSource(ctx, c, auth.LDAP)
if err != nil {
return err
}
@ -354,7 +354,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
return err
}
return a.updateAuthSource(authSource)
return a.updateAuthSource(ctx, authSource)
}
// addLdapSimpleAuth adds a new LDAP (simple auth) authentication source.
@ -383,7 +383,7 @@ func (a *authService) addLdapSimpleAuth(c *cli.Context) error {
return err
}
return a.createAuthSource(authSource)
return a.createAuthSource(ctx, authSource)
}
// updateLdapBindDn updates a new LDAP (simple auth) authentication source.
@ -395,7 +395,7 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
return err
}
authSource, err := a.getAuthSource(c, auth.DLDAP)
authSource, err := a.getAuthSource(ctx, c, auth.DLDAP)
if err != nil {
return err
}
@ -405,5 +405,5 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
return err
}
return a.updateAuthSource(authSource)
return a.updateAuthSource(ctx, authSource)
}

View File

@ -210,15 +210,15 @@ func TestAddLdapBindDn(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
createdAuthSource = authSource
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
return nil, nil
},
@ -441,15 +441,15 @@ func TestAddLdapSimpleAuth(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
createdAuthSource = authSource
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
return nil, nil
},
@ -896,15 +896,15 @@ func TestUpdateLdapBindDn(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call createAuthSource", n)
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
updatedAuthSource = authSource
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
if c.id != 0 {
assert.Equal(t, c.id, id, "case %d: wrong id", n)
}
@ -1286,15 +1286,15 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call createAuthSource", n)
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
updatedAuthSource = authSource
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
if c.id != 0 {
assert.Equal(t, c.id, id, "case %d: wrong id", n)
}

View File

@ -183,7 +183,7 @@ func runAddOauth(c *cli.Context) error {
}
}
return auth_model.CreateSource(&auth_model.Source{
return auth_model.CreateSource(ctx, &auth_model.Source{
Type: auth_model.OAuth2,
Name: c.String("name"),
IsActive: true,
@ -203,7 +203,7 @@ func runUpdateOauth(c *cli.Context) error {
return err
}
source, err := auth_model.GetSourceByID(c.Int64("id"))
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil {
return err
}
@ -294,5 +294,5 @@ func runUpdateOauth(c *cli.Context) error {
oAuth2Config.CustomURLMapping = customURLMapping
source.Cfg = oAuth2Config
return auth_model.UpdateSource(source)
return auth_model.UpdateSource(ctx, source)
}

View File

@ -156,7 +156,7 @@ func runAddSMTP(c *cli.Context) error {
smtpConfig.Auth = "PLAIN"
}
return auth_model.CreateSource(&auth_model.Source{
return auth_model.CreateSource(ctx, &auth_model.Source{
Type: auth_model.SMTP,
Name: c.String("name"),
IsActive: active,
@ -176,7 +176,7 @@ func runUpdateSMTP(c *cli.Context) error {
return err
}
source, err := auth_model.GetSourceByID(c.Int64("id"))
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil {
return err
}
@ -197,5 +197,5 @@ func runUpdateSMTP(c *cli.Context) error {
source.Cfg = smtpConfig
return auth_model.UpdateSource(source)
return auth_model.UpdateSource(ctx, source)
}

View File

@ -42,7 +42,7 @@ func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error {
for _, r := range runs {
runsList = append(runsList, r)
}
return runsList.LoadRepos()
return runsList.LoadRepos(ctx)
}
return nil
}

View File

@ -52,9 +52,9 @@ func (runs RunList) LoadTriggerUser(ctx context.Context) error {
return nil
}
func (runs RunList) LoadRepos() error {
func (runs RunList) LoadRepos(ctx context.Context) error {
repoIDs := runs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}

View File

@ -49,9 +49,9 @@ func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error {
return nil
}
func (schedules ScheduleList) LoadRepos() error {
func (schedules ScheduleList) LoadRepos(ctx context.Context) error {
repoIDs := schedules.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}

View File

@ -53,9 +53,9 @@ func (specs SpecList) GetRepoIDs() []int64 {
return ids.Values()
}
func (specs SpecList) LoadRepos() error {
func (specs SpecList) LoadRepos(ctx context.Context) error {
repoIDs := specs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}

View File

@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
stats.Counter.AuthSource = auth.CountSources()
stats.Counter.AuthSource = auth.CountSources(ctx)
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
stats.Counter.Label, _ = e.Count(new(issues_model.Label))

View File

@ -91,7 +91,7 @@ func addKey(ctx context.Context, key *PublicKey) (err error) {
}
// AddPublicKey adds new public key to database and authorized_keys file.
func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) {
func AddPublicKey(ctx context.Context, ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) {
log.Trace(content)
fingerprint, err := CalcFingerprint(content)
@ -99,7 +99,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
return nil, err
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
@ -136,9 +136,9 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
}
// GetPublicKeyByID returns public key by given ID.
func GetPublicKeyByID(keyID int64) (*PublicKey, error) {
func GetPublicKeyByID(ctx context.Context, keyID int64) (*PublicKey, error) {
key := new(PublicKey)
has, err := db.GetEngine(db.DefaultContext).
has, err := db.GetEngine(ctx).
ID(keyID).
Get(key)
if err != nil {
@ -180,7 +180,7 @@ func SearchPublicKeyByContentExact(ctx context.Context, content string) (*Public
}
// SearchPublicKey returns a list of public keys matching the provided arguments.
func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) {
func SearchPublicKey(ctx context.Context, uid int64, fingerprint string) ([]*PublicKey, error) {
keys := make([]*PublicKey, 0, 5)
cond := builder.NewCond()
if uid != 0 {
@ -189,12 +189,12 @@ func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) {
if fingerprint != "" {
cond = cond.And(builder.Eq{"fingerprint": fingerprint})
}
return keys, db.GetEngine(db.DefaultContext).Where(cond).Find(&keys)
return keys, db.GetEngine(ctx).Where(cond).Find(&keys)
}
// ListPublicKeys returns a list of public keys belongs to given user.
func ListPublicKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error) {
sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal)
func ListPublicKeys(ctx context.Context, uid int64, listOptions db.ListOptions) ([]*PublicKey, error) {
sess := db.GetEngine(ctx).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal)
if listOptions.Page != 0 {
sess = db.SetSessionPagination(sess, &listOptions)
@ -207,30 +207,30 @@ func ListPublicKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error)
}
// CountPublicKeys count public keys a user has
func CountPublicKeys(userID int64) (int64, error) {
sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal)
func CountPublicKeys(ctx context.Context, userID int64) (int64, error) {
sess := db.GetEngine(ctx).Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal)
return sess.Count(&PublicKey{})
}
// ListPublicKeysBySource returns a list of synchronized public keys for a given user and login source.
func ListPublicKeysBySource(uid, authSourceID int64) ([]*PublicKey, error) {
func ListPublicKeysBySource(ctx context.Context, uid, authSourceID int64) ([]*PublicKey, error) {
keys := make([]*PublicKey, 0, 5)
return keys, db.GetEngine(db.DefaultContext).
return keys, db.GetEngine(ctx).
Where("owner_id = ? AND login_source_id = ?", uid, authSourceID).
Find(&keys)
}
// UpdatePublicKeyUpdated updates public key use time.
func UpdatePublicKeyUpdated(id int64) error {
func UpdatePublicKeyUpdated(ctx context.Context, id int64) error {
// Check if key exists before update as affected rows count is unreliable
// and will return 0 affected rows if two updates are made at the same time
if cnt, err := db.GetEngine(db.DefaultContext).ID(id).Count(&PublicKey{}); err != nil {
if cnt, err := db.GetEngine(ctx).ID(id).Count(&PublicKey{}); err != nil {
return err
} else if cnt != 1 {
return ErrKeyNotExist{id}
}
_, err := db.GetEngine(db.DefaultContext).ID(id).Cols("updated_unix").Update(&PublicKey{
_, err := db.GetEngine(ctx).ID(id).Cols("updated_unix").Update(&PublicKey{
UpdatedUnix: timeutil.TimeStampNow(),
})
if err != nil {
@ -250,7 +250,7 @@ func DeletePublicKeys(ctx context.Context, keyIDs ...int64) error {
}
// PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key
func PublicKeysAreExternallyManaged(keys []*PublicKey) ([]bool, error) {
func PublicKeysAreExternallyManaged(ctx context.Context, keys []*PublicKey) ([]bool, error) {
sources := make([]*auth.Source, 0, 5)
externals := make([]bool, len(keys))
keyloop:
@ -272,7 +272,7 @@ keyloop:
if source == nil {
var err error
source, err = auth.GetSourceByID(key.LoginSourceID)
source, err = auth.GetSourceByID(ctx, key.LoginSourceID)
if err != nil {
if auth.IsErrSourceNotExist(err) {
externals[i] = false
@ -295,15 +295,15 @@ keyloop:
}
// PublicKeyIsExternallyManaged returns whether the provided KeyID represents an externally managed Key
func PublicKeyIsExternallyManaged(id int64) (bool, error) {
key, err := GetPublicKeyByID(id)
func PublicKeyIsExternallyManaged(ctx context.Context, id int64) (bool, error) {
key, err := GetPublicKeyByID(ctx, id)
if err != nil {
return false, err
}
if key.LoginSourceID == 0 {
return false, nil
}
source, err := auth.GetSourceByID(key.LoginSourceID)
source, err := auth.GetSourceByID(ctx, key.LoginSourceID)
if err != nil {
if auth.IsErrSourceNotExist(err) {
return false, nil
@ -318,9 +318,9 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) {
}
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
func deleteKeysMarkedForDeletion(ctx context.Context, keys []string) (bool, error) {
// Start session
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return false, err
}
@ -349,7 +349,7 @@ func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
}
// AddPublicKeysBySource add a users public keys. Returns true if there are changes.
func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
func AddPublicKeysBySource(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
var sshKeysNeedUpdate bool
for _, sshKey := range sshPublicKeys {
var err error
@ -368,7 +368,7 @@ func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys [
marshalled = marshalled[:len(marshalled)-1]
sshKeyName := fmt.Sprintf("%s-%s", s.Name, ssh.FingerprintSHA256(out))
if _, err := AddPublicKey(usr.ID, sshKeyName, marshalled, s.ID); err != nil {
if _, err := AddPublicKey(ctx, usr.ID, sshKeyName, marshalled, s.ID); err != nil {
if IsErrKeyAlreadyExist(err) {
log.Trace("AddPublicKeysBySource[%s]: Public SSH Key %s already exists for user", sshKeyName, usr.Name)
} else {
@ -387,14 +387,14 @@ func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys [
}
// SynchronizePublicKeys updates a users public keys. Returns true if there are changes.
func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
func SynchronizePublicKeys(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
var sshKeysNeedUpdate bool
log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name)
// Get Public Keys from DB with current LDAP source
var giteaKeys []string
keys, err := ListPublicKeysBySource(usr.ID, s.ID)
keys, err := ListPublicKeysBySource(ctx, usr.ID, s.ID)
if err != nil {
log.Error("synchronizePublicKeys[%s]: Error listing Public SSH Keys for user %s: %v", s.Name, usr.Name, err)
}
@ -429,7 +429,7 @@ func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys [
newKeys = append(newKeys, key)
}
}
if AddPublicKeysBySource(usr, s, newKeys) {
if AddPublicKeysBySource(ctx, usr, s, newKeys) {
sshKeysNeedUpdate = true
}
@ -443,7 +443,7 @@ func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys [
}
// Delete keys from DB that no longer exist in the source
needUpd, err := deleteKeysMarkedForDeletion(giteaKeysToDelete)
needUpd, err := deleteKeysMarkedForDeletion(ctx, giteaKeysToDelete)
if err != nil {
log.Error("synchronizePublicKeys[%s]: Error deleting Public Keys marked for deletion for user %s: %v", s.Name, usr.Name, err)
}

View File

@ -21,7 +21,7 @@ import (
func ParseCommitWithSSHSignature(ctx context.Context, c *git.Commit, committer *user_model.User) *CommitVerification {
// Now try to associate the signature with the committer, if present
if committer.ID != 0 {
keys, err := ListPublicKeys(committer.ID, db.ListOptions{})
keys, err := ListPublicKeys(ctx, committer.ID, db.ListOptions{})
if err != nil { // Skipping failed to get ssh keys of user
log.Error("ListPublicKeys: %v", err)
return &CommitVerification{

View File

@ -48,8 +48,8 @@ func (key *DeployKey) AfterLoad() {
}
// GetContent gets associated public key content.
func (key *DeployKey) GetContent() error {
pkey, err := GetPublicKeyByID(key.KeyID)
func (key *DeployKey) GetContent(ctx context.Context) error {
pkey, err := GetPublicKeyByID(ctx, key.KeyID)
if err != nil {
return err
}

View File

@ -5,6 +5,7 @@
package auth
import (
"context"
"fmt"
"reflect"
@ -199,8 +200,8 @@ func (source *Source) SkipVerify() bool {
// CreateSource inserts a AuthSource in the DB if not already
// existing with the given name.
func CreateSource(source *Source) error {
has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source))
func CreateSource(ctx context.Context, source *Source) error {
has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
if err != nil {
return err
} else if has {
@ -211,7 +212,7 @@ func CreateSource(source *Source) error {
source.IsSyncEnabled = false
}
_, err = db.GetEngine(db.DefaultContext).Insert(source)
_, err = db.GetEngine(ctx).Insert(source)
if err != nil {
return err
}
@ -232,7 +233,7 @@ func CreateSource(source *Source) error {
err = registerableSource.RegisterSource()
if err != nil {
// remove the AuthSource in case of errors while registering configuration
if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil {
if _, err := db.GetEngine(ctx).Delete(source); err != nil {
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@ -240,33 +241,33 @@ func CreateSource(source *Source) error {
}
// Sources returns a slice of all login sources found in DB.
func Sources() ([]*Source, error) {
func Sources(ctx context.Context) ([]*Source, error) {
auths := make([]*Source, 0, 6)
return auths, db.GetEngine(db.DefaultContext).Find(&auths)
return auths, db.GetEngine(ctx).Find(&auths)
}
// SourcesByType returns all sources of the specified type
func SourcesByType(loginType Type) ([]*Source, error) {
func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}
// AllActiveSources returns all active sources
func AllActiveSources() ([]*Source, error) {
func AllActiveSources(ctx context.Context) ([]*Source, error) {
sources := make([]*Source, 0, 5)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}
// ActiveSources returns all active sources of the specified type
func ActiveSources(tp Type) ([]*Source, error) {
func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@ -274,11 +275,11 @@ func ActiveSources(tp Type) ([]*Source, error) {
// IsSSPIEnabled returns true if there is at least one activated login
// source of type LoginSSPI
func IsSSPIEnabled() bool {
func IsSSPIEnabled(ctx context.Context) bool {
if !db.HasEngine {
return false
}
sources, err := ActiveSources(SSPI)
sources, err := ActiveSources(ctx, SSPI)
if err != nil {
log.Error("ActiveSources: %v", err)
return false
@ -287,7 +288,7 @@ func IsSSPIEnabled() bool {
}
// GetSourceByID returns login source by given ID.
func GetSourceByID(id int64) (*Source, error) {
func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
source := new(Source)
if id == 0 {
source.Cfg = registeredConfigs[NoType]()
@ -297,7 +298,7 @@ func GetSourceByID(id int64) (*Source, error) {
return source, nil
}
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
has, err := db.GetEngine(ctx).ID(id).Get(source)
if err != nil {
return nil, err
} else if !has {
@ -307,24 +308,24 @@ func GetSourceByID(id int64) (*Source, error) {
}
// UpdateSource updates a Source record in DB.
func UpdateSource(source *Source) error {
func UpdateSource(ctx context.Context, source *Source) error {
var originalSource *Source
if source.IsOAuth2() {
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
var err error
if originalSource, err = GetSourceByID(source.ID); err != nil {
if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
return err
}
}
has, err := db.GetEngine(db.DefaultContext).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
if err != nil {
return err
} else if has {
return ErrSourceAlreadyExist{source.Name}
}
_, err = db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
_, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
if err != nil {
return err
}
@ -345,7 +346,7 @@ func UpdateSource(source *Source) error {
err = registerableSource.RegisterSource()
if err != nil {
// restore original values since we cannot update the provider it self
if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil {
if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@ -353,8 +354,8 @@ func UpdateSource(source *Source) error {
}
// CountSources returns number of login sources.
func CountSources() int64 {
count, _ := db.GetEngine(db.DefaultContext).Count(new(Source))
func CountSources(ctx context.Context) int64 {
count, _ := db.GetEngine(ctx).Count(new(Source))
return count
}

View File

@ -42,7 +42,7 @@ func TestDumpAuthSource(t *testing.T) {
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
auth_model.CreateSource(&auth_model.Source{
auth_model.CreateSource(db.DefaultContext, &auth_model.Source{
Type: auth_model.OAuth2,
Name: "TestSource",
IsActive: false,

View File

@ -111,7 +111,7 @@ func findCodeComments(ctx context.Context, opts FindCommentsOptions, issue *Issu
if comment.RenderedContent, err = markdown.RenderString(&markup.RenderContext{
Ctx: ctx,
URLPrefix: issue.Repo.Link(),
Metas: issue.Repo.ComposeMetas(),
Metas: issue.Repo.ComposeMetas(ctx),
}, comment.Content); err != nil {
return nil, err
}

View File

@ -127,8 +127,8 @@ const (
)
// CreateIssueDependency creates a new dependency for an issue
func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func CreateIssueDependency(ctx context.Context, user *user_model.User, issue, dep *Issue) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -168,8 +168,8 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
}
// RemoveIssueDependency removes a dependency from an issue
func RemoveIssueDependency(user *user_model.User, issue, dep *Issue, depType DependencyType) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func RemoveIssueDependency(ctx context.Context, user *user_model.User, issue, dep *Issue, depType DependencyType) (err error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}

View File

@ -28,16 +28,16 @@ func TestCreateIssueDependency(t *testing.T) {
assert.NoError(t, err)
// Create a dependency and check if it was successful
err = issues_model.CreateIssueDependency(user1, issue1, issue2)
err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue1, issue2)
assert.NoError(t, err)
// Do it again to see if it will check if the dependency already exists
err = issues_model.CreateIssueDependency(user1, issue1, issue2)
err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue1, issue2)
assert.Error(t, err)
assert.True(t, issues_model.IsErrDependencyExists(err))
// Check for circular dependencies
err = issues_model.CreateIssueDependency(user1, issue2, issue1)
err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue2, issue1)
assert.Error(t, err)
assert.True(t, issues_model.IsErrCircularDependency(err))
@ -57,6 +57,6 @@ func TestCreateIssueDependency(t *testing.T) {
assert.True(t, left)
// Test removing the dependency
err = issues_model.RemoveIssueDependency(user1, issue1, issue2, issues_model.DependencyTypeBlockedBy)
err = issues_model.RemoveIssueDependency(db.DefaultContext, user1, issue1, issue2, issues_model.DependencyTypeBlockedBy)
assert.NoError(t, err)
}

View File

@ -83,12 +83,12 @@ func RemoveDuplicateExclusiveIssueLabels(ctx context.Context, issue *Issue, labe
}
// NewIssueLabel creates a new issue-label relation.
func NewIssueLabel(issue *Issue, label *Label, doer *user_model.User) (err error) {
if HasIssueLabel(db.DefaultContext, issue.ID, label.ID) {
func NewIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *user_model.User) (err error) {
if HasIssueLabel(ctx, issue.ID, label.ID) {
return nil
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -149,8 +149,8 @@ func newIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *us
}
// NewIssueLabels creates a list of issue-label relations.
func NewIssueLabels(issue *Issue, labels []*Label, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func NewIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -359,8 +359,8 @@ func clearIssueLabels(ctx context.Context, issue *Issue, doer *user_model.User)
// ClearIssueLabels removes all issue labels as the given user.
// Triggers appropriate WebHooks, if any.
func ClearIssueLabels(issue *Issue, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func ClearIssueLabels(ctx context.Context, issue *Issue, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -432,8 +432,8 @@ func RemoveDuplicateExclusiveLabels(labels []*Label) []*Label {
// ReplaceIssueLabels removes all current labels and add new labels to the issue.
// Triggers appropriate WebHooks, if any.
func ReplaceIssueLabels(issue *Issue, labels []*Label, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func ReplaceIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}

View File

@ -6,6 +6,7 @@ package issues_test
import (
"testing"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
@ -21,7 +22,7 @@ func TestNewIssueLabelsScope(t *testing.T) {
label2 := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{label1, label2}, doer))
assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{label1, label2}, doer))
assert.Len(t, issue.Labels, 1)
assert.Equal(t, label2.ID, issue.Labels[0].ID)

View File

@ -4,6 +4,8 @@
package issues
import (
"context"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
)
@ -17,16 +19,16 @@ type IssueLockOptions struct {
// LockIssue locks an issue. This would limit commenting abilities to
// users with write access to the repo
func LockIssue(opts *IssueLockOptions) error {
return updateIssueLock(opts, true)
func LockIssue(ctx context.Context, opts *IssueLockOptions) error {
return updateIssueLock(ctx, opts, true)
}
// UnlockIssue unlocks a previously locked issue.
func UnlockIssue(opts *IssueLockOptions) error {
return updateIssueLock(opts, false)
func UnlockIssue(ctx context.Context, opts *IssueLockOptions) error {
return updateIssueLock(ctx, opts, false)
}
func updateIssueLock(opts *IssueLockOptions, lock bool) error {
func updateIssueLock(ctx context.Context, opts *IssueLockOptions, lock bool) error {
if opts.Issue.IsLocked == lock {
return nil
}
@ -39,7 +41,7 @@ func updateIssueLock(opts *IssueLockOptions, lock bool) error {
commentType = CommentTypeUnlock
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}

View File

@ -38,11 +38,7 @@ func (issue *Issue) projectID(ctx context.Context) int64 {
}
// ProjectBoardID return project board id if issue was assigned to one
func (issue *Issue) ProjectBoardID() int64 {
return issue.projectBoardID(db.DefaultContext)
}
func (issue *Issue) projectBoardID(ctx context.Context) int64 {
func (issue *Issue) ProjectBoardID(ctx context.Context) int64 {
var ip project_model.ProjectIssue
has, err := db.GetEngine(ctx).Where("issue_id=?", issue.ID).Get(&ip)
if err != nil || !has {
@ -100,8 +96,8 @@ func LoadIssuesFromBoardList(ctx context.Context, bs project_model.BoardList) (m
}
// ChangeProjectAssign changes the project associated with an issue
func ChangeProjectAssign(issue *Issue, doer *user_model.User, newProjectID int64) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func ChangeProjectAssign(ctx context.Context, issue *Issue, doer *user_model.User, newProjectID int64) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -156,8 +152,8 @@ func addUpdateIssueProject(ctx context.Context, issue *Issue, doer *user_model.U
}
// MoveIssueAcrossProjectBoards move a card from one board to another
func MoveIssueAcrossProjectBoards(issue *Issue, board *project_model.Board) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func MoveIssueAcrossProjectBoards(ctx context.Context, issue *Issue, board *project_model.Board) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}

View File

@ -444,9 +444,9 @@ func applySubscribedCondition(sess *xorm.Session, subscriberID int64) *xorm.Sess
}
// GetRepoIDsForIssuesOptions find all repo ids for the given options
func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *user_model.User) ([]int64, error) {
func GetRepoIDsForIssuesOptions(ctx context.Context, opts *IssuesOptions, user *user_model.User) ([]int64, error) {
repoIDs := make([]int64, 0, 5)
e := db.GetEngine(db.DefaultContext)
e := db.GetEngine(ctx)
sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")

View File

@ -34,7 +34,7 @@ func TestIssue_ReplaceLabels(t *testing.T) {
for i, labelID := range labelIDs {
labels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID, RepoID: repo.ID})
}
assert.NoError(t, issues_model.ReplaceIssueLabels(issue, labels, doer))
assert.NoError(t, issues_model.ReplaceIssueLabels(db.DefaultContext, issue, labels, doer))
unittest.AssertCount(t, &issues_model.IssueLabel{IssueID: issueID}, len(expectedLabelIDs))
for _, labelID := range expectedLabelIDs {
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issueID, LabelID: labelID})
@ -122,7 +122,7 @@ func TestIssue_ClearLabels(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: test.issueID})
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID})
assert.NoError(t, issues_model.ClearIssueLabels(issue, doer))
assert.NoError(t, issues_model.ClearIssueLabels(db.DefaultContext, issue, doer))
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: test.issueID})
}
}
@ -230,7 +230,7 @@ func TestGetRepoIDsForIssuesOptions(t *testing.T) {
[]int64{1, 2},
},
} {
repoIDs, err := issues_model.GetRepoIDsForIssuesOptions(&test.Opts, user)
repoIDs, err := issues_model.GetRepoIDsForIssuesOptions(db.DefaultContext, &test.Opts, user)
assert.NoError(t, err)
if assert.Len(t, repoIDs, len(test.ExpectedRepoIDs)) {
for i, repoID := range repoIDs {

View File

@ -307,7 +307,7 @@ func TestNewIssueLabel(t *testing.T) {
// add new IssueLabel
prevNumIssues := label.NumIssues
assert.NoError(t, issues_model.NewIssueLabel(issue, label, doer))
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, label, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
Type: issues_model.CommentTypeLabel,
@ -320,7 +320,7 @@ func TestNewIssueLabel(t *testing.T) {
assert.EqualValues(t, prevNumIssues+1, label.NumIssues)
// re-add existing IssueLabel
assert.NoError(t, issues_model.NewIssueLabel(issue, label, doer))
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, label, doer))
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
}
@ -334,19 +334,19 @@ func TestNewIssueExclusiveLabel(t *testing.T) {
exclusiveLabelB := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
// coexisting regular and exclusive label
assert.NoError(t, issues_model.NewIssueLabel(issue, otherLabel, doer))
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelA, doer))
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, otherLabel, doer))
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelA, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
// exclusive label replaces existing one
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelB, doer))
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelB, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
// exclusive label replaces existing one again
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelA, doer))
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelA, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
@ -359,7 +359,7 @@ func TestNewIssueLabels(t *testing.T) {
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 5})
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{label1, label2}, doer))
assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{label1, label2}, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
Type: issues_model.CommentTypeLabel,
@ -377,7 +377,7 @@ func TestNewIssueLabels(t *testing.T) {
assert.EqualValues(t, 1, label2.NumClosedIssues)
// corner case: test empty slice
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{}, doer))
assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{}, doer))
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
}

View File

@ -58,8 +58,8 @@ func (opts GetMilestonesOption) toCond() builder.Cond {
}
// GetMilestones returns milestones filtered by GetMilestonesOption's
func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) {
sess := db.GetEngine(db.DefaultContext).Where(opts.toCond())
func GetMilestones(ctx context.Context, opts GetMilestonesOption) (MilestoneList, int64, error) {
sess := db.GetEngine(ctx).Where(opts.toCond())
if opts.Page != 0 {
sess = db.SetSessionPagination(sess, &opts)

View File

@ -40,7 +40,7 @@ func TestGetMilestonesByRepoID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
test := func(repoID int64, state api.StateType) {
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID})
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{
milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
RepoID: repo.ID,
State: state,
})
@ -77,7 +77,7 @@ func TestGetMilestonesByRepoID(t *testing.T) {
test(3, api.StateClosed)
test(3, api.StateAll)
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{
milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
RepoID: unittest.NonexistentID,
State: api.StateOpen,
})
@ -90,7 +90,7 @@ func TestGetMilestones(t *testing.T) {
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
test := func(sortType string, sortCond func(*issues_model.Milestone) int) {
for _, page := range []int{0, 1} {
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{
milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
ListOptions: db.ListOptions{
Page: page,
PageSize: setting.UI.IssuePagingNum,
@ -107,7 +107,7 @@ func TestGetMilestones(t *testing.T) {
}
assert.True(t, sort.IntsAreSorted(values))
milestones, _, err = issues_model.GetMilestones(issues_model.GetMilestonesOption{
milestones, _, err = issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
ListOptions: db.ListOptions{
Page: page,
PageSize: setting.UI.IssuePagingNum,

View File

@ -378,9 +378,9 @@ func (pr *PullRequest) GetApprovalCounts(ctx context.Context) ([]*ReviewCount, e
}
// GetApprovers returns the approvers of the pull request
func (pr *PullRequest) GetApprovers() string {
func (pr *PullRequest) GetApprovers(ctx context.Context) string {
stringBuilder := strings.Builder{}
if err := pr.getReviewedByLines(&stringBuilder); err != nil {
if err := pr.getReviewedByLines(ctx, &stringBuilder); err != nil {
log.Error("Unable to getReviewedByLines: Error: %v", err)
return ""
}
@ -388,14 +388,14 @@ func (pr *PullRequest) GetApprovers() string {
return stringBuilder.String()
}
func (pr *PullRequest) getReviewedByLines(writer io.Writer) error {
func (pr *PullRequest) getReviewedByLines(ctx context.Context, writer io.Writer) error {
maxReviewers := setting.Repository.PullRequest.DefaultMergeMessageMaxApprovers
if maxReviewers == 0 {
return nil
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -594,9 +594,9 @@ func GetUnmergedPullRequest(ctx context.Context, headRepoID, baseRepoID int64, h
// GetLatestPullRequestByHeadInfo returns the latest pull request (regardless of its status)
// by given head information (repo and branch).
func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest, error) {
func GetLatestPullRequestByHeadInfo(ctx context.Context, repoID int64, branch string) (*PullRequest, error) {
pr := new(PullRequest)
has, err := db.GetEngine(db.DefaultContext).
has, err := db.GetEngine(ctx).
Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub).
OrderBy("id DESC").
Get(pr)
@ -646,9 +646,9 @@ func GetPullRequestByID(ctx context.Context, id int64) (*PullRequest, error) {
}
// GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID.
func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error) {
func GetPullRequestByIssueIDWithNoAttributes(ctx context.Context, issueID int64) (*PullRequest, error) {
var pr PullRequest
has, err := db.GetEngine(db.DefaultContext).Where("issue_id = ?", issueID).Get(&pr)
has, err := db.GetEngine(ctx).Where("issue_id = ?", issueID).Get(&pr)
if err != nil {
return nil, err
}
@ -687,14 +687,14 @@ func GetAllUnmergedAgitPullRequestByPoster(ctx context.Context, uid int64) ([]*P
}
// Update updates all fields of pull request.
func (pr *PullRequest) Update() error {
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).AllCols().Update(pr)
func (pr *PullRequest) Update(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(pr.ID).AllCols().Update(pr)
return err
}
// UpdateCols updates specific fields of pull request.
func (pr *PullRequest) UpdateCols(cols ...string) error {
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).Cols(cols...).Update(pr)
func (pr *PullRequest) UpdateCols(ctx context.Context, cols ...string) error {
_, err := db.GetEngine(ctx).ID(pr.ID).Cols(cols...).Update(pr)
return err
}
@ -706,8 +706,8 @@ func (pr *PullRequest) UpdateColsIfNotMerged(ctx context.Context, cols ...string
// IsWorkInProgress determine if the Pull Request is a Work In Progress by its title
// Issue must be set before this method can be called.
func (pr *PullRequest) IsWorkInProgress() bool {
if err := pr.LoadIssue(db.DefaultContext); err != nil {
func (pr *PullRequest) IsWorkInProgress(ctx context.Context) bool {
if err := pr.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err)
return false
}
@ -774,8 +774,8 @@ func GetPullRequestsByHeadBranch(ctx context.Context, headBranch string, headRep
}
// GetBaseBranchLink returns the relative URL of the base branch
func (pr *PullRequest) GetBaseBranchLink() string {
if err := pr.LoadBaseRepo(db.DefaultContext); err != nil {
func (pr *PullRequest) GetBaseBranchLink(ctx context.Context) string {
if err := pr.LoadBaseRepo(ctx); err != nil {
log.Error("LoadBaseRepo: %v", err)
return ""
}
@ -786,12 +786,12 @@ func (pr *PullRequest) GetBaseBranchLink() string {
}
// GetHeadBranchLink returns the relative URL of the head branch
func (pr *PullRequest) GetHeadBranchLink() string {
func (pr *PullRequest) GetHeadBranchLink(ctx context.Context) string {
if pr.Flow == PullRequestFlowAGit {
return ""
}
if err := pr.LoadHeadRepo(db.DefaultContext); err != nil {
if err := pr.LoadHeadRepo(ctx); err != nil {
log.Error("LoadHeadRepo: %v", err)
return ""
}
@ -810,14 +810,14 @@ func UpdateAllowEdits(ctx context.Context, pr *PullRequest) error {
}
// Mergeable returns if the pullrequest is mergeable.
func (pr *PullRequest) Mergeable() bool {
func (pr *PullRequest) Mergeable(ctx context.Context) bool {
// If a pull request isn't mergable if it's:
// - Being conflict checked.
// - Has a conflict.
// - Received a error while being conflict checked.
// - Is a work-in-progress pull request.
return pr.Status != PullRequestStatusChecking && pr.Status != PullRequestStatusConflict &&
pr.Status != PullRequestStatusError && !pr.IsWorkInProgress()
pr.Status != PullRequestStatusError && !pr.IsWorkInProgress(ctx)
}
// HasEnoughApprovals returns true if pr has enough granted approvals.
@ -890,7 +890,7 @@ func MergeBlockedByOutdatedBranch(protectBranch *git_model.ProtectedBranch, pr *
func PullRequestCodeOwnersReview(ctx context.Context, pull *Issue, pr *PullRequest) error {
files := []string{"CODEOWNERS", "docs/CODEOWNERS", ".gitea/CODEOWNERS"}
if pr.IsWorkInProgress() {
if pr.IsWorkInProgress(ctx) {
return nil
}

View File

@ -213,7 +213,7 @@ func TestPullRequest_Update(t *testing.T) {
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
pr.BaseBranch = "baseBranch"
pr.HeadBranch = "headBranch"
pr.Update()
pr.Update(db.DefaultContext)
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: pr.ID})
assert.Equal(t, "baseBranch", pr.BaseBranch)
@ -228,7 +228,7 @@ func TestPullRequest_UpdateCols(t *testing.T) {
BaseBranch: "baseBranch",
HeadBranch: "headBranch",
}
assert.NoError(t, pr.UpdateCols("head_branch"))
assert.NoError(t, pr.UpdateCols(db.DefaultContext, "head_branch"))
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
assert.Equal(t, "master", pr.BaseBranch)
@ -260,13 +260,13 @@ func TestPullRequest_IsWorkInProgress(t *testing.T) {
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2})
pr.LoadIssue(db.DefaultContext)
assert.False(t, pr.IsWorkInProgress())
assert.False(t, pr.IsWorkInProgress(db.DefaultContext))
pr.Issue.Title = "WIP: " + pr.Issue.Title
assert.True(t, pr.IsWorkInProgress())
assert.True(t, pr.IsWorkInProgress(db.DefaultContext))
pr.Issue.Title = "[wip]: " + pr.Issue.Title
assert.True(t, pr.IsWorkInProgress())
assert.True(t, pr.IsWorkInProgress(db.DefaultContext))
}
func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) {
@ -334,7 +334,7 @@ func TestGetApprovers(t *testing.T) {
// Official reviews are already deduplicated. Allow unofficial reviews
// to assert that there are no duplicated approvers.
setting.Repository.PullRequest.DefaultMergeMessageOfficialApproversOnly = false
approvers := pr.GetApprovers()
approvers := pr.GetApprovers(db.DefaultContext)
expected := "Reviewed-by: User Five <user5@example.com>\nReviewed-by: Org Six <org6@example.com>\n"
assert.EqualValues(t, expected, approvers)
}

View File

@ -277,8 +277,8 @@ func UpdateRepoStats(ctx context.Context, id int64) error {
return nil
}
func updateUserStarNumbers(users []user_model.User) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func updateUserStarNumbers(ctx context.Context, users []user_model.User) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -294,19 +294,19 @@ func updateUserStarNumbers(users []user_model.User) error {
}
// DoctorUserStarNum recalculate Stars number for all user
func DoctorUserStarNum() (err error) {
func DoctorUserStarNum(ctx context.Context) (err error) {
const batchSize = 100
for start := 0; ; start += batchSize {
users := make([]user_model.User, 0, batchSize)
if err = db.GetEngine(db.DefaultContext).Limit(batchSize, start).Where("type = ?", 0).Cols("id").Find(&users); err != nil {
if err = db.GetEngine(ctx).Limit(batchSize, start).Where("type = ?", 0).Cols("id").Find(&users); err != nil {
return err
}
if len(users) == 0 {
break
}
if err = updateUserStarNumbers(users); err != nil {
if err = updateUserStarNumbers(ctx, users); err != nil {
return err
}
}

Some files were not shown because too many files have changed in this diff Show More