From 4ab853da40fb36cb1b87bd71961397de4f73a3d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Mon, 20 May 2024 21:18:19 +0200 Subject: [PATCH] Manager: Convert JobHasTasksInStatus and CountTasksOfJobInStatus to sqlc No functional changes. --- internal/manager/persistence/jobs.go | 54 ++++++++++--------- internal/manager/persistence/jobs_test.go | 6 +++ .../manager/persistence/sqlc/query_jobs.sql | 11 ++++ .../persistence/sqlc/query_jobs.sql.go | 53 ++++++++++++++++++ 4 files changed, 100 insertions(+), 24 deletions(-) diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 4c761a87..bcff8dbb 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -715,38 +715,44 @@ func (db *DB) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *Worke } func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) { - var numTasksInStatus int64 - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Where("job_id", job.ID). - Where("status", taskStatus). - Count(&numTasksInStatus) - if tx.Error != nil { - return false, taskError(tx.Error, "counting tasks of job %s in status %q", job.UUID, taskStatus) + queries, err := db.queries() + if err != nil { + return false, err } - return numTasksInStatus > 0, nil + + count, err := queries.JobCountTasksInStatus(ctx, sqlc.JobCountTasksInStatusParams{ + JobID: int64(job.ID), + TaskStatus: string(taskStatus), + }) + if err != nil { + return false, taskError(err, "counting tasks of job %s in status %q", job.UUID, taskStatus) + } + + return count > 0, nil } +// CountTasksOfJobInStatus counts the number of tasks in the job. +// It returns two counts, one is the number of tasks in the given statuses, the +// other is the total number of tasks of the job. func (db *DB) CountTasksOfJobInStatus( ctx context.Context, job *Job, taskStatuses ...api.TaskStatus, ) (numInStatus, numTotal int, err error) { - type Result struct { - Status api.TaskStatus - NumTasks int + queries, err := db.queries() + if err != nil { + return 0, 0, err } - var results []Result - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Select("status, count(*) as num_tasks"). - Where("job_id", job.ID). - Group("status"). - Scan(&results) + // Convert from []api.TaskStatus to []string for feeding to sqlc. + statusesAsStrings := make([]string, len(taskStatuses)) + for index := range taskStatuses { + statusesAsStrings[index] = string(taskStatuses[index]) + } - if tx.Error != nil { - return 0, 0, jobError(tx.Error, "count tasks of job %s in status %q", job.UUID, taskStatuses) + results, err := queries.JobCountTaskStatuses(ctx, int64(job.ID)) + if err != nil { + return 0, 0, jobError(err, "count tasks of job %s in status %q", job.UUID, taskStatuses) } // Create lookup table for which statuses to count. @@ -757,10 +763,10 @@ func (db *DB) CountTasksOfJobInStatus( // Count the number of tasks per status. for _, result := range results { - if countStatus[result.Status] { - numInStatus += result.NumTasks + if countStatus[api.TaskStatus(result.Status)] { + numInStatus += int(result.NumTasks) } - numTotal += result.NumTasks + numTotal += int(result.NumTasks) } return diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index 20a0b1d5..af338f25 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -396,6 +396,12 @@ func TestCountTasksOfJobInStatus(t *testing.T) { require.NoError(t, err) assert.Equal(t, 0, numActive) assert.Equal(t, 3, numTotal) + + numCounted, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, + api.TaskStatusFailed, api.TaskStatusQueued) + require.NoError(t, err) + assert.Equal(t, 3, numCounted) + assert.Equal(t, 3, numTotal) } func TestCheckIfJobsHoldLargeNumOfTasks(t *testing.T) { diff --git a/internal/manager/persistence/sqlc/query_jobs.sql b/internal/manager/persistence/sqlc/query_jobs.sql index 21b7e442..aa02ebb4 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql +++ b/internal/manager/persistence/sqlc/query_jobs.sql @@ -120,3 +120,14 @@ UPDATE tasks SET updated_at = @updated_at, worker_id = @worker_id WHERE id=@id; + +-- name: JobCountTasksInStatus :one +-- Fetch number of tasks in the given status, of the given job. +SELECT count(*) as num_tasks FROM tasks +WHERE job_id = @job_id AND status = @task_status; + +-- name: JobCountTaskStatuses :many +-- Fetch (status, num tasks in that status) rows for the given job. +SELECT status, count(*) as num_tasks FROM tasks +WHERE job_id = @job_id +GROUP BY status; diff --git a/internal/manager/persistence/sqlc/query_jobs.sql.go b/internal/manager/persistence/sqlc/query_jobs.sql.go index 09811f0e..5adb0878 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql.go +++ b/internal/manager/persistence/sqlc/query_jobs.sql.go @@ -396,6 +396,59 @@ func (q *Queries) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, arg Fetch return items, nil } +const jobCountTaskStatuses = `-- name: JobCountTaskStatuses :many +SELECT status, count(*) as num_tasks FROM tasks +WHERE job_id = ?1 +GROUP BY status +` + +type JobCountTaskStatusesRow struct { + Status string + NumTasks int64 +} + +// Fetch (status, num tasks in that status) rows for the given job. +func (q *Queries) JobCountTaskStatuses(ctx context.Context, jobID int64) ([]JobCountTaskStatusesRow, error) { + rows, err := q.db.QueryContext(ctx, jobCountTaskStatuses, jobID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []JobCountTaskStatusesRow + for rows.Next() { + var i JobCountTaskStatusesRow + if err := rows.Scan(&i.Status, &i.NumTasks); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobCountTasksInStatus = `-- name: JobCountTasksInStatus :one +SELECT count(*) as num_tasks FROM tasks +WHERE job_id = ?1 AND status = ?2 +` + +type JobCountTasksInStatusParams struct { + JobID int64 + TaskStatus string +} + +// Fetch number of tasks in the given status, of the given job. +func (q *Queries) JobCountTasksInStatus(ctx context.Context, arg JobCountTasksInStatusParams) (int64, error) { + row := q.db.QueryRowContext(ctx, jobCountTasksInStatus, arg.JobID, arg.TaskStatus) + var num_tasks int64 + err := row.Scan(&num_tasks) + return num_tasks, err +} + const requestJobDeletion = `-- name: RequestJobDeletion :exec UPDATE jobs SET updated_at = ?1,