Penultimate round of db.DefaultContext refactor (#27414)

Part of #27065

---------

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
JakobDev
2023-10-11 06:24:07 +02:00
committed by GitHub
parent 50166d1f7c
commit ebe803e514
136 changed files with 428 additions and 421 deletions

View File

@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error {
return err return err
} }
authSources, err := auth_model.Sources() authSources, err := auth_model.Sources(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -100,7 +100,7 @@ func runDeleteAuth(c *cli.Context) error {
return err return err
} }
source, err := auth_model.GetSourceByID(c.Int64("id")) source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil { if err != nil {
return err return err
} }

View File

@ -17,9 +17,9 @@ import (
type ( type (
authService struct { authService struct {
initDB func(ctx context.Context) error initDB func(ctx context.Context) error
createAuthSource func(*auth.Source) error createAuthSource func(context.Context, *auth.Source) error
updateAuthSource func(*auth.Source) error updateAuthSource func(context.Context, *auth.Source) error
getAuthSourceByID func(id int64) (*auth.Source, error) getAuthSourceByID func(ctx context.Context, id int64) (*auth.Source, error)
} }
) )
@ -289,12 +289,12 @@ func findLdapSecurityProtocolByName(name string) (ldap.SecurityProtocol, bool) {
// getAuthSource gets the login source by its id defined in the command line flags. // getAuthSource gets the login source by its id defined in the command line flags.
// It returns an error if the id is not set, does not match any source or if the source is not of expected type. // It returns an error if the id is not set, does not match any source or if the source is not of expected type.
func (a *authService) getAuthSource(c *cli.Context, authType auth.Type) (*auth.Source, error) { func (a *authService) getAuthSource(ctx context.Context, c *cli.Context, authType auth.Type) (*auth.Source, error) {
if err := argsSet(c, "id"); err != nil { if err := argsSet(c, "id"); err != nil {
return nil, err return nil, err
} }
authSource, err := a.getAuthSourceByID(c.Int64("id")) authSource, err := a.getAuthSourceByID(ctx, c.Int64("id"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -332,7 +332,7 @@ func (a *authService) addLdapBindDn(c *cli.Context) error {
return err return err
} }
return a.createAuthSource(authSource) return a.createAuthSource(ctx, authSource)
} }
// updateLdapBindDn updates a new LDAP via Bind DN authentication source. // updateLdapBindDn updates a new LDAP via Bind DN authentication source.
@ -344,7 +344,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
return err return err
} }
authSource, err := a.getAuthSource(c, auth.LDAP) authSource, err := a.getAuthSource(ctx, c, auth.LDAP)
if err != nil { if err != nil {
return err return err
} }
@ -354,7 +354,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
return err return err
} }
return a.updateAuthSource(authSource) return a.updateAuthSource(ctx, authSource)
} }
// addLdapSimpleAuth adds a new LDAP (simple auth) authentication source. // addLdapSimpleAuth adds a new LDAP (simple auth) authentication source.
@ -383,7 +383,7 @@ func (a *authService) addLdapSimpleAuth(c *cli.Context) error {
return err return err
} }
return a.createAuthSource(authSource) return a.createAuthSource(ctx, authSource)
} }
// updateLdapBindDn updates a new LDAP (simple auth) authentication source. // updateLdapBindDn updates a new LDAP (simple auth) authentication source.
@ -395,7 +395,7 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
return err return err
} }
authSource, err := a.getAuthSource(c, auth.DLDAP) authSource, err := a.getAuthSource(ctx, c, auth.DLDAP)
if err != nil { if err != nil {
return err return err
} }
@ -405,5 +405,5 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
return err return err
} }
return a.updateAuthSource(authSource) return a.updateAuthSource(ctx, authSource)
} }

View File

@ -210,15 +210,15 @@ func TestAddLdapBindDn(t *testing.T) {
initDB: func(context.Context) error { initDB: func(context.Context) error {
return nil return nil
}, },
createAuthSource: func(authSource *auth.Source) error { createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
createdAuthSource = authSource createdAuthSource = authSource
return nil return nil
}, },
updateAuthSource: func(authSource *auth.Source) error { updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call updateAuthSource", n) assert.FailNow(t, "case %d: should not call updateAuthSource", n)
return nil return nil
}, },
getAuthSourceByID: func(id int64) (*auth.Source, error) { getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n) assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
return nil, nil return nil, nil
}, },
@ -441,15 +441,15 @@ func TestAddLdapSimpleAuth(t *testing.T) {
initDB: func(context.Context) error { initDB: func(context.Context) error {
return nil return nil
}, },
createAuthSource: func(authSource *auth.Source) error { createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
createdAuthSource = authSource createdAuthSource = authSource
return nil return nil
}, },
updateAuthSource: func(authSource *auth.Source) error { updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call updateAuthSource", n) assert.FailNow(t, "case %d: should not call updateAuthSource", n)
return nil return nil
}, },
getAuthSourceByID: func(id int64) (*auth.Source, error) { getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n) assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
return nil, nil return nil, nil
}, },
@ -896,15 +896,15 @@ func TestUpdateLdapBindDn(t *testing.T) {
initDB: func(context.Context) error { initDB: func(context.Context) error {
return nil return nil
}, },
createAuthSource: func(authSource *auth.Source) error { createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call createAuthSource", n) assert.FailNow(t, "case %d: should not call createAuthSource", n)
return nil return nil
}, },
updateAuthSource: func(authSource *auth.Source) error { updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
updatedAuthSource = authSource updatedAuthSource = authSource
return nil return nil
}, },
getAuthSourceByID: func(id int64) (*auth.Source, error) { getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
if c.id != 0 { if c.id != 0 {
assert.Equal(t, c.id, id, "case %d: wrong id", n) assert.Equal(t, c.id, id, "case %d: wrong id", n)
} }
@ -1286,15 +1286,15 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
initDB: func(context.Context) error { initDB: func(context.Context) error {
return nil return nil
}, },
createAuthSource: func(authSource *auth.Source) error { createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call createAuthSource", n) assert.FailNow(t, "case %d: should not call createAuthSource", n)
return nil return nil
}, },
updateAuthSource: func(authSource *auth.Source) error { updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
updatedAuthSource = authSource updatedAuthSource = authSource
return nil return nil
}, },
getAuthSourceByID: func(id int64) (*auth.Source, error) { getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
if c.id != 0 { if c.id != 0 {
assert.Equal(t, c.id, id, "case %d: wrong id", n) assert.Equal(t, c.id, id, "case %d: wrong id", n)
} }

View File

@ -183,7 +183,7 @@ func runAddOauth(c *cli.Context) error {
} }
} }
return auth_model.CreateSource(&auth_model.Source{ return auth_model.CreateSource(ctx, &auth_model.Source{
Type: auth_model.OAuth2, Type: auth_model.OAuth2,
Name: c.String("name"), Name: c.String("name"),
IsActive: true, IsActive: true,
@ -203,7 +203,7 @@ func runUpdateOauth(c *cli.Context) error {
return err return err
} }
source, err := auth_model.GetSourceByID(c.Int64("id")) source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil { if err != nil {
return err return err
} }
@ -294,5 +294,5 @@ func runUpdateOauth(c *cli.Context) error {
oAuth2Config.CustomURLMapping = customURLMapping oAuth2Config.CustomURLMapping = customURLMapping
source.Cfg = oAuth2Config source.Cfg = oAuth2Config
return auth_model.UpdateSource(source) return auth_model.UpdateSource(ctx, source)
} }

View File

@ -156,7 +156,7 @@ func runAddSMTP(c *cli.Context) error {
smtpConfig.Auth = "PLAIN" smtpConfig.Auth = "PLAIN"
} }
return auth_model.CreateSource(&auth_model.Source{ return auth_model.CreateSource(ctx, &auth_model.Source{
Type: auth_model.SMTP, Type: auth_model.SMTP,
Name: c.String("name"), Name: c.String("name"),
IsActive: active, IsActive: active,
@ -176,7 +176,7 @@ func runUpdateSMTP(c *cli.Context) error {
return err return err
} }
source, err := auth_model.GetSourceByID(c.Int64("id")) source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil { if err != nil {
return err return err
} }
@ -197,5 +197,5 @@ func runUpdateSMTP(c *cli.Context) error {
source.Cfg = smtpConfig source.Cfg = smtpConfig
return auth_model.UpdateSource(source) return auth_model.UpdateSource(ctx, source)
} }

View File

@ -42,7 +42,7 @@ func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error {
for _, r := range runs { for _, r := range runs {
runsList = append(runsList, r) runsList = append(runsList, r)
} }
return runsList.LoadRepos() return runsList.LoadRepos(ctx)
} }
return nil return nil
} }

