Manager: convert task scheduler from gorm to sqlc

Convert the task scheduler from gorm to sqlc. This makes the query
considerably easier to read.

No functional changes intended.
This commit is contained in:
Sybren A. Stüvel 2024-06-30 21:17:13 +02:00
parent d86c97d06e
commit bfe47ea394
7 changed files with 444 additions and 125 deletions

@ -189,6 +189,37 @@ func (db *DB) queries() (*sqlc.Queries, error) {
return sqlc.New(&loggingWrapper), nil
}
type queriesTX struct {
queries *sqlc.Queries
commit func() error
rollback func() error
}
// queries returns the SQLC Queries struct, connected to this database.
// It is intended that all GORM queries will be migrated to use this interface
// instead.
func (db *DB) queriesWithTX() (*queriesTX, error) {
sqldb, err := db.gormDB.DB()
if err != nil {
return nil, fmt.Errorf("could not get low-level database driver: %w", err)
}
tx, err := sqldb.Begin()
if err != nil {
return nil, fmt.Errorf("could not begin database transaction: %w", err)
}
loggingWrapper := LoggingDBConn{tx}
qtx := queriesTX{
queries: sqlc.New(&loggingWrapper),
commit: tx.Commit,
rollback: tx.Rollback,
}
return &qtx, nil
}
// now returns the result of `nowFunc()` wrapped in a sql.NullTime.
func (db *DB) now() sql.NullTime {
return sql.NullTime{

@ -0,0 +1,53 @@
-- name: FetchAssignedAndRunnableTaskOfWorker :one
-- Fetch a task that's assigned to this worker, and is in a runnable state.
SELECT sqlc.embed(tasks)
FROM tasks
INNER JOIN jobs ON tasks.job_id = jobs.id
WHERE tasks.status=@active_task_status
AND tasks.worker_id=@worker_id
AND jobs.status IN (sqlc.slice('active_job_statuses'))
LIMIT 1;
-- name: FindRunnableTask :one
-- Find a task to be run by a worker. This is the core of the task scheduler.
--
-- Note that this query doesn't check for the assigned worker. Tasks that have a
-- 'schedulable' status might have been assigned to a worker, representing the
-- last worker to touch it -- it's not meant to indicate "ownership" of the
-- task.
--
-- The order in the WHERE clause is important, slices should come last. See
-- https://github.com/sqlc-dev/sqlc/issues/2452 for more info.
SELECT sqlc.embed(tasks)
FROM tasks
INNER JOIN jobs ON tasks.job_id = jobs.id
LEFT JOIN task_failures TF ON tasks.id = TF.task_id AND TF.worker_id=@worker_id
WHERE TF.worker_id IS NULL -- Not failed by this worker before.
AND tasks.id NOT IN (
-- Find all tasks IDs that have incomplete dependencies. These are not runnable.
SELECT tasks_incomplete.id
FROM tasks AS tasks_incomplete
INNER JOIN task_dependencies td ON tasks_incomplete.id = td.task_id
INNER JOIN tasks dep ON dep.id = td.dependency_id
WHERE dep.status != @task_status_completed
)
AND tasks.type NOT IN (
SELECT task_type
FROM job_blocks
WHERE job_blocks.worker_id = @worker_id
AND job_blocks.job_id = jobs.id
)
AND (
jobs.worker_tag_id IS NULL
OR jobs.worker_tag_id IN (sqlc.slice('worker_tags')))
AND tasks.status IN (sqlc.slice('schedulable_task_statuses'))
AND jobs.status IN (sqlc.slice('schedulable_job_statuses'))
AND tasks.type IN (sqlc.slice('supported_task_types'))
ORDER BY jobs.priority DESC, tasks.priority DESC;
-- name: AssignTaskToWorker :exec
UPDATE tasks
SET worker_id=@worker_id, last_touched_at=@now, updated_at=@now
WHERE tasks.id=@task_id;

@ -0,0 +1,191 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.26.0
// source: query_task_scheduler.sql
package sqlc
import (
"context"
"database/sql"
"strings"
)
const assignTaskToWorker = `-- name: AssignTaskToWorker :exec
UPDATE tasks
SET worker_id=?1, last_touched_at=?2, updated_at=?2
WHERE tasks.id=?3
`
type AssignTaskToWorkerParams struct {
WorkerID sql.NullInt64
Now sql.NullTime
TaskID int64
}
func (q *Queries) AssignTaskToWorker(ctx context.Context, arg AssignTaskToWorkerParams) error {
_, err := q.db.ExecContext(ctx, assignTaskToWorker, arg.WorkerID, arg.Now, arg.TaskID)
return err
}
const fetchAssignedAndRunnableTaskOfWorker = `-- name: FetchAssignedAndRunnableTaskOfWorker :one
SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity
FROM tasks
INNER JOIN jobs ON tasks.job_id = jobs.id
WHERE tasks.status=?1
AND tasks.worker_id=?2
AND jobs.status IN (/*SLICE:active_job_statuses*/?)
LIMIT 1
`
type FetchAssignedAndRunnableTaskOfWorkerParams struct {
ActiveTaskStatus string
WorkerID sql.NullInt64
ActiveJobStatuses []string
}
type FetchAssignedAndRunnableTaskOfWorkerRow struct {
Task Task
}
// Fetch a task that's assigned to this worker, and is in a runnable state.
func (q *Queries) FetchAssignedAndRunnableTaskOfWorker(ctx context.Context, arg FetchAssignedAndRunnableTaskOfWorkerParams) (FetchAssignedAndRunnableTaskOfWorkerRow, error) {
query := fetchAssignedAndRunnableTaskOfWorker
var queryParams []interface{}
queryParams = append(queryParams, arg.ActiveTaskStatus)
queryParams = append(queryParams, arg.WorkerID)
if len(arg.ActiveJobStatuses) > 0 {
for _, v := range arg.ActiveJobStatuses {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:active_job_statuses*/?", strings.Repeat(",?", len(arg.ActiveJobStatuses))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:active_job_statuses*/?", "NULL", 1)
}
row := q.db.QueryRowContext(ctx, query, queryParams...)
var i FetchAssignedAndRunnableTaskOfWorkerRow
err := row.Scan(
&i.Task.ID,
&i.Task.CreatedAt,
&i.Task.UpdatedAt,
&i.Task.UUID,
&i.Task.Name,
&i.Task.Type,
&i.Task.JobID,
&i.Task.Priority,
&i.Task.Status,
&i.Task.WorkerID,
&i.Task.LastTouchedAt,
&i.Task.Commands,
&i.Task.Activity,
)
return i, err
}
const findRunnableTask = `-- name: FindRunnableTask :one
SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity
FROM tasks
INNER JOIN jobs ON tasks.job_id = jobs.id
LEFT JOIN task_failures TF ON tasks.id = TF.task_id AND TF.worker_id=?1
WHERE TF.worker_id IS NULL -- Not failed by this worker before.
AND tasks.id NOT IN (
-- Find all tasks IDs that have incomplete dependencies. These are not runnable.
SELECT tasks_incomplete.id
FROM tasks AS tasks_incomplete
INNER JOIN task_dependencies td ON tasks_incomplete.id = td.task_id
INNER JOIN tasks dep ON dep.id = td.dependency_id
WHERE dep.status != ?2
)
AND tasks.type NOT IN (
SELECT task_type
FROM job_blocks
WHERE job_blocks.worker_id = ?1
AND job_blocks.job_id = jobs.id
)
AND (
jobs.worker_tag_id IS NULL
OR jobs.worker_tag_id IN (/*SLICE:worker_tags*/?))
AND tasks.status IN (/*SLICE:schedulable_task_statuses*/?)
AND jobs.status IN (/*SLICE:schedulable_job_statuses*/?)
AND tasks.type IN (/*SLICE:supported_task_types*/?)
ORDER BY jobs.priority DESC, tasks.priority DESC
`
type FindRunnableTaskParams struct {
WorkerID int64
TaskStatusCompleted string
WorkerTags []sql.NullInt64
SchedulableTaskStatuses []string
SchedulableJobStatuses []string
SupportedTaskTypes []string
}
type FindRunnableTaskRow struct {
Task Task
}
// Find a task to be run by a worker. This is the core of the task scheduler.
//
// Note that this query doesn't check for the assigned worker. Tasks that have a
// 'schedulable' status might have been assigned to a worker, representing the
// last worker to touch it -- it's not meant to indicate "ownership" of the
// task.
//
// The order in the WHERE clause is important, slices should come last. See
// https://github.com/sqlc-dev/sqlc/issues/2452 for more info.
func (q *Queries) FindRunnableTask(ctx context.Context, arg FindRunnableTaskParams) (FindRunnableTaskRow, error) {
query := findRunnableTask
var queryParams []interface{}
queryParams = append(queryParams, arg.WorkerID)
queryParams = append(queryParams, arg.TaskStatusCompleted)
if len(arg.WorkerTags) > 0 {
for _, v := range arg.WorkerTags {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:worker_tags*/?", strings.Repeat(",?", len(arg.WorkerTags))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:worker_tags*/?", "NULL", 1)
}
if len(arg.SchedulableTaskStatuses) > 0 {
for _, v := range arg.SchedulableTaskStatuses {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:schedulable_task_statuses*/?", strings.Repeat(",?", len(arg.SchedulableTaskStatuses))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:schedulable_task_statuses*/?", "NULL", 1)
}
if len(arg.SchedulableJobStatuses) > 0 {
for _, v := range arg.SchedulableJobStatuses {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:schedulable_job_statuses*/?", strings.Repeat(",?", len(arg.SchedulableJobStatuses))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:schedulable_job_statuses*/?", "NULL", 1)
}
if len(arg.SupportedTaskTypes) > 0 {
for _, v := range arg.SupportedTaskTypes {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:supported_task_types*/?", strings.Repeat(",?", len(arg.SupportedTaskTypes))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:supported_task_types*/?", "NULL", 1)
}
row := q.db.QueryRowContext(ctx, query, queryParams...)
var i FindRunnableTaskRow
err := row.Scan(
&i.Task.ID,
&i.Task.CreatedAt,
&i.Task.UpdatedAt,
&i.Task.UUID,
&i.Task.Name,
&i.Task.Type,
&i.Task.JobID,
&i.Task.Priority,
&i.Task.Status,
&i.Task.WorkerID,
&i.Task.LastTouchedAt,
&i.Task.Commands,
&i.Task.Activity,
)
return i, err
}

@ -49,6 +49,10 @@ SELECT * FROM workers WHERE workers.uuid = @uuid and deleted_at is NULL;
-- FetchWorkerUnconditional ignores soft-deletion status and just returns the worker.
SELECT * FROM workers WHERE workers.uuid = @uuid;
-- name: FetchWorkerUnconditionalByID :one
-- FetchWorkerUnconditional ignores soft-deletion status and just returns the worker.
SELECT * FROM workers WHERE workers.id = @worker_id;
-- name: FetchWorkerTags :many
SELECT worker_tags.*
FROM worker_tags

@ -196,6 +196,35 @@ func (q *Queries) FetchWorkerUnconditional(ctx context.Context, uuid string) (Wo
return i, err
}
const fetchWorkerUnconditionalByID = `-- name: FetchWorkerUnconditionalByID :one
SELECT id, created_at, updated_at, uuid, secret, name, address, platform, software, status, last_seen_at, status_requested, lazy_status_request, supported_task_types, deleted_at, can_restart FROM workers WHERE workers.id = ?1
`
// FetchWorkerUnconditional ignores soft-deletion status and just returns the worker.
func (q *Queries) FetchWorkerUnconditionalByID(ctx context.Context, workerID int64) (Worker, error) {
row := q.db.QueryRowContext(ctx, fetchWorkerUnconditionalByID, workerID)
var i Worker
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.UUID,
&i.Secret,
&i.Name,
&i.Address,
&i.Platform,
&i.Software,
&i.Status,
&i.LastSeenAt,
&i.StatusRequested,
&i.LazyStatusRequest,
&i.SupportedTaskTypes,
&i.DeletedAt,
&i.CanRestart,
)
return i, err
}
const fetchWorkers = `-- name: FetchWorkers :many
SELECT workers.id, workers.created_at, workers.updated_at, workers.uuid, workers.secret, workers.name, workers.address, workers.platform, workers.software, workers.status, workers.last_seen_at, workers.status_requested, workers.lazy_status_request, workers.supported_task_types, workers.deleted_at, workers.can_restart FROM workers
WHERE deleted_at IS NULL

@ -4,11 +4,15 @@ package persistence
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc"
"projects.blender.org/studio/flamenco/pkg/api"
)
@ -26,149 +30,139 @@ func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*Task, error) {
logger := log.With().Str("worker", w.UUID).Logger()
logger.Trace().Msg("finding task for worker")
hasWorkerTags, err := db.HasWorkerTags(ctx)
// Run all queries in a single transaction.
//
// After this point, all queries should use this transaction. Otherwise SQLite
// will deadlock, as it will make any other query wait until this transaction
// is done.
qtx, err := db.queriesWithTX()
if err != nil {
return nil, err
}
// Run two queries in one transaction:
// 1. find task, and
// 2. assign the task to the worker.
var task *Task
txErr := db.gormDB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var err error
task, err = findTaskForWorker(tx, w, hasWorkerTags)
if err != nil {
if isDatabaseBusyError(err) {
logger.Trace().Err(err).Msg("database busy while finding task for worker")
return errDatabaseBusy
}
logger.Error().Err(err).Msg("finding task for worker")
return fmt.Errorf("finding task for worker: %w", err)
}
if task == nil {
// No task found, which is fine.
return nil
}
defer qtx.rollback()
// Found a task, now assign it to the requesting worker.
if err := assignTaskToWorker(tx, w, task); err != nil {
if isDatabaseBusyError(err) {
logger.Trace().Err(err).Msg("database busy while assigning task to worker")
return errDatabaseBusy
}
logger.Warn().
Str("taskID", task.UUID).
Err(err).
Msg("assigning task to worker")
return fmt.Errorf("assigning task to worker: %w", err)
}
return nil
})
if txErr != nil {
return nil, txErr
task, err := db.scheduleTask(ctx, qtx.queries, w, logger)
if err != nil {
return nil, err
}
if task == nil {
logger.Debug().Msg("no task for worker")
// No task means no changes to the database.
// It's fine to just roll back the transaction.
return nil, nil
}
gormTask, err := convertSqlTaskWithJobAndWorker(ctx, qtx.queries, *task)
if err != nil {
return nil, err
}
if err := qtx.commit(); err != nil {
return nil, fmt.Errorf(
"could not commit database transaction after scheduling task %s for worker %s: %w",
task.UUID, w.UUID, err)
}
return gormTask, nil
}
func (db *DB) scheduleTask(ctx context.Context, queries *sqlc.Queries, w *Worker, logger zerolog.Logger) (*sqlc.Task, error) {
if w.ID == 0 {
panic("worker should be in database, but has zero ID")
}
workerID := sql.NullInt64{Int64: int64(w.ID), Valid: true}
// If a task is alreay active & assigned to this worker, return just that.
// Note that this task type could be blocklisted or no longer supported by the
// Worker, but since it's active that is unlikely.
{
row, err := queries.FetchAssignedAndRunnableTaskOfWorker(ctx, sqlc.FetchAssignedAndRunnableTaskOfWorkerParams{
ActiveTaskStatus: string(api.TaskStatusActive),
ActiveJobStatuses: convertJobStatuses(schedulableJobStatuses),
WorkerID: workerID,
})
switch {
case errors.Is(err, sql.ErrNoRows):
// Fine, just means there was no task assigned yet.
case err != nil:
return nil, err
case row.Task.ID > 0:
return &row.Task, nil
}
}
task, err := findTaskForWorker(ctx, queries, w)
switch {
case errors.Is(err, sql.ErrNoRows):
// Fine, just means there was no task assigned yet.
return nil, nil
case isDatabaseBusyError(err):
logger.Trace().Err(err).Msg("database busy while finding task for worker")
return nil, errDatabaseBusy
case err != nil:
logger.Error().Err(err).Msg("finding task for worker")
return nil, fmt.Errorf("finding task for worker: %w", err)
}
// Assign the task to the worker.
err = queries.AssignTaskToWorker(ctx, sqlc.AssignTaskToWorkerParams{
WorkerID: workerID,
Now: db.now(),
TaskID: task.ID,
})
switch {
case isDatabaseBusyError(err):
logger.Trace().Err(err).Msg("database busy while assigning task to worker")
return nil, errDatabaseBusy
case err != nil:
logger.Warn().
Str("taskID", task.UUID).
Err(err).
Msg("assigning task to worker")
return nil, fmt.Errorf("assigning task to worker: %w", err)
}
// Make sure the returned task matches the database.
task.WorkerID = workerID
logger.Info().
Str("taskID", task.UUID).
Msg("assigned task to worker")
return task, nil
}
func findTaskForWorker(tx *gorm.DB, w *Worker, checkWorkerTags bool) (*Task, error) {
task := Task{}
// If a task is alreay active & assigned to this worker, return just that.
// Note that this task type could be blocklisted or no longer supported by the
// Worker, but since it's active that is unlikely.
assignedTaskResult := taskAssignedAndRunnableQuery(tx.Model(&task), w).
Preload("Job").
Find(&task)
if assignedTaskResult.Error != nil {
return nil, assignedTaskResult.Error
}
if assignedTaskResult.RowsAffected > 0 {
return &task, nil
}
// Produce the 'current task ID' by selecting all its incomplete dependencies.
// This can then be used in a subquery to filter out such tasks.
// `tasks.id` is the task ID from the outer query.
incompleteDepsQuery := tx.Table("tasks as tasks2").
Select("tasks2.id").
Joins("left join task_dependencies td on tasks2.id = td.task_id").
Joins("left join tasks dep on dep.id = td.dependency_id").
Where("tasks2.id = tasks.id").
Where("dep.status is not NULL and dep.status != ?", api.TaskStatusCompleted)
blockedTaskTypesQuery := tx.Model(&JobBlock{}).
Select("job_blocks.task_type").
Where("job_blocks.worker_id = ?", w.ID).
Where("job_blocks.job_id = jobs.id")
// Note that this query doesn't check for the assigned worker. Tasks that have
// a 'schedulable' status might have been assigned to a worker, representing
// the last worker to touch it -- it's not meant to indicate "ownership" of
// the task.
findTaskQuery := tx.Model(&task).
Joins("left join jobs on tasks.job_id = jobs.id").
Joins("left join task_failures TF on tasks.id = TF.task_id and TF.worker_id=?", w.ID).
Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses
Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses
Where("tasks.type in ?", w.TaskTypes()). // Supported task types
Where("tasks.id not in (?)", incompleteDepsQuery). // Dependencies completed
Where("TF.worker_id is NULL"). // Not failed before
Where("tasks.type not in (?)", blockedTaskTypesQuery) // Non-blocklisted
if checkWorkerTags {
// The system has one or more tags, so limit the available jobs to those
// that have no tag, or overlap with the Worker's tags.
if len(w.Tags) == 0 {
// Tagless workers only get tagless jobs.
findTaskQuery = findTaskQuery.
Where("jobs.worker_tag_id is NULL")
} else {
// Taged workers get tagless jobs AND jobs of their own tags.
tagIDs := []uint{}
for _, tag := range w.Tags {
tagIDs = append(tagIDs, tag.ID)
}
findTaskQuery = findTaskQuery.
Where("jobs.worker_tag_id is NULL or worker_tag_id in ?", tagIDs)
}
}
findTaskResult := findTaskQuery.
Order("jobs.priority desc"). // Highest job priority
Order("tasks.priority desc"). // Highest task priority
Limit(1).
Preload("Job").
Find(&task)
if findTaskResult.Error != nil {
return nil, findTaskResult.Error
}
if task.ID == 0 {
// No task fetched, which doesn't result in an error with Limt(1).Find(&task).
return nil, nil
}
return &task, nil
}
func assignTaskToWorker(tx *gorm.DB, w *Worker, t *Task) error {
return tx.Model(t).
Select("WorkerID", "LastTouchedAt").
Updates(Task{WorkerID: &w.ID, LastTouchedAt: tx.NowFunc()}).Error
func findTaskForWorker(
ctx context.Context,
queries *sqlc.Queries,
w *Worker,
) (sqlc.Task, error) {
// Construct the list of worker tags to check.
workerTags := make([]sql.NullInt64, len(w.Tags))
for index, tag := range w.Tags {
workerTags[index] = sql.NullInt64{Int64: int64(tag.ID), Valid: true}
}
row, err := queries.FindRunnableTask(ctx, sqlc.FindRunnableTaskParams{
WorkerID: int64(w.ID),
SchedulableTaskStatuses: convertTaskStatuses(schedulableTaskStatuses),
SchedulableJobStatuses: convertJobStatuses(schedulableJobStatuses),
SupportedTaskTypes: w.TaskTypes(),
TaskStatusCompleted: string(api.TaskStatusCompleted),
WorkerTags: workerTags,
})
if err != nil {
return sqlc.Task{}, err
}
if row.Task.ID == 0 {
return sqlc.Task{}, nil
}
return row.Task, nil
}
// taskAssignedAndRunnableQuery appends some GORM clauses to query for a task

@ -34,3 +34,20 @@ sql:
jobuuid: "JobUUID"
taskUUID: "TaskUUID"
workeruuid: "WorkerUUID"
- engine: "sqlite"
schema: "internal/manager/persistence/sqlc/schema.sql"
queries: "internal/manager/persistence/sqlc/query_task_scheduler.sql"
gen:
go:
out: "internal/manager/persistence/sqlc"
overrides:
- db_type: "jsonb"
go_type:
import: "encoding/json"
type: "RawMessage"
rename:
uuid: "UUID"
uuids: "UUIDs"
jobuuid: "JobUUID"
taskUUID: "TaskUUID"
workeruuid: "WorkerUUID"