diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 17f0ea9c..d0443297 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -744,12 +744,6 @@ func (db *DB) CountTasksOfJobInStatus( return 0, 0, err } - // Convert from []api.TaskStatus to []string for feeding to sqlc. - statusesAsStrings := make([]string, len(taskStatuses)) - for index := range taskStatuses { - statusesAsStrings[index] = string(taskStatuses[index]) - } - results, err := queries.JobCountTaskStatuses(ctx, int64(job.ID)) if err != nil { return 0, 0, jobError(err, "count tasks of job %s in status %q", job.UUID, taskStatuses) @@ -803,15 +797,9 @@ func (db *DB) FetchTasksOfJobInStatus(ctx context.Context, job *Job, taskStatuse return nil, err } - // Convert from []api.TaskStatus to []string for feeding to sqlc. - statusesAsStrings := make([]string, len(taskStatuses)) - for index := range taskStatuses { - statusesAsStrings[index] = string(taskStatuses[index]) - } - rows, err := queries.FetchTasksOfJobInStatus(ctx, sqlc.FetchTasksOfJobInStatusParams{ JobID: int64(job.ID), - TaskStatus: statusesAsStrings, + TaskStatus: convertTaskStatuses(taskStatuses), }) if err != nil { return nil, taskError(err, "fetching tasks of job %s in status %q", job.UUID, taskStatuses) @@ -837,13 +825,20 @@ func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, return taskError(nil, "empty status not allowed") } - tx := db.gormDB.WithContext(ctx). - Model(Task{}). - Where("job_Id = ?", job.ID). - Updates(Task{Status: taskStatus, Activity: activity}) + queries, err := db.queries() + if err != nil { + return err + } - if tx.Error != nil { - return taskError(tx.Error, "updating status of all tasks of job %s", job.UUID) + err = queries.UpdateJobsTaskStatuses(ctx, sqlc.UpdateJobsTaskStatusesParams{ + UpdatedAt: db.now(), + Status: string(taskStatus), + Activity: activity, + JobID: int64(job.ID), + }) + + if err != nil { + return taskError(err, "updating status of all tasks of job %s", job.UUID) } return nil } @@ -857,13 +852,21 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, return taskError(nil, "empty status not allowed") } - tx := db.gormDB.WithContext(ctx). - Model(Task{}). - Where("job_Id = ?", job.ID). - Where("status in ?", statusesToUpdate). - Updates(Task{Status: taskStatus, Activity: activity}) - if tx.Error != nil { - return taskError(tx.Error, "updating status of all tasks in status %v of job %s", statusesToUpdate, job.UUID) + queries, err := db.queries() + if err != nil { + return err + } + + err = queries.UpdateJobsTaskStatusesConditional(ctx, sqlc.UpdateJobsTaskStatusesConditionalParams{ + UpdatedAt: db.now(), + Status: string(taskStatus), + Activity: activity, + JobID: int64(job.ID), + StatusesToUpdate: convertTaskStatuses(statusesToUpdate), + }) + + if err != nil { + return taskError(err, "updating status of all tasks in status %v of job %s", statusesToUpdate, job.UUID) } return nil } @@ -1027,3 +1030,12 @@ func convertSqlcTask(task sqlc.Task, jobUUID string, workerUUID string) (*Task, return &dbTask, nil } + +// convertTaskStatuses converts from []api.TaskStatus to []string for feeding to sqlc. +func convertTaskStatuses(taskStatuses []api.TaskStatus) []string { + statusesAsStrings := make([]string, len(taskStatuses)) + for index := range taskStatuses { + statusesAsStrings[index] = string(taskStatuses[index]) + } + return statusesAsStrings +} diff --git a/internal/manager/persistence/sqlc/query_jobs.sql b/internal/manager/persistence/sqlc/query_jobs.sql index 231a2d5d..ff81a65d 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql +++ b/internal/manager/persistence/sqlc/query_jobs.sql @@ -128,6 +128,20 @@ UPDATE tasks SET activity = @activity WHERE id=@id; +-- name: UpdateJobsTaskStatusesConditional :exec +UPDATE tasks SET + updated_at = @updated_at, + status = @status, + activity = @activity +WHERE job_id = @job_id AND status in (sqlc.slice('statuses_to_update')); + +-- name: UpdateJobsTaskStatuses :exec +UPDATE tasks SET + updated_at = @updated_at, + status = @status, + activity = @activity +WHERE job_id = @job_id; + -- name: TaskAssignToWorker :exec UPDATE tasks SET updated_at = @updated_at, diff --git a/internal/manager/persistence/sqlc/query_jobs.sql.go b/internal/manager/persistence/sqlc/query_jobs.sql.go index 38f5f84f..ba2287f4 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql.go +++ b/internal/manager/persistence/sqlc/query_jobs.sql.go @@ -679,6 +679,66 @@ func (q *Queries) TaskAssignToWorker(ctx context.Context, arg TaskAssignToWorker return err } +const updateJobsTaskStatuses = `-- name: UpdateJobsTaskStatuses :exec +UPDATE tasks SET + updated_at = ?1, + status = ?2, + activity = ?3 +WHERE job_id = ?4 +` + +type UpdateJobsTaskStatusesParams struct { + UpdatedAt sql.NullTime + Status string + Activity string + JobID int64 +} + +func (q *Queries) UpdateJobsTaskStatuses(ctx context.Context, arg UpdateJobsTaskStatusesParams) error { + _, err := q.db.ExecContext(ctx, updateJobsTaskStatuses, + arg.UpdatedAt, + arg.Status, + arg.Activity, + arg.JobID, + ) + return err +} + +const updateJobsTaskStatusesConditional = `-- name: UpdateJobsTaskStatusesConditional :exec +UPDATE tasks SET + updated_at = ?1, + status = ?2, + activity = ?3 +WHERE job_id = ?4 AND status in (/*SLICE:statuses_to_update*/?) +` + +type UpdateJobsTaskStatusesConditionalParams struct { + UpdatedAt sql.NullTime + Status string + Activity string + JobID int64 + StatusesToUpdate []string +} + +func (q *Queries) UpdateJobsTaskStatusesConditional(ctx context.Context, arg UpdateJobsTaskStatusesConditionalParams) error { + query := updateJobsTaskStatusesConditional + var queryParams []interface{} + queryParams = append(queryParams, arg.UpdatedAt) + queryParams = append(queryParams, arg.Status) + queryParams = append(queryParams, arg.Activity) + queryParams = append(queryParams, arg.JobID) + if len(arg.StatusesToUpdate) > 0 { + for _, v := range arg.StatusesToUpdate { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:statuses_to_update*/?", strings.Repeat(",?", len(arg.StatusesToUpdate))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:statuses_to_update*/?", "NULL", 1) + } + _, err := q.db.ExecContext(ctx, query, queryParams...) + return err +} + const updateTask = `-- name: UpdateTask :exec UPDATE tasks SET updated_at = ?1,