View File

@ -52,9 +52,9 @@ func (runs RunList) LoadTriggerUser(ctx context.Context) error {
return nil return nil
} }
func (runs RunList) LoadRepos() error { func (runs RunList) LoadRepos(ctx context.Context) error {
repoIDs := runs.GetRepoIDs() repoIDs := runs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs) repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil { if err != nil {
return err return err
} }

View File

@ -49,9 +49,9 @@ func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error {
return nil return nil
} }
func (schedules ScheduleList) LoadRepos() error { func (schedules ScheduleList) LoadRepos(ctx context.Context) error {
repoIDs := schedules.GetRepoIDs() repoIDs := schedules.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs) repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil { if err != nil {
return err return err
} }

View File

@ -53,9 +53,9 @@ func (specs SpecList) GetRepoIDs() []int64 {
return ids.Values() return ids.Values()
} }
func (specs SpecList) LoadRepos() error { func (specs SpecList) LoadRepos(ctx context.Context) error {
repoIDs := specs.GetRepoIDs() repoIDs := specs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs) repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil { if err != nil {
return err return err
} }

View File

@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
stats.Counter.Follow, _ = e.Count(new(user_model.Follow)) stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror)) stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
stats.Counter.Release, _ = e.Count(new(repo_model.Release)) stats.Counter.Release, _ = e.Count(new(repo_model.Release))
stats.Counter.AuthSource = auth.CountSources() stats.Counter.AuthSource = auth.CountSources(ctx)
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook)) stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone)) stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
stats.Counter.Label, _ = e.Count(new(issues_model.Label)) stats.Counter.Label, _ = e.Count(new(issues_model.Label))

View File

