Penultimate round of db.DefaultContext
refactor (#27414)
Part of #27065 --------- Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
authSources, err := auth_model.Sources()
|
authSources, err := auth_model.Sources(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -100,7 +100,7 @@ func runDeleteAuth(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
source, err := auth_model.GetSourceByID(c.Int64("id"))
|
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -17,9 +17,9 @@ import (
|
|||||||
type (
|
type (
|
||||||
authService struct {
|
authService struct {
|
||||||
initDB func(ctx context.Context) error
|
initDB func(ctx context.Context) error
|
||||||
createAuthSource func(*auth.Source) error
|
createAuthSource func(context.Context, *auth.Source) error
|
||||||
updateAuthSource func(*auth.Source) error
|
updateAuthSource func(context.Context, *auth.Source) error
|
||||||
getAuthSourceByID func(id int64) (*auth.Source, error)
|
getAuthSourceByID func(ctx context.Context, id int64) (*auth.Source, error)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -289,12 +289,12 @@ func findLdapSecurityProtocolByName(name string) (ldap.SecurityProtocol, bool) {
|
|||||||
|
|
||||||
// getAuthSource gets the login source by its id defined in the command line flags.
|
// getAuthSource gets the login source by its id defined in the command line flags.
|
||||||
// It returns an error if the id is not set, does not match any source or if the source is not of expected type.
|
// It returns an error if the id is not set, does not match any source or if the source is not of expected type.
|
||||||
func (a *authService) getAuthSource(c *cli.Context, authType auth.Type) (*auth.Source, error) {
|
func (a *authService) getAuthSource(ctx context.Context, c *cli.Context, authType auth.Type) (*auth.Source, error) {
|
||||||
if err := argsSet(c, "id"); err != nil {
|
if err := argsSet(c, "id"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authSource, err := a.getAuthSourceByID(c.Int64("id"))
|
authSource, err := a.getAuthSourceByID(ctx, c.Int64("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -332,7 +332,7 @@ func (a *authService) addLdapBindDn(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.createAuthSource(authSource)
|
return a.createAuthSource(ctx, authSource)
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateLdapBindDn updates a new LDAP via Bind DN authentication source.
|
// updateLdapBindDn updates a new LDAP via Bind DN authentication source.
|
||||||
@ -344,7 +344,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
authSource, err := a.getAuthSource(c, auth.LDAP)
|
authSource, err := a.getAuthSource(ctx, c, auth.LDAP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -354,7 +354,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.updateAuthSource(authSource)
|
return a.updateAuthSource(ctx, authSource)
|
||||||
}
|
}
|
||||||
|
|
||||||
// addLdapSimpleAuth adds a new LDAP (simple auth) authentication source.
|
// addLdapSimpleAuth adds a new LDAP (simple auth) authentication source.
|
||||||
@ -383,7 +383,7 @@ func (a *authService) addLdapSimpleAuth(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.createAuthSource(authSource)
|
return a.createAuthSource(ctx, authSource)
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateLdapBindDn updates a new LDAP (simple auth) authentication source.
|
// updateLdapBindDn updates a new LDAP (simple auth) authentication source.
|
||||||
@ -395,7 +395,7 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
authSource, err := a.getAuthSource(c, auth.DLDAP)
|
authSource, err := a.getAuthSource(ctx, c, auth.DLDAP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -405,5 +405,5 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.updateAuthSource(authSource)
|
return a.updateAuthSource(ctx, authSource)
|
||||||
}
|
}
|
||||||
|
@ -210,15 +210,15 @@ func TestAddLdapBindDn(t *testing.T) {
|
|||||||
initDB: func(context.Context) error {
|
initDB: func(context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
createAuthSource: func(authSource *auth.Source) error {
|
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
createdAuthSource = authSource
|
createdAuthSource = authSource
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
updateAuthSource: func(authSource *auth.Source) error {
|
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
|
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
getAuthSourceByID: func(id int64) (*auth.Source, error) {
|
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
|
||||||
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
|
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
@ -441,15 +441,15 @@ func TestAddLdapSimpleAuth(t *testing.T) {
|
|||||||
initDB: func(context.Context) error {
|
initDB: func(context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
createAuthSource: func(authSource *auth.Source) error {
|
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
createdAuthSource = authSource
|
createdAuthSource = authSource
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
updateAuthSource: func(authSource *auth.Source) error {
|
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
|
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
getAuthSourceByID: func(id int64) (*auth.Source, error) {
|
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
|
||||||
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
|
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
@ -896,15 +896,15 @@ func TestUpdateLdapBindDn(t *testing.T) {
|
|||||||
initDB: func(context.Context) error {
|
initDB: func(context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
createAuthSource: func(authSource *auth.Source) error {
|
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
assert.FailNow(t, "case %d: should not call createAuthSource", n)
|
assert.FailNow(t, "case %d: should not call createAuthSource", n)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
updateAuthSource: func(authSource *auth.Source) error {
|
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
updatedAuthSource = authSource
|
updatedAuthSource = authSource
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
getAuthSourceByID: func(id int64) (*auth.Source, error) {
|
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
|
||||||
if c.id != 0 {
|
if c.id != 0 {
|
||||||
assert.Equal(t, c.id, id, "case %d: wrong id", n)
|
assert.Equal(t, c.id, id, "case %d: wrong id", n)
|
||||||
}
|
}
|
||||||
@ -1286,15 +1286,15 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
|
|||||||
initDB: func(context.Context) error {
|
initDB: func(context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
createAuthSource: func(authSource *auth.Source) error {
|
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
assert.FailNow(t, "case %d: should not call createAuthSource", n)
|
assert.FailNow(t, "case %d: should not call createAuthSource", n)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
updateAuthSource: func(authSource *auth.Source) error {
|
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
|
||||||
updatedAuthSource = authSource
|
updatedAuthSource = authSource
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
getAuthSourceByID: func(id int64) (*auth.Source, error) {
|
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
|
||||||
if c.id != 0 {
|
if c.id != 0 {
|
||||||
assert.Equal(t, c.id, id, "case %d: wrong id", n)
|
assert.Equal(t, c.id, id, "case %d: wrong id", n)
|
||||||
}
|
}
|
||||||
|
@ -183,7 +183,7 @@ func runAddOauth(c *cli.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return auth_model.CreateSource(&auth_model.Source{
|
return auth_model.CreateSource(ctx, &auth_model.Source{
|
||||||
Type: auth_model.OAuth2,
|
Type: auth_model.OAuth2,
|
||||||
Name: c.String("name"),
|
Name: c.String("name"),
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
@ -203,7 +203,7 @@ func runUpdateOauth(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
source, err := auth_model.GetSourceByID(c.Int64("id"))
|
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -294,5 +294,5 @@ func runUpdateOauth(c *cli.Context) error {
|
|||||||
oAuth2Config.CustomURLMapping = customURLMapping
|
oAuth2Config.CustomURLMapping = customURLMapping
|
||||||
source.Cfg = oAuth2Config
|
source.Cfg = oAuth2Config
|
||||||
|
|
||||||
return auth_model.UpdateSource(source)
|
return auth_model.UpdateSource(ctx, source)
|
||||||
}
|
}
|
||||||
|
@ -156,7 +156,7 @@ func runAddSMTP(c *cli.Context) error {
|
|||||||
smtpConfig.Auth = "PLAIN"
|
smtpConfig.Auth = "PLAIN"
|
||||||
}
|
}
|
||||||
|
|
||||||
return auth_model.CreateSource(&auth_model.Source{
|
return auth_model.CreateSource(ctx, &auth_model.Source{
|
||||||
Type: auth_model.SMTP,
|
Type: auth_model.SMTP,
|
||||||
Name: c.String("name"),
|
Name: c.String("name"),
|
||||||
IsActive: active,
|
IsActive: active,
|
||||||
@ -176,7 +176,7 @@ func runUpdateSMTP(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
source, err := auth_model.GetSourceByID(c.Int64("id"))
|
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -197,5 +197,5 @@ func runUpdateSMTP(c *cli.Context) error {
|
|||||||
|
|
||||||
source.Cfg = smtpConfig
|
source.Cfg = smtpConfig
|
||||||
|
|
||||||
return auth_model.UpdateSource(source)
|
return auth_model.UpdateSource(ctx, source)
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error {
|
|||||||
for _, r := range runs {
|
for _, r := range runs {
|
||||||
runsList = append(runsList, r)
|
runsList = append(runsList, r)
|
||||||
}
|
}
|
||||||
return runsList.LoadRepos()
|
return runsList.LoadRepos(ctx)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -52,9 +52,9 @@ func (runs RunList) LoadTriggerUser(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (runs RunList) LoadRepos() error {
|
func (runs RunList) LoadRepos(ctx context.Context) error {
|
||||||
repoIDs := runs.GetRepoIDs()
|
repoIDs := runs.GetRepoIDs()
|
||||||
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
|
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -49,9 +49,9 @@ func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (schedules ScheduleList) LoadRepos() error {
|
func (schedules ScheduleList) LoadRepos(ctx context.Context) error {
|
||||||
repoIDs := schedules.GetRepoIDs()
|
repoIDs := schedules.GetRepoIDs()
|
||||||
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
|
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -53,9 +53,9 @@ func (specs SpecList) GetRepoIDs() []int64 {
|
|||||||
return ids.Values()
|
return ids.Values()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (specs SpecList) LoadRepos() error {
|
func (specs SpecList) LoadRepos(ctx context.Context) error {
|
||||||
repoIDs := specs.GetRepoIDs()
|
repoIDs := specs.GetRepoIDs()
|
||||||
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
|
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
|
|||||||
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
|
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
|
||||||
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
|
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
|
||||||
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
|
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
|
||||||
stats.Counter.AuthSource = auth.CountSources()
|
stats.Counter.AuthSource = auth.CountSources(ctx)
|
||||||
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
|
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
|
||||||
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
|
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
|
||||||
stats.Counter.Label, _ = e.Count(new(issues_model.Label))
|
stats.Counter.Label, _ = e.Count(new(issues_model.Label))
|
||||||
|
@ -91,7 +91,7 @@ func addKey(ctx context.Context, key *PublicKey) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddPublicKey adds new public key to database and authorized_keys file.
|
// AddPublicKey adds new public key to database and authorized_keys file.
|
||||||
func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) {
|
func AddPublicKey(ctx context.Context, ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) {
|
||||||
log.Trace(content)
|
log.Trace(content)
|
||||||
|
|
||||||
fingerprint, err := CalcFingerprint(content)
|
fingerprint, err := CalcFingerprint(content)
|
||||||
@ -99,7 +99,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -136,9 +136,9 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPublicKeyByID returns public key by given ID.
|
// GetPublicKeyByID returns public key by given ID.
|
||||||
func GetPublicKeyByID(keyID int64) (*PublicKey, error) {
|
func GetPublicKeyByID(ctx context.Context, keyID int64) (*PublicKey, error) {
|
||||||
key := new(PublicKey)
|
key := new(PublicKey)
|
||||||
has, err := db.GetEngine(db.DefaultContext).
|
has, err := db.GetEngine(ctx).
|
||||||
ID(keyID).
|
ID(keyID).
|
||||||
Get(key)
|
Get(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -180,7 +180,7 @@ func SearchPublicKeyByContentExact(ctx context.Context, content string) (*Public
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SearchPublicKey returns a list of public keys matching the provided arguments.
|
// SearchPublicKey returns a list of public keys matching the provided arguments.
|
||||||
func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) {
|
func SearchPublicKey(ctx context.Context, uid int64, fingerprint string) ([]*PublicKey, error) {
|
||||||
keys := make([]*PublicKey, 0, 5)
|
keys := make([]*PublicKey, 0, 5)
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
if uid != 0 {
|
if uid != 0 {
|
||||||
@ -189,12 +189,12 @@ func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) {
|
|||||||
if fingerprint != "" {
|
if fingerprint != "" {
|
||||||
cond = cond.And(builder.Eq{"fingerprint": fingerprint})
|
cond = cond.And(builder.Eq{"fingerprint": fingerprint})
|
||||||
}
|
}
|
||||||
return keys, db.GetEngine(db.DefaultContext).Where(cond).Find(&keys)
|
return keys, db.GetEngine(ctx).Where(cond).Find(&keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPublicKeys returns a list of public keys belongs to given user.
|
// ListPublicKeys returns a list of public keys belongs to given user.
|
||||||
func ListPublicKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error) {
|
func ListPublicKeys(ctx context.Context, uid int64, listOptions db.ListOptions) ([]*PublicKey, error) {
|
||||||
sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal)
|
sess := db.GetEngine(ctx).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal)
|
||||||
if listOptions.Page != 0 {
|
if listOptions.Page != 0 {
|
||||||
sess = db.SetSessionPagination(sess, &listOptions)
|
sess = db.SetSessionPagination(sess, &listOptions)
|
||||||
|
|
||||||
@ -207,30 +207,30 @@ func ListPublicKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CountPublicKeys count public keys a user has
|
// CountPublicKeys count public keys a user has
|
||||||
func CountPublicKeys(userID int64) (int64, error) {
|
func CountPublicKeys(ctx context.Context, userID int64) (int64, error) {
|
||||||
sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal)
|
sess := db.GetEngine(ctx).Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal)
|
||||||
return sess.Count(&PublicKey{})
|
return sess.Count(&PublicKey{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPublicKeysBySource returns a list of synchronized public keys for a given user and login source.
|
// ListPublicKeysBySource returns a list of synchronized public keys for a given user and login source.
|
||||||
func ListPublicKeysBySource(uid, authSourceID int64) ([]*PublicKey, error) {
|
func ListPublicKeysBySource(ctx context.Context, uid, authSourceID int64) ([]*PublicKey, error) {
|
||||||
keys := make([]*PublicKey, 0, 5)
|
keys := make([]*PublicKey, 0, 5)
|
||||||
return keys, db.GetEngine(db.DefaultContext).
|
return keys, db.GetEngine(ctx).
|
||||||
Where("owner_id = ? AND login_source_id = ?", uid, authSourceID).
|
Where("owner_id = ? AND login_source_id = ?", uid, authSourceID).
|
||||||
Find(&keys)
|
Find(&keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePublicKeyUpdated updates public key use time.
|
// UpdatePublicKeyUpdated updates public key use time.
|
||||||
func UpdatePublicKeyUpdated(id int64) error {
|
func UpdatePublicKeyUpdated(ctx context.Context, id int64) error {
|
||||||
// Check if key exists before update as affected rows count is unreliable
|
// Check if key exists before update as affected rows count is unreliable
|
||||||
// and will return 0 affected rows if two updates are made at the same time
|
// and will return 0 affected rows if two updates are made at the same time
|
||||||
if cnt, err := db.GetEngine(db.DefaultContext).ID(id).Count(&PublicKey{}); err != nil {
|
if cnt, err := db.GetEngine(ctx).ID(id).Count(&PublicKey{}); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if cnt != 1 {
|
} else if cnt != 1 {
|
||||||
return ErrKeyNotExist{id}
|
return ErrKeyNotExist{id}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := db.GetEngine(db.DefaultContext).ID(id).Cols("updated_unix").Update(&PublicKey{
|
_, err := db.GetEngine(ctx).ID(id).Cols("updated_unix").Update(&PublicKey{
|
||||||
UpdatedUnix: timeutil.TimeStampNow(),
|
UpdatedUnix: timeutil.TimeStampNow(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -250,7 +250,7 @@ func DeletePublicKeys(ctx context.Context, keyIDs ...int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key
|
// PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key
|
||||||
func PublicKeysAreExternallyManaged(keys []*PublicKey) ([]bool, error) {
|
func PublicKeysAreExternallyManaged(ctx context.Context, keys []*PublicKey) ([]bool, error) {
|
||||||
sources := make([]*auth.Source, 0, 5)
|
sources := make([]*auth.Source, 0, 5)
|
||||||
externals := make([]bool, len(keys))
|
externals := make([]bool, len(keys))
|
||||||
keyloop:
|
keyloop:
|
||||||
@ -272,7 +272,7 @@ keyloop:
|
|||||||
|
|
||||||
if source == nil {
|
if source == nil {
|
||||||
var err error
|
var err error
|
||||||
source, err = auth.GetSourceByID(key.LoginSourceID)
|
source, err = auth.GetSourceByID(ctx, key.LoginSourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if auth.IsErrSourceNotExist(err) {
|
if auth.IsErrSourceNotExist(err) {
|
||||||
externals[i] = false
|
externals[i] = false
|
||||||
@ -295,15 +295,15 @@ keyloop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PublicKeyIsExternallyManaged returns whether the provided KeyID represents an externally managed Key
|
// PublicKeyIsExternallyManaged returns whether the provided KeyID represents an externally managed Key
|
||||||
func PublicKeyIsExternallyManaged(id int64) (bool, error) {
|
func PublicKeyIsExternallyManaged(ctx context.Context, id int64) (bool, error) {
|
||||||
key, err := GetPublicKeyByID(id)
|
key, err := GetPublicKeyByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if key.LoginSourceID == 0 {
|
if key.LoginSourceID == 0 {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
source, err := auth.GetSourceByID(key.LoginSourceID)
|
source, err := auth.GetSourceByID(ctx, key.LoginSourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if auth.IsErrSourceNotExist(err) {
|
if auth.IsErrSourceNotExist(err) {
|
||||||
return false, nil
|
return false, nil
|
||||||
@ -318,9 +318,9 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
|
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
|
||||||
func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
|
func deleteKeysMarkedForDeletion(ctx context.Context, keys []string) (bool, error) {
|
||||||
// Start session
|
// Start session
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -349,7 +349,7 @@ func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddPublicKeysBySource add a users public keys. Returns true if there are changes.
|
// AddPublicKeysBySource add a users public keys. Returns true if there are changes.
|
||||||
func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
|
func AddPublicKeysBySource(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
|
||||||
var sshKeysNeedUpdate bool
|
var sshKeysNeedUpdate bool
|
||||||
for _, sshKey := range sshPublicKeys {
|
for _, sshKey := range sshPublicKeys {
|
||||||
var err error
|
var err error
|
||||||
@ -368,7 +368,7 @@ func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys [
|
|||||||
marshalled = marshalled[:len(marshalled)-1]
|
marshalled = marshalled[:len(marshalled)-1]
|
||||||
sshKeyName := fmt.Sprintf("%s-%s", s.Name, ssh.FingerprintSHA256(out))
|
sshKeyName := fmt.Sprintf("%s-%s", s.Name, ssh.FingerprintSHA256(out))
|
||||||
|
|
||||||
if _, err := AddPublicKey(usr.ID, sshKeyName, marshalled, s.ID); err != nil {
|
if _, err := AddPublicKey(ctx, usr.ID, sshKeyName, marshalled, s.ID); err != nil {
|
||||||
if IsErrKeyAlreadyExist(err) {
|
if IsErrKeyAlreadyExist(err) {
|
||||||
log.Trace("AddPublicKeysBySource[%s]: Public SSH Key %s already exists for user", sshKeyName, usr.Name)
|
log.Trace("AddPublicKeysBySource[%s]: Public SSH Key %s already exists for user", sshKeyName, usr.Name)
|
||||||
} else {
|
} else {
|
||||||
@ -387,14 +387,14 @@ func AddPublicKeysBySource(usr *user_model.User, s *auth.Source, sshPublicKeys [
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SynchronizePublicKeys updates a users public keys. Returns true if there are changes.
|
// SynchronizePublicKeys updates a users public keys. Returns true if there are changes.
|
||||||
func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
|
func SynchronizePublicKeys(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
|
||||||
var sshKeysNeedUpdate bool
|
var sshKeysNeedUpdate bool
|
||||||
|
|
||||||
log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name)
|
log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name)
|
||||||
|
|
||||||
// Get Public Keys from DB with current LDAP source
|
// Get Public Keys from DB with current LDAP source
|
||||||
var giteaKeys []string
|
var giteaKeys []string
|
||||||
keys, err := ListPublicKeysBySource(usr.ID, s.ID)
|
keys, err := ListPublicKeysBySource(ctx, usr.ID, s.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("synchronizePublicKeys[%s]: Error listing Public SSH Keys for user %s: %v", s.Name, usr.Name, err)
|
log.Error("synchronizePublicKeys[%s]: Error listing Public SSH Keys for user %s: %v", s.Name, usr.Name, err)
|
||||||
}
|
}
|
||||||
@ -429,7 +429,7 @@ func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys [
|
|||||||
newKeys = append(newKeys, key)
|
newKeys = append(newKeys, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if AddPublicKeysBySource(usr, s, newKeys) {
|
if AddPublicKeysBySource(ctx, usr, s, newKeys) {
|
||||||
sshKeysNeedUpdate = true
|
sshKeysNeedUpdate = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -443,7 +443,7 @@ func SynchronizePublicKeys(usr *user_model.User, s *auth.Source, sshPublicKeys [
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete keys from DB that no longer exist in the source
|
// Delete keys from DB that no longer exist in the source
|
||||||
needUpd, err := deleteKeysMarkedForDeletion(giteaKeysToDelete)
|
needUpd, err := deleteKeysMarkedForDeletion(ctx, giteaKeysToDelete)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("synchronizePublicKeys[%s]: Error deleting Public Keys marked for deletion for user %s: %v", s.Name, usr.Name, err)
|
log.Error("synchronizePublicKeys[%s]: Error deleting Public Keys marked for deletion for user %s: %v", s.Name, usr.Name, err)
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ import (
|
|||||||
func ParseCommitWithSSHSignature(ctx context.Context, c *git.Commit, committer *user_model.User) *CommitVerification {
|
func ParseCommitWithSSHSignature(ctx context.Context, c *git.Commit, committer *user_model.User) *CommitVerification {
|
||||||
// Now try to associate the signature with the committer, if present
|
// Now try to associate the signature with the committer, if present
|
||||||
if committer.ID != 0 {
|
if committer.ID != 0 {
|
||||||
keys, err := ListPublicKeys(committer.ID, db.ListOptions{})
|
keys, err := ListPublicKeys(ctx, committer.ID, db.ListOptions{})
|
||||||
if err != nil { // Skipping failed to get ssh keys of user
|
if err != nil { // Skipping failed to get ssh keys of user
|
||||||
log.Error("ListPublicKeys: %v", err)
|
log.Error("ListPublicKeys: %v", err)
|
||||||
return &CommitVerification{
|
return &CommitVerification{
|
||||||
|
@ -48,8 +48,8 @@ func (key *DeployKey) AfterLoad() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetContent gets associated public key content.
|
// GetContent gets associated public key content.
|
||||||
func (key *DeployKey) GetContent() error {
|
func (key *DeployKey) GetContent(ctx context.Context) error {
|
||||||
pkey, err := GetPublicKeyByID(key.KeyID)
|
pkey, err := GetPublicKeyByID(ctx, key.KeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
@ -199,8 +200,8 @@ func (source *Source) SkipVerify() bool {
|
|||||||
|
|
||||||
// CreateSource inserts a AuthSource in the DB if not already
|
// CreateSource inserts a AuthSource in the DB if not already
|
||||||
// existing with the given name.
|
// existing with the given name.
|
||||||
func CreateSource(source *Source) error {
|
func CreateSource(ctx context.Context, source *Source) error {
|
||||||
has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source))
|
has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if has {
|
} else if has {
|
||||||
@ -211,7 +212,7 @@ func CreateSource(source *Source) error {
|
|||||||
source.IsSyncEnabled = false
|
source.IsSyncEnabled = false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.GetEngine(db.DefaultContext).Insert(source)
|
_, err = db.GetEngine(ctx).Insert(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -232,7 +233,7 @@ func CreateSource(source *Source) error {
|
|||||||
err = registerableSource.RegisterSource()
|
err = registerableSource.RegisterSource()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// remove the AuthSource in case of errors while registering configuration
|
// remove the AuthSource in case of errors while registering configuration
|
||||||
if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil {
|
if _, err := db.GetEngine(ctx).Delete(source); err != nil {
|
||||||
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -240,33 +241,33 @@ func CreateSource(source *Source) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sources returns a slice of all login sources found in DB.
|
// Sources returns a slice of all login sources found in DB.
|
||||||
func Sources() ([]*Source, error) {
|
func Sources(ctx context.Context) ([]*Source, error) {
|
||||||
auths := make([]*Source, 0, 6)
|
auths := make([]*Source, 0, 6)
|
||||||
return auths, db.GetEngine(db.DefaultContext).Find(&auths)
|
return auths, db.GetEngine(ctx).Find(&auths)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SourcesByType returns all sources of the specified type
|
// SourcesByType returns all sources of the specified type
|
||||||
func SourcesByType(loginType Type) ([]*Source, error) {
|
func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) {
|
||||||
sources := make([]*Source, 0, 1)
|
sources := make([]*Source, 0, 1)
|
||||||
if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
|
if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return sources, nil
|
return sources, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllActiveSources returns all active sources
|
// AllActiveSources returns all active sources
|
||||||
func AllActiveSources() ([]*Source, error) {
|
func AllActiveSources(ctx context.Context) ([]*Source, error) {
|
||||||
sources := make([]*Source, 0, 5)
|
sources := make([]*Source, 0, 5)
|
||||||
if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
|
if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return sources, nil
|
return sources, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ActiveSources returns all active sources of the specified type
|
// ActiveSources returns all active sources of the specified type
|
||||||
func ActiveSources(tp Type) ([]*Source, error) {
|
func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) {
|
||||||
sources := make([]*Source, 0, 1)
|
sources := make([]*Source, 0, 1)
|
||||||
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
|
if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return sources, nil
|
return sources, nil
|
||||||
@ -274,11 +275,11 @@ func ActiveSources(tp Type) ([]*Source, error) {
|
|||||||
|
|
||||||
// IsSSPIEnabled returns true if there is at least one activated login
|
// IsSSPIEnabled returns true if there is at least one activated login
|
||||||
// source of type LoginSSPI
|
// source of type LoginSSPI
|
||||||
func IsSSPIEnabled() bool {
|
func IsSSPIEnabled(ctx context.Context) bool {
|
||||||
if !db.HasEngine {
|
if !db.HasEngine {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
sources, err := ActiveSources(SSPI)
|
sources, err := ActiveSources(ctx, SSPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("ActiveSources: %v", err)
|
log.Error("ActiveSources: %v", err)
|
||||||
return false
|
return false
|
||||||
@ -287,7 +288,7 @@ func IsSSPIEnabled() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetSourceByID returns login source by given ID.
|
// GetSourceByID returns login source by given ID.
|
||||||
func GetSourceByID(id int64) (*Source, error) {
|
func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
|
||||||
source := new(Source)
|
source := new(Source)
|
||||||
if id == 0 {
|
if id == 0 {
|
||||||
source.Cfg = registeredConfigs[NoType]()
|
source.Cfg = registeredConfigs[NoType]()
|
||||||
@ -297,7 +298,7 @@ func GetSourceByID(id int64) (*Source, error) {
|
|||||||
return source, nil
|
return source, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
|
has, err := db.GetEngine(ctx).ID(id).Get(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if !has {
|
} else if !has {
|
||||||
@ -307,24 +308,24 @@ func GetSourceByID(id int64) (*Source, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSource updates a Source record in DB.
|
// UpdateSource updates a Source record in DB.
|
||||||
func UpdateSource(source *Source) error {
|
func UpdateSource(ctx context.Context, source *Source) error {
|
||||||
var originalSource *Source
|
var originalSource *Source
|
||||||
if source.IsOAuth2() {
|
if source.IsOAuth2() {
|
||||||
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
|
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
|
||||||
var err error
|
var err error
|
||||||
if originalSource, err = GetSourceByID(source.ID); err != nil {
|
if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
has, err := db.GetEngine(db.DefaultContext).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
|
has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if has {
|
} else if has {
|
||||||
return ErrSourceAlreadyExist{source.Name}
|
return ErrSourceAlreadyExist{source.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
|
_, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -345,7 +346,7 @@ func UpdateSource(source *Source) error {
|
|||||||
err = registerableSource.RegisterSource()
|
err = registerableSource.RegisterSource()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// restore original values since we cannot update the provider it self
|
// restore original values since we cannot update the provider it self
|
||||||
if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil {
|
if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
|
||||||
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -353,8 +354,8 @@ func UpdateSource(source *Source) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CountSources returns number of login sources.
|
// CountSources returns number of login sources.
|
||||||
func CountSources() int64 {
|
func CountSources(ctx context.Context) int64 {
|
||||||
count, _ := db.GetEngine(db.DefaultContext).Count(new(Source))
|
count, _ := db.GetEngine(ctx).Count(new(Source))
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ func TestDumpAuthSource(t *testing.T) {
|
|||||||
|
|
||||||
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
|
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
|
||||||
|
|
||||||
auth_model.CreateSource(&auth_model.Source{
|
auth_model.CreateSource(db.DefaultContext, &auth_model.Source{
|
||||||
Type: auth_model.OAuth2,
|
Type: auth_model.OAuth2,
|
||||||
Name: "TestSource",
|
Name: "TestSource",
|
||||||
IsActive: false,
|
IsActive: false,
|
||||||
|
@ -111,7 +111,7 @@ func findCodeComments(ctx context.Context, opts FindCommentsOptions, issue *Issu
|
|||||||
if comment.RenderedContent, err = markdown.RenderString(&markup.RenderContext{
|
if comment.RenderedContent, err = markdown.RenderString(&markup.RenderContext{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
URLPrefix: issue.Repo.Link(),
|
URLPrefix: issue.Repo.Link(),
|
||||||
Metas: issue.Repo.ComposeMetas(),
|
Metas: issue.Repo.ComposeMetas(ctx),
|
||||||
}, comment.Content); err != nil {
|
}, comment.Content); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -127,8 +127,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// CreateIssueDependency creates a new dependency for an issue
|
// CreateIssueDependency creates a new dependency for an issue
|
||||||
func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
|
func CreateIssueDependency(ctx context.Context, user *user_model.User, issue, dep *Issue) error {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -168,8 +168,8 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveIssueDependency removes a dependency from an issue
|
// RemoveIssueDependency removes a dependency from an issue
|
||||||
func RemoveIssueDependency(user *user_model.User, issue, dep *Issue, depType DependencyType) (err error) {
|
func RemoveIssueDependency(ctx context.Context, user *user_model.User, issue, dep *Issue, depType DependencyType) (err error) {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -28,16 +28,16 @@ func TestCreateIssueDependency(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Create a dependency and check if it was successful
|
// Create a dependency and check if it was successful
|
||||||
err = issues_model.CreateIssueDependency(user1, issue1, issue2)
|
err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue1, issue2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Do it again to see if it will check if the dependency already exists
|
// Do it again to see if it will check if the dependency already exists
|
||||||
err = issues_model.CreateIssueDependency(user1, issue1, issue2)
|
err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue1, issue2)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.True(t, issues_model.IsErrDependencyExists(err))
|
assert.True(t, issues_model.IsErrDependencyExists(err))
|
||||||
|
|
||||||
// Check for circular dependencies
|
// Check for circular dependencies
|
||||||
err = issues_model.CreateIssueDependency(user1, issue2, issue1)
|
err = issues_model.CreateIssueDependency(db.DefaultContext, user1, issue2, issue1)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.True(t, issues_model.IsErrCircularDependency(err))
|
assert.True(t, issues_model.IsErrCircularDependency(err))
|
||||||
|
|
||||||
@ -57,6 +57,6 @@ func TestCreateIssueDependency(t *testing.T) {
|
|||||||
assert.True(t, left)
|
assert.True(t, left)
|
||||||
|
|
||||||
// Test removing the dependency
|
// Test removing the dependency
|
||||||
err = issues_model.RemoveIssueDependency(user1, issue1, issue2, issues_model.DependencyTypeBlockedBy)
|
err = issues_model.RemoveIssueDependency(db.DefaultContext, user1, issue1, issue2, issues_model.DependencyTypeBlockedBy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
@ -83,12 +83,12 @@ func RemoveDuplicateExclusiveIssueLabels(ctx context.Context, issue *Issue, labe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewIssueLabel creates a new issue-label relation.
|
// NewIssueLabel creates a new issue-label relation.
|
||||||
func NewIssueLabel(issue *Issue, label *Label, doer *user_model.User) (err error) {
|
func NewIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *user_model.User) (err error) {
|
||||||
if HasIssueLabel(db.DefaultContext, issue.ID, label.ID) {
|
if HasIssueLabel(ctx, issue.ID, label.ID) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -149,8 +149,8 @@ func newIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *us
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewIssueLabels creates a list of issue-label relations.
|
// NewIssueLabels creates a list of issue-label relations.
|
||||||
func NewIssueLabels(issue *Issue, labels []*Label, doer *user_model.User) (err error) {
|
func NewIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *user_model.User) (err error) {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -359,8 +359,8 @@ func clearIssueLabels(ctx context.Context, issue *Issue, doer *user_model.User)
|
|||||||
|
|
||||||
// ClearIssueLabels removes all issue labels as the given user.
|
// ClearIssueLabels removes all issue labels as the given user.
|
||||||
// Triggers appropriate WebHooks, if any.
|
// Triggers appropriate WebHooks, if any.
|
||||||
func ClearIssueLabels(issue *Issue, doer *user_model.User) (err error) {
|
func ClearIssueLabels(ctx context.Context, issue *Issue, doer *user_model.User) (err error) {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -432,8 +432,8 @@ func RemoveDuplicateExclusiveLabels(labels []*Label) []*Label {
|
|||||||
|
|
||||||
// ReplaceIssueLabels removes all current labels and add new labels to the issue.
|
// ReplaceIssueLabels removes all current labels and add new labels to the issue.
|
||||||
// Triggers appropriate WebHooks, if any.
|
// Triggers appropriate WebHooks, if any.
|
||||||
func ReplaceIssueLabels(issue *Issue, labels []*Label, doer *user_model.User) (err error) {
|
func ReplaceIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *user_model.User) (err error) {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ package issues_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"code.gitea.io/gitea/models/db"
|
||||||
issues_model "code.gitea.io/gitea/models/issues"
|
issues_model "code.gitea.io/gitea/models/issues"
|
||||||
"code.gitea.io/gitea/models/unittest"
|
"code.gitea.io/gitea/models/unittest"
|
||||||
user_model "code.gitea.io/gitea/models/user"
|
user_model "code.gitea.io/gitea/models/user"
|
||||||
@ -21,7 +22,7 @@ func TestNewIssueLabelsScope(t *testing.T) {
|
|||||||
label2 := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
|
label2 := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
|
||||||
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
|
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
|
||||||
|
|
||||||
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{label1, label2}, doer))
|
assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{label1, label2}, doer))
|
||||||
|
|
||||||
assert.Len(t, issue.Labels, 1)
|
assert.Len(t, issue.Labels, 1)
|
||||||
assert.Equal(t, label2.ID, issue.Labels[0].ID)
|
assert.Equal(t, label2.ID, issue.Labels[0].ID)
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
package issues
|
package issues
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"code.gitea.io/gitea/models/db"
|
"code.gitea.io/gitea/models/db"
|
||||||
user_model "code.gitea.io/gitea/models/user"
|
user_model "code.gitea.io/gitea/models/user"
|
||||||
)
|
)
|
||||||
@ -17,16 +19,16 @@ type IssueLockOptions struct {
|
|||||||
|
|
||||||
// LockIssue locks an issue. This would limit commenting abilities to
|
// LockIssue locks an issue. This would limit commenting abilities to
|
||||||
// users with write access to the repo
|
// users with write access to the repo
|
||||||
func LockIssue(opts *IssueLockOptions) error {
|
func LockIssue(ctx context.Context, opts *IssueLockOptions) error {
|
||||||
return updateIssueLock(opts, true)
|
return updateIssueLock(ctx, opts, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnlockIssue unlocks a previously locked issue.
|
// UnlockIssue unlocks a previously locked issue.
|
||||||
func UnlockIssue(opts *IssueLockOptions) error {
|
func UnlockIssue(ctx context.Context, opts *IssueLockOptions) error {
|
||||||
return updateIssueLock(opts, false)
|
return updateIssueLock(ctx, opts, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateIssueLock(opts *IssueLockOptions, lock bool) error {
|
func updateIssueLock(ctx context.Context, opts *IssueLockOptions, lock bool) error {
|
||||||
if opts.Issue.IsLocked == lock {
|
if opts.Issue.IsLocked == lock {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -39,7 +41,7 @@ func updateIssueLock(opts *IssueLockOptions, lock bool) error {
|
|||||||
commentType = CommentTypeUnlock
|
commentType = CommentTypeUnlock
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -38,11 +38,7 @@ func (issue *Issue) projectID(ctx context.Context) int64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ProjectBoardID return project board id if issue was assigned to one
|
// ProjectBoardID return project board id if issue was assigned to one
|
||||||
func (issue *Issue) ProjectBoardID() int64 {
|
func (issue *Issue) ProjectBoardID(ctx context.Context) int64 {
|
||||||
return issue.projectBoardID(db.DefaultContext)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (issue *Issue) projectBoardID(ctx context.Context) int64 {
|
|
||||||
var ip project_model.ProjectIssue
|
var ip project_model.ProjectIssue
|
||||||
has, err := db.GetEngine(ctx).Where("issue_id=?", issue.ID).Get(&ip)
|
has, err := db.GetEngine(ctx).Where("issue_id=?", issue.ID).Get(&ip)
|
||||||
if err != nil || !has {
|
if err != nil || !has {
|
||||||
@ -100,8 +96,8 @@ func LoadIssuesFromBoardList(ctx context.Context, bs project_model.BoardList) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ChangeProjectAssign changes the project associated with an issue
|
// ChangeProjectAssign changes the project associated with an issue
|
||||||
func ChangeProjectAssign(issue *Issue, doer *user_model.User, newProjectID int64) error {
|
func ChangeProjectAssign(ctx context.Context, issue *Issue, doer *user_model.User, newProjectID int64) error {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -156,8 +152,8 @@ func addUpdateIssueProject(ctx context.Context, issue *Issue, doer *user_model.U
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MoveIssueAcrossProjectBoards move a card from one board to another
|
// MoveIssueAcrossProjectBoards move a card from one board to another
|
||||||
func MoveIssueAcrossProjectBoards(issue *Issue, board *project_model.Board) error {
|
func MoveIssueAcrossProjectBoards(ctx context.Context, issue *Issue, board *project_model.Board) error {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -444,9 +444,9 @@ func applySubscribedCondition(sess *xorm.Session, subscriberID int64) *xorm.Sess
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRepoIDsForIssuesOptions find all repo ids for the given options
|
// GetRepoIDsForIssuesOptions find all repo ids for the given options
|
||||||
func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *user_model.User) ([]int64, error) {
|
func GetRepoIDsForIssuesOptions(ctx context.Context, opts *IssuesOptions, user *user_model.User) ([]int64, error) {
|
||||||
repoIDs := make([]int64, 0, 5)
|
repoIDs := make([]int64, 0, 5)
|
||||||
e := db.GetEngine(db.DefaultContext)
|
e := db.GetEngine(ctx)
|
||||||
|
|
||||||
sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
|
sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ func TestIssue_ReplaceLabels(t *testing.T) {
|
|||||||
for i, labelID := range labelIDs {
|
for i, labelID := range labelIDs {
|
||||||
labels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID, RepoID: repo.ID})
|
labels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID, RepoID: repo.ID})
|
||||||
}
|
}
|
||||||
assert.NoError(t, issues_model.ReplaceIssueLabels(issue, labels, doer))
|
assert.NoError(t, issues_model.ReplaceIssueLabels(db.DefaultContext, issue, labels, doer))
|
||||||
unittest.AssertCount(t, &issues_model.IssueLabel{IssueID: issueID}, len(expectedLabelIDs))
|
unittest.AssertCount(t, &issues_model.IssueLabel{IssueID: issueID}, len(expectedLabelIDs))
|
||||||
for _, labelID := range expectedLabelIDs {
|
for _, labelID := range expectedLabelIDs {
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issueID, LabelID: labelID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issueID, LabelID: labelID})
|
||||||
@ -122,7 +122,7 @@ func TestIssue_ClearLabels(t *testing.T) {
|
|||||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||||
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: test.issueID})
|
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: test.issueID})
|
||||||
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID})
|
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID})
|
||||||
assert.NoError(t, issues_model.ClearIssueLabels(issue, doer))
|
assert.NoError(t, issues_model.ClearIssueLabels(db.DefaultContext, issue, doer))
|
||||||
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: test.issueID})
|
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: test.issueID})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -230,7 +230,7 @@ func TestGetRepoIDsForIssuesOptions(t *testing.T) {
|
|||||||
[]int64{1, 2},
|
[]int64{1, 2},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
repoIDs, err := issues_model.GetRepoIDsForIssuesOptions(&test.Opts, user)
|
repoIDs, err := issues_model.GetRepoIDsForIssuesOptions(db.DefaultContext, &test.Opts, user)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
if assert.Len(t, repoIDs, len(test.ExpectedRepoIDs)) {
|
if assert.Len(t, repoIDs, len(test.ExpectedRepoIDs)) {
|
||||||
for i, repoID := range repoIDs {
|
for i, repoID := range repoIDs {
|
||||||
|
@ -307,7 +307,7 @@ func TestNewIssueLabel(t *testing.T) {
|
|||||||
|
|
||||||
// add new IssueLabel
|
// add new IssueLabel
|
||||||
prevNumIssues := label.NumIssues
|
prevNumIssues := label.NumIssues
|
||||||
assert.NoError(t, issues_model.NewIssueLabel(issue, label, doer))
|
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, label, doer))
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label.ID})
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
|
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
|
||||||
Type: issues_model.CommentTypeLabel,
|
Type: issues_model.CommentTypeLabel,
|
||||||
@ -320,7 +320,7 @@ func TestNewIssueLabel(t *testing.T) {
|
|||||||
assert.EqualValues(t, prevNumIssues+1, label.NumIssues)
|
assert.EqualValues(t, prevNumIssues+1, label.NumIssues)
|
||||||
|
|
||||||
// re-add existing IssueLabel
|
// re-add existing IssueLabel
|
||||||
assert.NoError(t, issues_model.NewIssueLabel(issue, label, doer))
|
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, label, doer))
|
||||||
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
|
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,19 +334,19 @@ func TestNewIssueExclusiveLabel(t *testing.T) {
|
|||||||
exclusiveLabelB := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
|
exclusiveLabelB := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 8})
|
||||||
|
|
||||||
// coexisting regular and exclusive label
|
// coexisting regular and exclusive label
|
||||||
assert.NoError(t, issues_model.NewIssueLabel(issue, otherLabel, doer))
|
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, otherLabel, doer))
|
||||||
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelA, doer))
|
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelA, doer))
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
|
||||||
|
|
||||||
// exclusive label replaces existing one
|
// exclusive label replaces existing one
|
||||||
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelB, doer))
|
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelB, doer))
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
|
||||||
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
|
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
|
||||||
|
|
||||||
// exclusive label replaces existing one again
|
// exclusive label replaces existing one again
|
||||||
assert.NoError(t, issues_model.NewIssueLabel(issue, exclusiveLabelA, doer))
|
assert.NoError(t, issues_model.NewIssueLabel(db.DefaultContext, issue, exclusiveLabelA, doer))
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: otherLabel.ID})
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelA.ID})
|
||||||
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
|
unittest.AssertNotExistsBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: exclusiveLabelB.ID})
|
||||||
@ -359,7 +359,7 @@ func TestNewIssueLabels(t *testing.T) {
|
|||||||
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 5})
|
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 5})
|
||||||
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
|
doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
|
||||||
|
|
||||||
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{label1, label2}, doer))
|
assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{label1, label2}, doer))
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
|
unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
|
||||||
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
|
unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{
|
||||||
Type: issues_model.CommentTypeLabel,
|
Type: issues_model.CommentTypeLabel,
|
||||||
@ -377,7 +377,7 @@ func TestNewIssueLabels(t *testing.T) {
|
|||||||
assert.EqualValues(t, 1, label2.NumClosedIssues)
|
assert.EqualValues(t, 1, label2.NumClosedIssues)
|
||||||
|
|
||||||
// corner case: test empty slice
|
// corner case: test empty slice
|
||||||
assert.NoError(t, issues_model.NewIssueLabels(issue, []*issues_model.Label{}, doer))
|
assert.NoError(t, issues_model.NewIssueLabels(db.DefaultContext, issue, []*issues_model.Label{}, doer))
|
||||||
|
|
||||||
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
|
unittest.CheckConsistencyFor(t, &issues_model.Issue{}, &issues_model.Label{})
|
||||||
}
|
}
|
||||||
|
@ -58,8 +58,8 @@ func (opts GetMilestonesOption) toCond() builder.Cond {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetMilestones returns milestones filtered by GetMilestonesOption's
|
// GetMilestones returns milestones filtered by GetMilestonesOption's
|
||||||
func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) {
|
func GetMilestones(ctx context.Context, opts GetMilestonesOption) (MilestoneList, int64, error) {
|
||||||
sess := db.GetEngine(db.DefaultContext).Where(opts.toCond())
|
sess := db.GetEngine(ctx).Where(opts.toCond())
|
||||||
|
|
||||||
if opts.Page != 0 {
|
if opts.Page != 0 {
|
||||||
sess = db.SetSessionPagination(sess, &opts)
|
sess = db.SetSessionPagination(sess, &opts)
|
||||||
|
@ -40,7 +40,7 @@ func TestGetMilestonesByRepoID(t *testing.T) {
|
|||||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||||
test := func(repoID int64, state api.StateType) {
|
test := func(repoID int64, state api.StateType) {
|
||||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID})
|
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID})
|
||||||
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{
|
milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
|
||||||
RepoID: repo.ID,
|
RepoID: repo.ID,
|
||||||
State: state,
|
State: state,
|
||||||
})
|
})
|
||||||
@ -77,7 +77,7 @@ func TestGetMilestonesByRepoID(t *testing.T) {
|
|||||||
test(3, api.StateClosed)
|
test(3, api.StateClosed)
|
||||||
test(3, api.StateAll)
|
test(3, api.StateAll)
|
||||||
|
|
||||||
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{
|
milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
|
||||||
RepoID: unittest.NonexistentID,
|
RepoID: unittest.NonexistentID,
|
||||||
State: api.StateOpen,
|
State: api.StateOpen,
|
||||||
})
|
})
|
||||||
@ -90,7 +90,7 @@ func TestGetMilestones(t *testing.T) {
|
|||||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
|
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
|
||||||
test := func(sortType string, sortCond func(*issues_model.Milestone) int) {
|
test := func(sortType string, sortCond func(*issues_model.Milestone) int) {
|
||||||
for _, page := range []int{0, 1} {
|
for _, page := range []int{0, 1} {
|
||||||
milestones, _, err := issues_model.GetMilestones(issues_model.GetMilestonesOption{
|
milestones, _, err := issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
|
||||||
ListOptions: db.ListOptions{
|
ListOptions: db.ListOptions{
|
||||||
Page: page,
|
Page: page,
|
||||||
PageSize: setting.UI.IssuePagingNum,
|
PageSize: setting.UI.IssuePagingNum,
|
||||||
@ -107,7 +107,7 @@ func TestGetMilestones(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.True(t, sort.IntsAreSorted(values))
|
assert.True(t, sort.IntsAreSorted(values))
|
||||||
|
|
||||||
milestones, _, err = issues_model.GetMilestones(issues_model.GetMilestonesOption{
|
milestones, _, err = issues_model.GetMilestones(db.DefaultContext, issues_model.GetMilestonesOption{
|
||||||
ListOptions: db.ListOptions{
|
ListOptions: db.ListOptions{
|
||||||
Page: page,
|
Page: page,
|
||||||
PageSize: setting.UI.IssuePagingNum,
|
PageSize: setting.UI.IssuePagingNum,
|
||||||
|
@ -378,9 +378,9 @@ func (pr *PullRequest) GetApprovalCounts(ctx context.Context) ([]*ReviewCount, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetApprovers returns the approvers of the pull request
|
// GetApprovers returns the approvers of the pull request
|
||||||
func (pr *PullRequest) GetApprovers() string {
|
func (pr *PullRequest) GetApprovers(ctx context.Context) string {
|
||||||
stringBuilder := strings.Builder{}
|
stringBuilder := strings.Builder{}
|
||||||
if err := pr.getReviewedByLines(&stringBuilder); err != nil {
|
if err := pr.getReviewedByLines(ctx, &stringBuilder); err != nil {
|
||||||
log.Error("Unable to getReviewedByLines: Error: %v", err)
|
log.Error("Unable to getReviewedByLines: Error: %v", err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -388,14 +388,14 @@ func (pr *PullRequest) GetApprovers() string {
|
|||||||
return stringBuilder.String()
|
return stringBuilder.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pr *PullRequest) getReviewedByLines(writer io.Writer) error {
|
func (pr *PullRequest) getReviewedByLines(ctx context.Context, writer io.Writer) error {
|
||||||
maxReviewers := setting.Repository.PullRequest.DefaultMergeMessageMaxApprovers
|
maxReviewers := setting.Repository.PullRequest.DefaultMergeMessageMaxApprovers
|
||||||
|
|
||||||
if maxReviewers == 0 {
|
if maxReviewers == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -594,9 +594,9 @@ func GetUnmergedPullRequest(ctx context.Context, headRepoID, baseRepoID int64, h
|
|||||||
|
|
||||||
// GetLatestPullRequestByHeadInfo returns the latest pull request (regardless of its status)
|
// GetLatestPullRequestByHeadInfo returns the latest pull request (regardless of its status)
|
||||||
// by given head information (repo and branch).
|
// by given head information (repo and branch).
|
||||||
func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest, error) {
|
func GetLatestPullRequestByHeadInfo(ctx context.Context, repoID int64, branch string) (*PullRequest, error) {
|
||||||
pr := new(PullRequest)
|
pr := new(PullRequest)
|
||||||
has, err := db.GetEngine(db.DefaultContext).
|
has, err := db.GetEngine(ctx).
|
||||||
Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub).
|
Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub).
|
||||||
OrderBy("id DESC").
|
OrderBy("id DESC").
|
||||||
Get(pr)
|
Get(pr)
|
||||||
@ -646,9 +646,9 @@ func GetPullRequestByID(ctx context.Context, id int64) (*PullRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID.
|
// GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID.
|
||||||
func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error) {
|
func GetPullRequestByIssueIDWithNoAttributes(ctx context.Context, issueID int64) (*PullRequest, error) {
|
||||||
var pr PullRequest
|
var pr PullRequest
|
||||||
has, err := db.GetEngine(db.DefaultContext).Where("issue_id = ?", issueID).Get(&pr)
|
has, err := db.GetEngine(ctx).Where("issue_id = ?", issueID).Get(&pr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -687,14 +687,14 @@ func GetAllUnmergedAgitPullRequestByPoster(ctx context.Context, uid int64) ([]*P
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update updates all fields of pull request.
|
// Update updates all fields of pull request.
|
||||||
func (pr *PullRequest) Update() error {
|
func (pr *PullRequest) Update(ctx context.Context) error {
|
||||||
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).AllCols().Update(pr)
|
_, err := db.GetEngine(ctx).ID(pr.ID).AllCols().Update(pr)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateCols updates specific fields of pull request.
|
// UpdateCols updates specific fields of pull request.
|
||||||
func (pr *PullRequest) UpdateCols(cols ...string) error {
|
func (pr *PullRequest) UpdateCols(ctx context.Context, cols ...string) error {
|
||||||
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).Cols(cols...).Update(pr)
|
_, err := db.GetEngine(ctx).ID(pr.ID).Cols(cols...).Update(pr)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -706,8 +706,8 @@ func (pr *PullRequest) UpdateColsIfNotMerged(ctx context.Context, cols ...string
|
|||||||
|
|
||||||
// IsWorkInProgress determine if the Pull Request is a Work In Progress by its title
|
// IsWorkInProgress determine if the Pull Request is a Work In Progress by its title
|
||||||
// Issue must be set before this method can be called.
|
// Issue must be set before this method can be called.
|
||||||
func (pr *PullRequest) IsWorkInProgress() bool {
|
func (pr *PullRequest) IsWorkInProgress(ctx context.Context) bool {
|
||||||
if err := pr.LoadIssue(db.DefaultContext); err != nil {
|
if err := pr.LoadIssue(ctx); err != nil {
|
||||||
log.Error("LoadIssue: %v", err)
|
log.Error("LoadIssue: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -774,8 +774,8 @@ func GetPullRequestsByHeadBranch(ctx context.Context, headBranch string, headRep
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetBaseBranchLink returns the relative URL of the base branch
|
// GetBaseBranchLink returns the relative URL of the base branch
|
||||||
func (pr *PullRequest) GetBaseBranchLink() string {
|
func (pr *PullRequest) GetBaseBranchLink(ctx context.Context) string {
|
||||||
if err := pr.LoadBaseRepo(db.DefaultContext); err != nil {
|
if err := pr.LoadBaseRepo(ctx); err != nil {
|
||||||
log.Error("LoadBaseRepo: %v", err)
|
log.Error("LoadBaseRepo: %v", err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -786,12 +786,12 @@ func (pr *PullRequest) GetBaseBranchLink() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetHeadBranchLink returns the relative URL of the head branch
|
// GetHeadBranchLink returns the relative URL of the head branch
|
||||||
func (pr *PullRequest) GetHeadBranchLink() string {
|
func (pr *PullRequest) GetHeadBranchLink(ctx context.Context) string {
|
||||||
if pr.Flow == PullRequestFlowAGit {
|
if pr.Flow == PullRequestFlowAGit {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pr.LoadHeadRepo(db.DefaultContext); err != nil {
|
if err := pr.LoadHeadRepo(ctx); err != nil {
|
||||||
log.Error("LoadHeadRepo: %v", err)
|
log.Error("LoadHeadRepo: %v", err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -810,14 +810,14 @@ func UpdateAllowEdits(ctx context.Context, pr *PullRequest) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mergeable returns if the pullrequest is mergeable.
|
// Mergeable returns if the pullrequest is mergeable.
|
||||||
func (pr *PullRequest) Mergeable() bool {
|
func (pr *PullRequest) Mergeable(ctx context.Context) bool {
|
||||||
// If a pull request isn't mergable if it's:
|
// If a pull request isn't mergable if it's:
|
||||||
// - Being conflict checked.
|
// - Being conflict checked.
|
||||||
// - Has a conflict.
|
// - Has a conflict.
|
||||||
// - Received a error while being conflict checked.
|
// - Received a error while being conflict checked.
|
||||||
// - Is a work-in-progress pull request.
|
// - Is a work-in-progress pull request.
|
||||||
return pr.Status != PullRequestStatusChecking && pr.Status != PullRequestStatusConflict &&
|
return pr.Status != PullRequestStatusChecking && pr.Status != PullRequestStatusConflict &&
|
||||||
pr.Status != PullRequestStatusError && !pr.IsWorkInProgress()
|
pr.Status != PullRequestStatusError && !pr.IsWorkInProgress(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEnoughApprovals returns true if pr has enough granted approvals.
|
// HasEnoughApprovals returns true if pr has enough granted approvals.
|
||||||
@ -890,7 +890,7 @@ func MergeBlockedByOutdatedBranch(protectBranch *git_model.ProtectedBranch, pr *
|
|||||||
func PullRequestCodeOwnersReview(ctx context.Context, pull *Issue, pr *PullRequest) error {
|
func PullRequestCodeOwnersReview(ctx context.Context, pull *Issue, pr *PullRequest) error {
|
||||||
files := []string{"CODEOWNERS", "docs/CODEOWNERS", ".gitea/CODEOWNERS"}
|
files := []string{"CODEOWNERS", "docs/CODEOWNERS", ".gitea/CODEOWNERS"}
|
||||||
|
|
||||||
if pr.IsWorkInProgress() {
|
if pr.IsWorkInProgress(ctx) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,7 +213,7 @@ func TestPullRequest_Update(t *testing.T) {
|
|||||||
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
|
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
|
||||||
pr.BaseBranch = "baseBranch"
|
pr.BaseBranch = "baseBranch"
|
||||||
pr.HeadBranch = "headBranch"
|
pr.HeadBranch = "headBranch"
|
||||||
pr.Update()
|
pr.Update(db.DefaultContext)
|
||||||
|
|
||||||
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: pr.ID})
|
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: pr.ID})
|
||||||
assert.Equal(t, "baseBranch", pr.BaseBranch)
|
assert.Equal(t, "baseBranch", pr.BaseBranch)
|
||||||
@ -228,7 +228,7 @@ func TestPullRequest_UpdateCols(t *testing.T) {
|
|||||||
BaseBranch: "baseBranch",
|
BaseBranch: "baseBranch",
|
||||||
HeadBranch: "headBranch",
|
HeadBranch: "headBranch",
|
||||||
}
|
}
|
||||||
assert.NoError(t, pr.UpdateCols("head_branch"))
|
assert.NoError(t, pr.UpdateCols(db.DefaultContext, "head_branch"))
|
||||||
|
|
||||||
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
|
pr = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1})
|
||||||
assert.Equal(t, "master", pr.BaseBranch)
|
assert.Equal(t, "master", pr.BaseBranch)
|
||||||
@ -260,13 +260,13 @@ func TestPullRequest_IsWorkInProgress(t *testing.T) {
|
|||||||
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2})
|
pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2})
|
||||||
pr.LoadIssue(db.DefaultContext)
|
pr.LoadIssue(db.DefaultContext)
|
||||||
|
|
||||||
assert.False(t, pr.IsWorkInProgress())
|
assert.False(t, pr.IsWorkInProgress(db.DefaultContext))
|
||||||
|
|
||||||
pr.Issue.Title = "WIP: " + pr.Issue.Title
|
pr.Issue.Title = "WIP: " + pr.Issue.Title
|
||||||
assert.True(t, pr.IsWorkInProgress())
|
assert.True(t, pr.IsWorkInProgress(db.DefaultContext))
|
||||||
|
|
||||||
pr.Issue.Title = "[wip]: " + pr.Issue.Title
|
pr.Issue.Title = "[wip]: " + pr.Issue.Title
|
||||||
assert.True(t, pr.IsWorkInProgress())
|
assert.True(t, pr.IsWorkInProgress(db.DefaultContext))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) {
|
func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) {
|
||||||
@ -334,7 +334,7 @@ func TestGetApprovers(t *testing.T) {
|
|||||||
// Official reviews are already deduplicated. Allow unofficial reviews
|
// Official reviews are already deduplicated. Allow unofficial reviews
|
||||||
// to assert that there are no duplicated approvers.
|
// to assert that there are no duplicated approvers.
|
||||||
setting.Repository.PullRequest.DefaultMergeMessageOfficialApproversOnly = false
|
setting.Repository.PullRequest.DefaultMergeMessageOfficialApproversOnly = false
|
||||||
approvers := pr.GetApprovers()
|
approvers := pr.GetApprovers(db.DefaultContext)
|
||||||
expected := "Reviewed-by: User Five <user5@example.com>\nReviewed-by: Org Six <org6@example.com>\n"
|
expected := "Reviewed-by: User Five <user5@example.com>\nReviewed-by: Org Six <org6@example.com>\n"
|
||||||
assert.EqualValues(t, expected, approvers)
|
assert.EqualValues(t, expected, approvers)
|
||||||
}
|
}
|
||||||
|
@ -277,8 +277,8 @@ func UpdateRepoStats(ctx context.Context, id int64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserStarNumbers(users []user_model.User) error {
|
func updateUserStarNumbers(ctx context.Context, users []user_model.User) error {
|
||||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
ctx, committer, err := db.TxContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -294,19 +294,19 @@ func updateUserStarNumbers(users []user_model.User) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DoctorUserStarNum recalculate Stars number for all user
|
// DoctorUserStarNum recalculate Stars number for all user
|
||||||
func DoctorUserStarNum() (err error) {
|
func DoctorUserStarNum(ctx context.Context) (err error) {
|
||||||
const batchSize = 100
|
const batchSize = 100
|
||||||
|
|
||||||
for start := 0; ; start += batchSize {
|
for start := 0; ; start += batchSize {
|
||||||
users := make([]user_model.User, 0, batchSize)
|
users := make([]user_model.User, 0, batchSize)
|
||||||
if err = db.GetEngine(db.DefaultContext).Limit(batchSize, start).Where("type = ?", 0).Cols("id").Find(&users); err != nil {
|
if err = db.GetEngine(ctx).Limit(batchSize, start).Where("type = ?", 0).Cols("id").Find(&users); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(users) == 0 {
|
if len(users) == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = updateUserStarNumbers(users); err != nil {
|
if err = updateUserStarNumbers(ctx, users); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user