@ -91,7 +91,7 @@ func addKey(ctx context.Context, key *PublicKey) (err error) {
} }
// AddPublicKey adds new public key to database and authorized_keys file. // AddPublicKey adds new public key to database and authorized_keys file.
func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) { func AddPublicKey(ctx context.Context, ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) {
log.Trace(content) log.Trace(content)
fingerprint, err := CalcFingerprint(content) fingerprint, err := CalcFingerprint(content)
@ -99,7 +99,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
return nil, err return nil, err
} }
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -136,9 +136,9 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
} }
// GetPublicKeyByID returns public key by given ID. // GetPublicKeyByID returns public key by given ID.
func GetPublicKeyByID(keyID int64) (*PublicKey, error) { func GetPublicKeyByID(ctx context.Context, keyID int64) (*PublicKey, error) {
key := new(PublicKey) key := new(PublicKey)
has, err := db.GetEngine(db.DefaultContext). has, err := db.GetEngine(ctx).
ID(keyID). ID(keyID).
Get(key) Get(key)
if err != nil { if err != nil {
@ -180,7 +180,7 @@ func SearchPublicKeyByContentExact(ctx context.Context, content string) (*Public
} }
// SearchPublicKey returns a list of public keys matching the provided arguments. // SearchPublicKey returns a list of public keys matching the provided arguments.
func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) { func SearchPublicKey(ctx context.Context, uid int64, fingerprint string) ([]*PublicKey, error) {
keys := make([]*PublicKey, 0, 5) keys := make([]*PublicKey, 0, 5)
cond := builder.NewCond() cond := builder.NewCond()
if uid != 0 { if uid != 0 {
@ -189,12 +189,12 @@ func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) {
if fingerprint != "" { if fingerprint != "" {
cond = cond.And(builder.Eq{"fingerprint": fingerprint}) cond = cond.And(builder.Eq{"fingerprint": fingerprint})
} }
return keys, db.GetEngine(db.DefaultContext).Where(cond).Find(&keys) return keys, db.GetEngine(ctx).Where(cond).Find(&keys)
} }
// ListPublicKeys returns a list of public keys belongs to given user. // ListPublicKeys returns a list of public keys belongs to given user.
func ListPublicKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error) { func ListPublicKeys(ctx context.Context, uid int64, listOptions db.ListOptions) ([]*PublicKey, error) {
sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal) sess := db.GetEngine(ctx).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal)
if listOptions.Page != 0 { if listOptions.Page != 0 {
sess = db.SetSessionPagination(sess, &listOptions) sess = db.SetSessionPagination(sess, &listOptions)
@ -207,30 +207,30 @@ func ListPublicKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error)
} }
// CountPublicKeys count public keys a user has // CountPublicKeys count public keys a user has
func CountPublicKeys(userID int64) (int64, error) { func CountPublicKeys(ctx context.Context, userID int64) (int64, error) {
sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal) sess := db.GetEngine(ctx).Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal)
return sess.Count(&PublicKey{}) return sess.Count(&PublicKey{})
} }
// ListPublicKeysBySource returns a list of synchronized public keys for a given user and login source. // ListPublicKeysBySource returns a list of synchronized public keys for a given user and login source.
func ListPublicKeysBySource(uid, authSourceID int64) ([]*PublicKey, error) { func ListPublicKeysBySource(ctx context.Context, uid, authSourceID int64) ([]*PublicKey, error) {
keys := make([]*PublicKey, 0, 5) keys := make([]*PublicKey, 0, 5)
return keys, db.GetEngine(db.DefaultContext). return keys, db.GetEngine(ctx).
Where("owner_id = ? AND login_source_id = ?", uid, authSourceID). Where("owner_id = ? AND login_source_id = ?", uid, authSourceID).
Find(&keys) Find(&keys)
} }
// UpdatePublicKeyUpdated updates public key use time. // UpdatePublicKeyUpdated updates public key use time.
func UpdatePublicKeyUpdated(id int64) error { func UpdatePublicKeyUpdated(ctx context.Context, id int64) error {
// Check if key exists before update as affected rows count is unreliable // Check if key exists before update as affected rows count is unreliable
// and will return 0 affected rows if two updates are made at the same time // and will return 0 affected rows if two updates are made at the same time
if cnt, err := db.GetEngine(db.DefaultContext).ID(id).Count(&PublicKey{}); err != nil { if cnt, err := db.GetEngine(ctx).ID(id).Count(&PublicKey{}); err != nil {
return err return err
} else if cnt != 1 { } else if cnt != 1 {
return ErrKeyNotExist{id} return ErrKeyNotExist{id}
} }
_, err := db.GetEngine(db.DefaultContext).ID(id).Cols("updated_unix").Update(&PublicKey{ _, err := db.GetEngine(ctx).ID(id).Cols("updated_unix").Update(&PublicKey{
UpdatedUnix: timeutil.TimeStampNow(), UpdatedUnix: timeutil.TimeStampNow(),
}) })
if err != nil { if err != nil {
@ -250,7 +250,7 @@ func DeletePublicKeys(ctx context.Context, keyIDs ...int64) error {
} }
// PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key // PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key
func PublicKeysAreExternallyManaged(keys []*PublicKey) ([]bool, error) { func PublicKeysAreExternallyManaged(ctx context.Context, keys []*PublicKey) ([]bool, error) {
sources := make([]*auth.Source, 0, 5) sources := make([]*auth.Source, 0, 5)
externals := make([]bool, len(keys)) externals := make([]bool, len(keys))
keyloop: keyloop:
@ -272,7 +272,7 @@ keyloop:
if source == nil { if source == nil {
var err error var err error
source, err = auth.GetSourceByID(key.LoginSourceID) source, err = auth.GetSourceByID(ctx, key.LoginSourceID)
if err != nil { if err != nil {
if auth.IsErrSourceNotExist(err) { if auth.IsErrSourceNotExist(err) {
externals[i] = false externals[i] = false
@ -295,15 +295,15 @@ keyloop:
} }
// PublicKeyIsExternallyManaged returns whether the provided KeyID represents an externally managed Key // PublicKeyIsExternallyManaged returns whether the provided KeyID represents an externally managed Key
func PublicKeyIsExternallyManaged(id int64) (bool, error) { func PublicKeyIsExternallyManaged(ctx context.Context, id int64) (bool, error) {
key, err := GetPublicKeyByID(id) key, err := GetPublicKeyByID(ctx, id)
if err != nil { if err != nil {
return false, err return false, err
} }
if key.LoginSourceID == 0 { if key.LoginSourceID == 0 {
return false, nil return false, nil
} }
source, err := auth.GetSourceByID(key.LoginSourceID) source, err := auth.GetSourceByID(ctx, key.LoginSourceID)
if err != nil { if err != nil {
if auth.IsErrSourceNotExist(err) { if auth.IsErrSourceNotExist(err) {
return false, nil return false, nil
@ -318,9 +318,9 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) {
} }
// deleteKeysMarkedForDeletion returns true if ssh keys needs update // deleteKeysMarkedForDeletion returns true if ssh keys needs update
func deleteKeysMarkedForDeletion(keys []string) (bool, error) { func deleteKeysMarkedForDeletion(ctx context.Context, keys []string) (bool, error) {
// Start session // Start session
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -349,7 +349,7 @@ func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
} }
// AddPublicKeysBySource add a users public keys. Returns true if there are changes. // AddPublicKeysBySource add a users public keys. Returns true if there are changes.
func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool { func AddPublicKeysBySource(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
var sshKeysNeedUpdate bool var sshKeysNeedUpdate bool
for _, sshKey := range sshPublicKeys { for _, sshKey := range sshPublicKeys {
var err error var err error
@ -368,7 +368,7 @@ func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys [
marshalled = marshalled[:len(marshalled)-1] marshalled = marshalled[:len(marshalled)-1]
sshKeyName := fmt.Sprintf("%s-%s", s.Name, ssh.FingerprintSHA256(out)) sshKeyName := fmt.Sprintf("%s-%s", s.Name, ssh.FingerprintSHA256(out))
if _, err := AddPublicKey(usr.ID, sshKeyName, marshalled, s.ID); err != nil { if _, err := AddPublicKey(ctx, usr.ID, sshKeyName, marshalled, s.ID); err != nil {
if IsErrKeyAlreadyExist(err) { if IsErrKeyAlreadyExist(err) {
log.Trace("AddPublicKeysBySource[%s]: Public SSH Key %s already exists for user", sshKeyName, usr.Name) log.Trace("AddPublicKeysBySource[%s]: Public SSH Key %s already exists for user", sshKeyName, usr.Name)
} else { } else {
@ -387,14 +387,14 @@ func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys [
} }
// SynchronizePublicKeys updates a users public keys. Returns true if there are changes. // SynchronizePublicKeys updates a users public keys. Returns true if there are changes.
func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool { func SynchronizePublicKeys(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
var sshKeysNeedUpdate bool var sshKeysNeedUpdate bool
log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name) log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name)
// Get Public Keys from DB with current LDAP source // Get Public Keys from DB with current LDAP source
var giteaKeys []string var giteaKeys []string
keys, err := ListPublicKeysBySource(usr.ID, s.ID) keys, err := ListPublicKeysBySource(ctx, usr.ID, s.ID)
if err != nil { if err != nil {
log.Error("synchronizePublicKeys[%s]: Error listing Public SSH Keys for user %s: %v", s.Name, usr.Name, err) log.Error("synchronizePublicKeys[%s]: Error listing Public SSH Keys for user %s: %v", s.Name, usr.Name, err)
} }
@ -429,7 +429,7 @@ func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys [
newKeys = append(newKeys, key) newKeys = append(newKeys, key)
} }
} }
if AddPublicKeysBySource(usr, s, newKeys) { if AddPublicKeysBySource(ctx, usr, s, newKeys) {
sshKeysNeedUpdate = true sshKeysNeedUpdate = true
} }
@ -443,7 +443,7 @@ func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys [
} }
// Delete keys from DB that no longer exist in the source // Delete keys from DB that no longer exist in the source
needUpd, err := deleteKeysMarkedForDeletion(giteaKeysToDelete) needUpd, err := deleteKeysMarkedForDeletion(ctx, giteaKeysToDelete)
if err != nil { if err != nil {
log.Error("synchronizePublicKeys[%s]: Error deleting Public Keys marked for deletion for user %s: %v", s.Name, usr.Name, err) log.Error("synchronizePublicKeys[%s]: Error deleting Public Keys marked for deletion for user %s: %v", s.Name, usr.Name, err)
} }

View File

@ -21,7 +21,7 @@ import (
func ParseCommitWithSSHSignature(ctx context.Context, c *git.Commit, committer *user_model.User) *CommitVerification { func ParseCommitWithSSHSignature(ctx context.Context, c *git.Commit, committer *user_model.User) *CommitVerification {
// Now try to associate the signature with the committer, if present // Now try to associate the signature with the committer, if present
if committer.ID != 0 { if committer.ID != 0 {
keys, err := ListPublicKeys(committer.ID, db.ListOptions{}) keys, err := ListPublicKeys(ctx, committer.ID, db.ListOptions{})
if err != nil { // Skipping failed to get ssh keys of user if err != nil { // Skipping failed to get ssh keys of user
log.Error("ListPublicKeys: %v", err) log.Error("ListPublicKeys: %v", err)
return &CommitVerification{ return &CommitVerification{

View File

@ -48,8 +48,8 @@ func (key *DeployKey) AfterLoad() {
} }
// GetContent gets associated public key content. // GetContent gets associated public key content.
func (key *DeployKey) GetContent() error { func (key *DeployKey) GetContent(ctx context.Context) error {
pkey, err := GetPublicKeyByID(key.KeyID) pkey, err := GetPublicKeyByID(ctx, key.KeyID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,6 +5,7 @@
package auth package auth
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
@ -199,8 +200,8 @@ func (source *Source) SkipVerify() bool {
// CreateSource inserts a AuthSource in the DB if not already // CreateSource inserts a AuthSource in the DB if not already
// existing with the given name. // existing with the given name.
func CreateSource(source *Source) error { func CreateSource(ctx context.Context, source *Source) error {
has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source)) has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
if err != nil { if err != nil {
return err return err
} else if has { } else if has {
@ -211,7 +212,7 @@ func CreateSource(source *Source) error {
source.IsSyncEnabled = false source.IsSyncEnabled = false
} }
_, err = db.GetEngine(db.DefaultContext).Insert(source) _, err = db.GetEngine(ctx).Insert(source)
if err != nil { if err != nil {
return err return err
} }
@ -232,7 +233,7 @@ func CreateSource(source *Source) error {
err = registerableSource.RegisterSource() err = registerableSource.RegisterSource()
if err != nil { if err != nil {
// remove the AuthSource in case of errors while registering configuration // remove the AuthSource in case of errors while registering configuration
if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil { if _, err := db.GetEngine(ctx).Delete(source); err != nil {
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err) log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
} }
} }
@ -240,33 +241,33 @@ func CreateSource(source *Source) error {
} }
// Sources returns a slice of all login sources found in DB. // Sources returns a slice of all login sources found in DB.
func Sources() ([]*Source, error) { func Sources(ctx context.Context) ([]*Source, error) {
auths := make([]*Source, 0, 6) auths := make([]*Source, 0, 6)
return auths, db.GetEngine(db.DefaultContext).Find(&auths) return auths, db.GetEngine(ctx).Find(&auths)
} }
// SourcesByType returns all sources of the specified type // SourcesByType returns all sources of the specified type
func SourcesByType(loginType Type) ([]*Source, error) { func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) {
sources := make([]*Source, 0, 1) sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil { if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil {
return nil, err return nil, err
} }
return sources, nil return sources, nil
} }
// AllActiveSources returns all active sources // AllActiveSources returns all active sources
func AllActiveSources() ([]*Source, error) { func AllActiveSources(ctx context.Context) ([]*Source, error) {
sources := make([]*Source, 0, 5) sources := make([]*Source, 0, 5)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil { if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil {
return nil, err return nil, err
} }
return sources, nil return sources, nil
} }
// ActiveSources returns all active sources of the specified type // ActiveSources returns all active sources of the specified type
func ActiveSources(tp Type) ([]*Source, error) { func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) {
sources := make([]*Source, 0, 1) sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil { if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
return nil, err return nil, err
} }
return sources, nil return sources, nil
@ -274,11 +275,11 @@ func ActiveSources(tp Type) ([]*Source, error) {
// IsSSPIEnabled returns true if there is at least one activated login // IsSSPIEnabled returns true if there is at least one activated login
// source of type LoginSSPI // source of type LoginSSPI
func IsSSPIEnabled() bool { func IsSSPIEnabled(ctx context.Context) bool {
if !db.HasEngine { if !db.HasEngine {
return false return false
} }
sources, err := ActiveSources(SSPI) sources, err := ActiveSources(ctx, SSPI)
if err != nil { if err != nil {
log.Error("ActiveSources: %v", err) log.Error("ActiveSources: %v", err)
return false return false
@ -287,7 +288,7 @@ func IsSSPIEnabled() bool {
} }
// GetSourceByID returns login source by given ID. // GetSourceByID returns login source by given ID.
func GetSourceByID(id int64) (*Source, error) { func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
source := new(Source) source := new(Source)
if id == 0 { if id == 0 {
source.Cfg = registeredConfigs[NoType]() source.Cfg = registeredConfigs[NoType]()
@ -297,7 +298,7 @@ func GetSourceByID(id int64) (*Source, error) {
return source, nil return source, nil
} }
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source) has, err := db.GetEngine(ctx).ID(id).Get(source)
if err != nil { if err != nil {
return nil, err return nil, err
} else if !has { } else if !has {
@ -307,24 +308,24 @@ func GetSourceByID(id int64) (*Source, error) {
} }
// UpdateSource updates a Source record in DB. // UpdateSource updates a Source record in DB.
func UpdateSource(source *Source) error { func UpdateSource(ctx context.Context, source *Source) error {
var originalSource *Source var originalSource *Source
if source.IsOAuth2() { if source.IsOAuth2() {
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers // keep track of the original values so we can restore in case of errors while registering OAuth2 providers
var err error var err error
if originalSource, err = GetSourceByID(source.ID); err != nil { if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
return err return err
} }
} }
has, err := db.GetEngine(db.DefaultContext).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source)) has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
if err != nil { if err != nil {
return err return err
} else if has { } else if has {
return ErrSourceAlreadyExist{source.Name} return ErrSourceAlreadyExist{source.Name}
} }
_, err = db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source) _, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
if err != nil { if err != nil {
return err return err
} }
@ -345,7 +346,7 @@ func UpdateSource(source *Source) error {
err = registerableSource.RegisterSource() err = registerableSource.RegisterSource()
if err != nil { if err != nil {
// restore original values since we cannot update the provider it self // restore original values since we cannot update the provider it self
if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil { if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err) log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
} }
} }
@ -353,8 +354,8 @@ func UpdateSource(source *Source) error {
} }
// CountSources returns number of login sources. // CountSources returns number of login sources.
func CountSources() int64 { func CountSources(ctx context.Context) int64 {
count, _ := db.GetEngine(db.DefaultContext).Count(new(Source)) count, _ := db.GetEngine(ctx).Count(new(Source))
return count return count
} }

View File

@ -42,7 +42,7 @@ func TestDumpAuthSource(t *testing.T) {
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource)) auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
auth_model.CreateSource(&auth_model.Source{ auth_model.CreateSource(db.DefaultContext, &auth_model.Source{
Type: auth_model.OAuth2, Type: auth_model.OAuth2,
Name: "TestSource", Name: "TestSource",
IsActive: false, IsActive: false,

View File

@ -111,7 +111,7 @@ func findCodeComments(ctx context.Context, opts FindCommentsOptions, issue *Issu
if comment.RenderedContent, err = markdown.RenderString(&markup.RenderContext{ if comment.RenderedContent, err = markdown.RenderString(&markup.RenderContext{
Ctx: ctx, Ctx: ctx,
URLPrefix: issue.Repo.Link(), URLPrefix: issue.Repo.Link(),
Metas: issue.Repo.ComposeMetas(), Metas: issue.Repo.ComposeMetas(ctx),
}, comment.Content); err != nil { }, comment.Content); err != nil {
return nil, err return nil, err
} }

View File

@ -127,8 +127,8 @@ const (
) )
// CreateIssueDependency creates a new dependency for an issue // CreateIssueDependency creates a new dependency for an issue
func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error { func CreateIssueDependency(ctx context.Context, user *user_model.User, issue, dep *Issue) error {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -168,8 +168,8 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
} }
// RemoveIssueDependency removes a dependency from an issue // RemoveIssueDependency removes a dependency from an issue
func RemoveIssueDependency(user *user_model.User, issue, dep *Issue, depType DependencyType) (err error) { func RemoveIssueDependency(ctx context.Context, user *user_model.User, issue, dep *Issue, depType DependencyType) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -28,16 +28,16 @@ func TestCreateIssueDependency(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Create a dependency and check if it was successful // Create a dependency and check if it was successful
err = issues_model.CreateIssueDependency(user1, issue1, issue2) err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue1, issue2)
assert.NoError(t, err) assert.NoError(t, err)
// Do it again to see if it will check if the dependency already exists // Do it again to see if it will check if the dependency already exists
err = issues_model.CreateIssueDependency(user1, issue1, issue2) err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue1, issue2)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, issues_model.IsErrDependencyExists(err)) assert.True(t, issues_model.IsErrDependencyExists(err))
// Check for circular dependencies // Check for circular dependencies
err = issues_model.CreateIssueDependency(user1, issue2, issue1) err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue2, issue1)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, issues_model.IsErrCircularDependency(err)) assert.True(t, issues_model.IsErrCircularDependency(err))
@ -57,6 +57,6 @@ func TestCreateIssueDependency(t *testing.T) {
assert.True(t, left) assert.True(t, left)
// Test removing the dependency // Test removing the dependency
err = issues_model.RemoveIssueDependency(user1, issue1, issue2, issues_model.DependencyTypeBlockedBy) err = issues_model.RemoveIssueDependency(db.DefaultContext, user1, issue1, issue2, issues_model.DependencyTypeBlockedBy)
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -83,12 +83,12 @@ func RemoveDuplicateExclusiveIssueLabels(ctx context.Context, issue *Issue, labe
} }
// NewIssueLabel creates a new issue-label relation. // NewIssueLabel creates a new issue-label relation.
func NewIssueLabel(issue *Issue, label *Label, doer *user_model.User) (err error) { func NewIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *user_model.User) (err error) {
if HasIssueLabel(db.DefaultContext, issue.ID, label.ID) { if HasIssueLabel(ctx, issue.ID, label.ID) {
return nil return nil
} }
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -149,8 +149,8 @@ func newIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *us
} }
// NewIssueLabels creates a list of issue-label relations. // NewIssueLabels creates a list of issue-label relations.
func NewIssueLabels(issue *Issue, labels []*Label, doer *user_model.User) (err error) { func NewIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -359,8 +359,8 @@ func clearIssueLabels(ctx context.Context, issue *Issue, doer *user_model.User)
// ClearIssueLabels removes all issue labels as the given user. // ClearIssueLabels removes all issue labels as the given user.
// Triggers appropriate WebHooks, if any. // Triggers appropriate WebHooks, if any.
func ClearIssueLabels(issue *Issue, doer *user_model.User) (err error) { func ClearIssueLabels(ctx context.Context, issue *Issue, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -432,8 +432,8 @@ func RemoveDuplicateExclusiveLabels(labels []*Label) []*Label {
// ReplaceIssueLabels removes all current labels and add new labels to the issue. // ReplaceIssueLabels removes all current labels and add new labels to the issue.
// Triggers appropriate WebHooks, if any. // Triggers appropriate WebHooks, if any.
func ReplaceIssueLabels(issue *Issue, labels []*Label, doer *user_model.User) (err error) { func ReplaceIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,6 +6,7 @@ package issues_test
import ( import (
"testing" "testing"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues" issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user" user_model "code.gitea.io/gitea/models/user"
@ -21,7 +22,7 @@ func TestNewIssueLabelsScope(t *testing.T) {
label2 := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8}) label2 := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{label1, label2}, doer)) assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{label1, label2}, doer))
assert.Len(t, issue.Labels, 1) assert.Len(t, issue.Labels, 1)
assert.Equal(t, label2.ID, issue.Labels[0].ID) assert.Equal(t, label2.ID, issue.Labels[0].ID)

View File

@ -4,6 +4,8 @@
package issues package issues
import ( import (
"context"
"code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user" user_model "code.gitea.io/gitea/models/user"
) )
@ -17,16 +19,16 @@ type IssueLockOptions struct {
// LockIssue locks an issue. This would limit commenting abilities to // LockIssue locks an issue. This would limit commenting abilities to
// users with write access to the repo // users with write access to the repo
func LockIssue(opts *IssueLockOptions) error { func LockIssue(ctx context.Context, opts *IssueLockOptions) error {
return updateIssueLock(opts, true) return updateIssueLock(ctx, opts, true)
} }
// UnlockIssue unlocks a previously locked issue. // UnlockIssue unlocks a previously locked issue.
func UnlockIssue(opts *IssueLockOptions) error { func UnlockIssue(ctx context.Context, opts *IssueLockOptions) error {
return updateIssueLock(opts, false) return updateIssueLock(ctx, opts, false)
} }
func updateIssueLock(opts *IssueLockOptions, lock bool) error { func updateIssueLock(ctx context.Context, opts *IssueLockOptions, lock bool) error {
if opts.Issue.IsLocked == lock { if opts.Issue.IsLocked == lock {
return nil return nil
} }
@ -39,7 +41,7 @@ func updateIssueLock(opts *IssueLockOptions, lock bool) error {
commentType = CommentTypeUnlock commentType = CommentTypeUnlock
} }
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -38,11 +38,7 @@ func (issue *Issue) projectID(ctx context.Context) int64 {
} }
// ProjectBoardID return project board id if issue was assigned to one // ProjectBoardID return project board id if issue was assigned to one
func (issue *Issue) ProjectBoardID() int64 { func (issue *Issue) ProjectBoardID(ctx context.Context) int64 {
return issue.projectBoardID(db.DefaultContext)
}
func (issue *Issue) projectBoardID(ctx context.Context) int64 {
var ip project_model.ProjectIssue var ip project_model.ProjectIssue
has, err := db.GetEngine(ctx).Where("issue_id=?", issue.ID).Get(&ip) has, err := db.GetEngine(ctx).Where("issue_id=?", issue.ID).Get(&ip)
if err != nil || !has { if err != nil || !has {
@ -100,8 +96,8 @@ func LoadIssuesFromBoardList(ctx context.Context, bs project_model.BoardList) (m
} }
// ChangeProjectAssign changes the project associated with an issue // ChangeProjectAssign changes the project associated with an issue
func ChangeProjectAssign(issue *Issue, doer *user_model.User, newProjectID int64) error { func ChangeProjectAssign(ctx context.Context, issue *Issue, doer *user_model.User, newProjectID int64) error {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -156,8 +152,8 @@ func addUpdateIssueProject(ctx context.Context, issue *Issue, doer *user_model.U
} }
// MoveIssueAcrossProjectBoards move a card from one board to another // MoveIssueAcrossProjectBoards move a card from one board to another
func MoveIssueAcrossProjectBoards(issue *Issue, board *project_model.Board) error { func MoveIssueAcrossProjectBoards(ctx context.Context, issue *Issue, board *project_model.Board) error {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -444,9 +444,9 @@ func applySubscribedCondition(sess *xorm.Session, subscriberID int64) *xorm.Sess
} }
// GetRepoIDsForIssuesOptions find all repo ids for the given options // GetRepoIDsForIssuesOptions find all repo ids for the given options
func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *user_model.User) ([]int64, error) { func GetRepoIDsForIssuesOptions(ctx context.Context, opts *IssuesOptions, user *user_model.User) ([]int64, error) {
repoIDs := make([]int64, 0, 5) repoIDs := make([]int64, 0, 5)
e := db.GetEngine(db.DefaultContext) e := db.GetEngine(ctx)
sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id") sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")

View File

@ -34,7 +34,7 @@ func TestIssue_ReplaceLabels(t *testing.T) {
for i, labelID := range labelIDs { for i, labelID := range labelIDs {
labels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID, RepoID: repo.ID}) labels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID, RepoID: repo.ID})
} }
assert.NoError(t, issues_model.ReplaceIssueLabels(issue, labels, doer)) assert.NoError(t, issues_model.ReplaceIssueLabels(db.DefaultContext, issue, labels, doer))
unittest.AssertCount(t, &issues_model.IssueLabel{IssueID: issueID}, len(expectedLabelIDs)) unittest.AssertCount(t, &issues_model.IssueLabel{IssueID: issueID}, len(expectedLabelIDs))
for _, labelID := range expectedLabelIDs { for _, labelID := range expectedLabelIDs {
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issueID, LabelID: labelID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issueID, LabelID: labelID})
@ -122,7 +122,7 @@ func TestIssue_ClearLabels(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase()) assert.NoError(t, unittest.PrepareTestDatabase())
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: test.issueID}) issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: test.issueID})
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID}) doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID})
assert.NoError(t, issues_model.ClearIssueLabels(issue, doer)) assert.NoError(t, issues_model.ClearIssueLabels(db.DefaultContext, issue, doer))
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: test.issueID}) unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: test.issueID})
} }
} }
@ -230,7 +230,7 @@ func TestGetRepoIDsForIssuesOptions(t *testing.T) {
[]int64{1, 2}, []int64{1, 2},
}, },
} { } {
repoIDs, err := issues_model.GetRepoIDsForIssuesOptions(&test.Opts, user) repoIDs, err := issues_model.GetRepoIDsForIssuesOptions(db.DefaultContext, &test.Opts, user)
assert.NoError(t, err) assert.NoError(t, err)
if assert.Len(t, repoIDs, len(test.ExpectedRepoIDs)) { if assert.Len(t, repoIDs, len(test.ExpectedRepoIDs)) {
for i, repoID := range repoIDs { for i, repoID := range repoIDs {

View File

@ -307,7 +307,7 @@ func TestNewIssueLabel(t *testing.T) {
// add new IssueLabel // add new IssueLabel
prevNumIssues := label.NumIssues prevNumIssues := label.NumIssues
assert.NoError(t, issues_model.NewIssueLabel(issue, label, doer)) assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, label, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{ unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
Type: issues_model.CommentTypeLabel, Type: issues_model.CommentTypeLabel,
@ -320,7 +320,7 @@ func TestNewIssueLabel(t *testing.T) {
assert.EqualValues(t, prevNumIssues+1, label.NumIssues) assert.EqualValues(t, prevNumIssues+1, label.NumIssues)
// re-add existing IssueLabel // re-add existing IssueLabel
assert.NoError(t, issues_model.NewIssueLabel(issue, label, doer)) assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, label, doer))
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{}) unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
} }
@ -334,19 +334,19 @@ func TestNewIssueExclusiveLabel(t *testing.T) {
exclusiveLabelB := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8}) exclusiveLabelB := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
// coexisting regular and exclusive label // coexisting regular and exclusive label
assert.NoError(t, issues_model.NewIssueLabel(issue, otherLabel, doer)) assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, otherLabel, doer))
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelA, doer)) assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelA, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
// exclusive label replaces existing one // exclusive label replaces existing one
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelB, doer)) assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelB, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID}) unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
// exclusive label replaces existing one again // exclusive label replaces existing one again
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelA, doer)) assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelA, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID}) unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
@ -359,7 +359,7 @@ func TestNewIssueLabels(t *testing.T) {
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 5}) issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 5})
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{label1, label2}, doer)) assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{label1, label2}, doer))
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label1.ID}) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{ unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
Type: issues_model.CommentTypeLabel, Type: issues_model.CommentTypeLabel,
@ -377,7 +377,7 @@ func TestNewIssueLabels(t *testing.T) {
assert.EqualValues(t, 1, label2.NumClosedIssues) assert.EqualValues(t, 1, label2.NumClosedIssues)
// corner case: test empty slice // corner case: test empty slice
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{}, doer)) assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{}, doer))
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{}) unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
} }

View File

@ -58,8 +58,8 @@ func (opts GetMilestonesOption) toCond() builder.Cond {
} }
// GetMilestones returns milestones filtered by GetMilestonesOption's // GetMilestones returns milestones filtered by GetMilestonesOption's
func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) { func GetMilestones(ctx context.Context, opts GetMilestonesOption) (MilestoneList, int64, error) {
sess := db.GetEngine(db.DefaultContext).Where(opts.toCond()) sess := db.GetEngine(ctx).Where(opts.toCond())
if opts.Page != 0 { if opts.Page != 0 {
sess = db.SetSessionPagination(sess, &opts) sess = db.SetSessionPagination(sess, &opts)

View File

@ -40,7 +40,7 @@ func TestGetMilestonesByRepoID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase()) assert.NoError(t, unittest.PrepareTestDatabase())
test := func(repoID int64, state api.StateType) { test := func(repoID int64, state api.StateType) {
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID})
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{ milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
RepoID: repo.ID, RepoID: repo.ID,
State: state, State: state,
}) })
@ -77,7 +77,7 @@ func TestGetMilestonesByRepoID(t *testing.T) {
test(3, api.StateClosed) test(3, api.StateClosed)
test(3, api.StateAll) test(3, api.StateAll)
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{ milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
RepoID: unittest.NonexistentID, RepoID: unittest.NonexistentID,
State: api.StateOpen, State: api.StateOpen,
}) })
@ -90,7 +90,7 @@ func TestGetMilestones(t *testing.T) {
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}) repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
test := func(sortType string, sortCond func(*issues_model.Milestone) int) { test := func(sortType string, sortCond func(*issues_model.Milestone) int) {
for _, page := range []int{0, 1} { for _, page := range []int{0, 1} {
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{ milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
ListOptions: db.ListOptions{ ListOptions: db.ListOptions{
Page: page, Page: page,
PageSize: setting.UI.IssuePagingNum, PageSize: setting.UI.IssuePagingNum,
@ -107,7 +107,7 @@ func TestGetMilestones(t *testing.T) {
} }
assert.True(t, sort.IntsAreSorted(values)) assert.True(t, sort.IntsAreSorted(values))
milestones, _, err = issues_model.GetMilestones(issues_model.GetMilestonesOption{ milestones, _, err = issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
ListOptions: db.ListOptions{ ListOptions: db.ListOptions{
Page: page, Page: page,
PageSize: setting.UI.IssuePagingNum, PageSize: setting.UI.IssuePagingNum,

View File

@ -378,9 +378,9 @@ func (pr *PullRequest) GetApprovalCounts(ctx context.Context) ([]*ReviewCount, e
} }
// GetApprovers returns the approvers of the pull request // GetApprovers returns the approvers of the pull request
func (pr *PullRequest) GetApprovers() string { func (pr *PullRequest) GetApprovers(ctx context.Context) string {
stringBuilder := strings.Builder{} stringBuilder := strings.Builder{}
if err := pr.getReviewedByLines(&stringBuilder); err != nil { if err := pr.getReviewedByLines(ctx, &stringBuilder); err != nil {
log.Error("Unable to getReviewedByLines: Error: %v", err) log.Error("Unable to getReviewedByLines: Error: %v", err)
return "" return ""
} }
@ -388,14 +388,14 @@ func (pr *PullRequest) GetApprovers() string {
return stringBuilder.String() return stringBuilder.String()
} }
func (pr *PullRequest) getReviewedByLines(writer io.Writer) error { func (pr *PullRequest) getReviewedByLines(ctx context.Context, writer io.Writer) error {
maxReviewers := setting.Repository.PullRequest.DefaultMergeMessageMaxApprovers maxReviewers := setting.Repository.PullRequest.DefaultMergeMessageMaxApprovers
if maxReviewers == 0 { if maxReviewers == 0 {
return nil return nil
} }
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -594,9 +594,9 @@ func GetUnmergedPullRequest(ctx context.Context, headRepoID, baseRepoID int64, h
// GetLatestPullRequestByHeadInfo returns the latest pull request (regardless of its status) // GetLatestPullRequestByHeadInfo returns the latest pull request (regardless of its status)
// by given head information (repo and branch). // by given head information (repo and branch).
func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest, error) { func GetLatestPullRequestByHeadInfo(ctx context.Context, repoID int64, branch string) (*PullRequest, error) {
pr := new(PullRequest) pr := new(PullRequest)
has, err := db.GetEngine(db.DefaultContext). has, err := db.GetEngine(ctx).
Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub). Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub).
OrderBy("id DESC"). OrderBy("id DESC").
Get(pr) Get(pr)
@ -646,9 +646,9 @@ func GetPullRequestByID(ctx context.Context, id int64) (*PullRequest, error) {
} }
// GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID. // GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID.
func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error) { func GetPullRequestByIssueIDWithNoAttributes(ctx context.Context, issueID int64) (*PullRequest, error) {
var pr PullRequest var pr PullRequest
has, err := db.GetEngine(db.DefaultContext).Where("issue_id = ?", issueID).Get(&pr) has, err := db.GetEngine(ctx).Where("issue_id = ?", issueID).Get(&pr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -687,14 +687,14 @@ func GetAllUnmergedAgitPullRequestByPoster(ctx context.Context, uid int64) ([]*P
} }
// Update updates all fields of pull request. // Update updates all fields of pull request.
func (pr *PullRequest) Update() error { func (pr *PullRequest) Update(ctx context.Context) error {
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).AllCols().Update(pr) _, err := db.GetEngine(ctx).ID(pr.ID).AllCols().Update(pr)
return err return err
} }
// UpdateCols updates specific fields of pull request. // UpdateCols updates specific fields of pull request.
func (pr *PullRequest) UpdateCols(cols ...string) error { func (pr *PullRequest) UpdateCols(ctx context.Context, cols ...string) error {
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).Cols(cols...).Update(pr) _, err := db.GetEngine(ctx).ID(pr.ID).Cols(cols...).Update(pr)
return err return err
} }
@ -706,8 +706,8 @@ func (pr *PullRequest) UpdateColsIfNotMerged(ctx context.Context, cols ...string
// IsWorkInProgress determine if the Pull Request is a Work In Progress by its title // IsWorkInProgress determine if the Pull Request is a Work In Progress by its title
// Issue must be set before this method can be called. // Issue must be set before this method can be called.
func (pr *PullRequest) IsWorkInProgress() bool { func (pr *PullRequest) IsWorkInProgress(ctx context.Context) bool {
if err := pr.LoadIssue(db.DefaultContext); err != nil { if err := pr.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err) log.Error("LoadIssue: %v", err)
return false return false
} }
@ -774,8 +774,8 @@ func GetPullRequestsByHeadBranch(ctx context.Context, headBranch string, headRep
} }
// GetBaseBranchLink returns the relative URL of the base branch // GetBaseBranchLink returns the relative URL of the base branch
func (pr *PullRequest) GetBaseBranchLink() string { func (pr *PullRequest) GetBaseBranchLink(ctx context.Context) string {
if err := pr.LoadBaseRepo(db.DefaultContext); err != nil { if err := pr.LoadBaseRepo(ctx); err != nil {
log.Error("LoadBaseRepo: %v", err) log.Error("LoadBaseRepo: %v", err)
return "" return ""
} }
@ -786,12 +786,12 @@ func (pr *PullRequest) GetBaseBranchLink() string {
} }
// GetHeadBranchLink returns the relative URL of the head branch // GetHeadBranchLink returns the relative URL of the head branch
func (pr *PullRequest) GetHeadBranchLink() string { func (pr *PullRequest) GetHeadBranchLink(ctx context.Context) string {
if pr.Flow == PullRequestFlowAGit { if pr.Flow == PullRequestFlowAGit {
return "" return ""
} }
if err := pr.LoadHeadRepo(db.DefaultContext); err != nil { if err := pr.LoadHeadRepo(ctx); err != nil {
log.Error("LoadHeadRepo: %v", err) log.Error("LoadHeadRepo: %v", err)
return "" return ""
} }
@ -810,14 +810,14 @@ func UpdateAllowEdits(ctx context.Context, pr *PullRequest) error {
} }
// Mergeable returns if the pullrequest is mergeable. // Mergeable returns if the pullrequest is mergeable.
func (pr *PullRequest) Mergeable() bool { func (pr *PullRequest) Mergeable(ctx context.Context) bool {
// If a pull request isn't mergable if it's: // If a pull request isn't mergable if it's:
// - Being conflict checked. // - Being conflict checked.
// - Has a conflict. // - Has a conflict.
// - Received a error while being conflict checked. // - Received a error while being conflict checked.
// - Is a work-in-progress pull request. // - Is a work-in-progress pull request.
return pr.Status != PullRequestStatusChecking && pr.Status != PullRequestStatusConflict && return pr.Status != PullRequestStatusChecking && pr.Status != PullRequestStatusConflict &&
pr.Status != PullRequestStatusError && !pr.IsWorkInProgress() pr.Status != PullRequestStatusError && !pr.IsWorkInProgress(ctx)
} }
// HasEnoughApprovals returns true if pr has enough granted approvals. // HasEnoughApprovals returns true if pr has enough granted approvals.
@ -890,7 +890,7 @@ func MergeBlockedByOutdatedBranch(protectBranch *git_model.ProtectedBranch, pr *
func PullRequestCodeOwnersReview(ctx context.Context, pull *Issue, pr *PullRequest) error { func PullRequestCodeOwnersReview(ctx context.Context, pull *Issue, pr *PullRequest) error {
files := []string{"CODEOWNERS", "docs/CODEOWNERS", ".gitea/CODEOWNERS"} files := []string{"CODEOWNERS", "docs/CODEOWNERS", ".gitea/CODEOWNERS"}
if pr.IsWorkInProgress() { if pr.IsWorkInProgress(ctx) {
return nil return nil
} }

View File

@ -213,7 +213,7 @@ func TestPullRequest_Update(t *testing.T) {
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
pr.BaseBranch = "baseBranch" pr.BaseBranch = "baseBranch"
pr.HeadBranch = "headBranch" pr.HeadBranch = "headBranch"
pr.Update() pr.Update(db.DefaultContext)
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: pr.ID}) pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: pr.ID})
assert.Equal(t, "baseBranch", pr.BaseBranch) assert.Equal(t, "baseBranch", pr.BaseBranch)
@ -228,7 +228,7 @@ func TestPullRequest_UpdateCols(t *testing.T) {
BaseBranch: "baseBranch", BaseBranch: "baseBranch",
HeadBranch: "headBranch", HeadBranch: "headBranch",
} }
assert.NoError(t, pr.UpdateCols("head_branch")) assert.NoError(t, pr.UpdateCols(db.DefaultContext, "head_branch"))
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}) pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
assert.Equal(t, "master", pr.BaseBranch) assert.Equal(t, "master", pr.BaseBranch)
@ -260,13 +260,13 @@ func TestPullRequest_IsWorkInProgress(t *testing.T) {
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2}) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2})
pr.LoadIssue(db.DefaultContext) pr.LoadIssue(db.DefaultContext)
assert.False(t, pr.IsWorkInProgress()) assert.False(t, pr.IsWorkInProgress(db.DefaultContext))
pr.Issue.Title = "WIP: " + pr.Issue.Title pr.Issue.Title = "WIP: " + pr.Issue.Title
assert.True(t, pr.IsWorkInProgress()) assert.True(t, pr.IsWorkInProgress(db.DefaultContext))
pr.Issue.Title = "[wip]: " + pr.Issue.Title pr.Issue.Title = "[wip]: " + pr.Issue.Title
assert.True(t, pr.IsWorkInProgress()) assert.True(t, pr.IsWorkInProgress(db.DefaultContext))
} }
func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) { func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) {
@ -334,7 +334,7 @@ func TestGetApprovers(t *testing.T) {
// Official reviews are already deduplicated. Allow unofficial reviews // Official reviews are already deduplicated. Allow unofficial reviews
// to assert that there are no duplicated approvers. // to assert that there are no duplicated approvers.
setting.Repository.PullRequest.DefaultMergeMessageOfficialApproversOnly = false setting.Repository.PullRequest.DefaultMergeMessageOfficialApproversOnly = false
approvers := pr.GetApprovers() approvers := pr.GetApprovers(db.DefaultContext)
expected := "Reviewed-by: User Five <user5@example.com>\nReviewed-by: Org Six <org6@example.com>\n" expected := "Reviewed-by: User Five <user5@example.com>\nReviewed-by: Org Six <org6@example.com>\n"
assert.EqualValues(t, expected, approvers) assert.EqualValues(t, expected, approvers)
} }

View File

@ -277,8 +277,8 @@ func UpdateRepoStats(ctx context.Context, id int64) error {
return nil return nil
} }
func updateUserStarNumbers(users []user_model.User) error { func updateUserStarNumbers(ctx context.Context, users []user_model.User) error {
ctx, committer, err := db.TxContext(db.DefaultContext) ctx, committer, err := db.TxContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -294,19 +294,19 @@ func updateUserStarNumbers(users []user_model.User) error {
} }
// DoctorUserStarNum recalculate Stars number for all user // DoctorUserStarNum recalculate Stars number for all user
func DoctorUserStarNum() (err error) { func DoctorUserStarNum(ctx context.Context) (err error) {
const batchSize = 100 const batchSize = 100
for start := 0; ; start += batchSize { for start := 0; ; start += batchSize {
users := make([]user_model.User, 0, batchSize) users := make([]user_model.User, 0, batchSize)
if err = db.GetEngine(db.DefaultContext).Limit(batchSize, start).Where("type = ?", 0).Cols("id").Find(&users); err != nil { if err = db.GetEngine(ctx).Limit(batchSize, start).Where("type = ?", 0).Cols("id").Find(&users); err != nil {
return err return err
} }
if len(users) == 0 { if len(users) == 0 {
break break
} }
if err = updateUserStarNumbers(users); err != nil { if err = updateUserStarNumbers(ctx, users); err != nil {
return err return err
} }
} }

Some files were not shown because too many files have changed in this diff Show More