Skip to content

slurm

execution hpc scheduler · maintainer machinable-org · upstream

Submit runs to the Slurm scheduler.

Install

Declare the remote in your project's provider; the module is fetched on first use:

python
# interface/project.py — your project's provider
from machinable import Project


class MyProject(Project):
    def on_resolve_remotes(self):
        return {
            "slurm": "url+https://raw.githubusercontent.com/machinable-org/machinable/5c73557505915231582d680ae26c6404b747e812/integrations/slurm/slurm.py",
        }

The module is fetched once and cached in interface/remotes/; machinable fetch slurm downloads it without importing so you can inspect the code first. See Resolving remotes for how pulling, pinning, and updating work.

Requires the Slurm command-line tools (sbatch, squeue, scancel) on the host.

Usage

python
from machinable import get

with get("slurm", {"ranks": 8, "preamble": "mpirun"}):
    ...  # your component

Configuration

OptionTypeDefault
preamblestr | None''
mpistr | None'mpirun'
mpi_argsstr''
pythonstr | NoneNone
throttlefloat0.5
confirmboolTrue
copy_project_sourceboolTrue
resume_failedbool | Literal['new', 'skip']False
dryboolFalse

Source

integrations/slurm/slurm.py

slurm.py
py
import os
import shutil
import subprocess
import sys
import time
from typing import Literal, Optional

import arrow
import yaml
from pydantic import BaseModel, ConfigDict

from machinable import Execution, Project
from machinable.errors import ExecutionFailed
from machinable.utils import chmodx, run_and_stream


class Slurm(Execution):
    class Config(BaseModel):
        model_config = ConfigDict(extra="forbid")

        preamble: str | None = ""
        mpi: str | None = "mpirun"
        mpi_args: str = ""
        python: str | None = None
        throttle: float = 0.5
        confirm: bool = True
        copy_project_source: bool = True
        resume_failed: bool | Literal["new", "skip"] = False
        dry: bool = False

    def on_before_dispatch(self):
        if self.config.confirm and not self.config.dry:
            return confirm(self)

    def on_compute_default_resources(self, executable):
        resources = {}
        resources["-p"] = "development"
        resources["-t"] = "2:00:00"
        if (nodes := executable.config.get("nodes", False)) not in [
            None,
            False,
        ]:
            resources["--nodes"] = nodes
        if (ranks := executable.config.get("ranks", False)) not in [
            None,
            False,
        ]:
            resources["--ntasks-per-node"] = ranks

        return resources

    def __call__(self):
        # materialize the container record so submission scripts have a
        # durable home (dispatch only materializes the interfaces)
        if not self.is_materialized():
            super(Execution, self).materialize()

        jobs = {}
        for executable in self.pending_executables:
            # check if job is already launched
            if job := Job.find_by_name(executable.id):
                if job.status in ["PENDING", "RUNNING"]:
                    print(
                        f"{executable.id} is already launched"
                        f" with job_id={job.job_id}, skipping ..."
                    )
                    continue

            if self.config.resume_failed is not True:
                if (
                    executable.executions.filter(lambda x: x.is_incomplete()).count()
                    > 0
                ):
                    if self.config.resume_failed == "new":
                        executable = executable.new().materialize()
                    elif self.config.resume_failed == "skip":
                        continue
                    else:
                        err = (
                            f"{executable.module} <{executable.id})>"
                            " has previously been executed unsuccessfully."
                            " Set `resume_failed` to True, 'new' or 'skip'"
                            " to handle resubmission."
                        )
                        if self.config.dry:
                            print(err)
                        else:
                            raise ExecutionFailed(err)

            source_code = Project.get().path()
            if self.config.copy_project_source:
                print("Copy project source code ...")
                source_code = self.local_directory(executable.id, "source_code")
                cmd = [
                    "rsync",
                    "-rLptgoD",
                    "--exclude",
                    ".git",
                    "--filter",
                    "dir-merge,- .gitignore",
                    Project.get().path(""),
                    source_code,
                ]
                print(" ".join(cmd))
                if not self.config.dry:
                    run_and_stream(cmd, check=True)
                else:
                    print("Dry run, skipping rsync ...")

            script = "#!/usr/bin/env bash\n"

            resources = self.computed_resources(executable)
            mpi = executable.config.get("mpi", self.config.mpi)
            mpi_args = self.config.mpi_args
            ranks = executable.config.get("ranks", None)
            if ranks is not None:
                if mpi_args:
                    mpi_args = mpi_args.replace("{ranks}", str(ranks))
            python = self.config.python or sys.executable

            # usage dependencies
            if "--dependency" not in resources and (dependencies := executable.uses):
                ds = []
                for dependency in dependencies:
                    if dependency.id in jobs:
                        ds.append(str(jobs[dependency.id]))
                    else:
                        if job := Job.find_by_name(dependency.id):
                            ds.append(str(job.job_id))
                if ds:
                    resources["--dependency"] = "afterok:" + (":".join(ds))

            if "--job-name" not in resources:
                resources["--job-name"] = f"{executable.id}"
            if "--output" not in resources:
                resources["--output"] = os.path.abspath(
                    self.local_directory(executable.id, "output.log")
                )
            if "--open-mode" not in resources:
                resources["--open-mode"] = "append"

            sbatch_arguments = []
            for k, v in resources.items():
                if not k.startswith("--"):
                    continue
                line = "#SBATCH " + k
                if v not in [None, True]:
                    line += f"={v}"
                sbatch_arguments.append(line)

            script += "\n".join(sbatch_arguments) + "\n\n"

            if self.config.preamble:
                script += self.config.preamble

            if mpi:
                if mpi[-1] != " ":
                    mpi += " "
                mpi = mpi + mpi_args
                if mpi[-1] != " ":
                    mpi += " "
                python = mpi + python

            script += self.dispatch_code(
                executable, project_directory=source_code, python=python
            )

            print(f"Submitting job {executable} with resources: ")
            print(yaml.dump(resources))

            # add debug information
            script += "\n\n"
            script += f"# generated at: {arrow.now()}\n"
            script += f"# {executable.module} <{executable.id}>\n"
            script += f"# {executable.local_directory()}\n\n"
            script += "# " + yaml.dump(executable.version()).replace("\n", "\n# ")
            script += "\n"

            # submit to slurm
            script_file = chmodx(self.save_file([executable.id, "slurm.sh"], script))

            cmd = ["sbatch", script_file]
            print(" ".join(cmd))

            self.save_file(
                [executable.id, "slurm.json"],
                data={
                    "job_id": None,
                    "cmd": sbatch_arguments,
                    "script": script,
                },
            )

            if self.config.dry:
                print("Dry run ... ", executable)
                continue

            try:
                output = subprocess.run(
                    cmd,
                    text=True,
                    check=True,
                    env=os.environ,
                    capture_output=True,
                )
                print(output.stdout)
            except subprocess.CalledProcessError as _ex:
                print(_ex.output)
                raise _ex

            try:
                job_id = int(output.stdout.rsplit(" ", maxsplit=1)[-1])
            except ValueError:
                job_id = False
            print(
                f"{job_id}  named `{resources['--job-name']}`"
                f" for {executable.local_directory()}"
                f" (output at {resources['--output']})"
            )

            # update job information
            jobs[executable.id] = job_id
            self.save_file(
                [executable.id, "slurm.json"],
                data={
                    "job_id": job_id,
                    "cmd": sbatch_arguments,
                    "script": script,
                },
            )

            if self.config.throttle > 0 and len(self.pending_executables) > 1:
                time.sleep(self.config.throttle)

    def canonicalize_resources(self, resources):
        if resources is None:
            return {}

        shorthands = {
            "A": "account",
            "B": "extra-node-info",
            "C": "constraint",
            "c": "cpus-per-task",
            "d": "dependency",
            "D": "workdir",
            "e": "error",
            "F": "nodefile",
            "H": "hold",
            "h": "help",
            "I": "immediate",
            "i": "input",
            "J": "job-name",
            "k": "no-kill",
            "L": "licenses",
            "M": "clusters",
            "m": "distribution",
            "N": "nodes",
            "n": "ntasks",
            "O": "overcommit",
            "o": "output",
            "p": "partition",
            "Q": "quiet",
            "s": "share",
            "t": "time",
            "u": "usage",
            "V": "version",
            "v": "verbose",
            "w": "nodelist",
            "x": "exclude",
            "g": "geometry",
            "R": "no-rotate",
        }

        canonicalized = {}
        for k, v in resources.items():
            prefix = ""
            if k.startswith("#"):
                prefix = "#"
                k = k[1:]

            if k.startswith("--"):
                # already correct
                canonicalized[prefix + k] = str(v)
                continue
            if k.startswith("-"):
                # -p => --partition
                try:
                    if len(k) != 2:
                        raise KeyError("Invalid length")
                    canonicalized[prefix + "--" + shorthands[k[1]]] = str(v)
                    continue
                except KeyError as _ex:
                    raise ValueError(f"Invalid short option: {k}") from _ex
            if len(k) == 1:
                # p => --partition
                try:
                    canonicalized[prefix + "--" + shorthands[k]] = str(v)
                    continue
                except KeyError as _ex:
                    raise ValueError(f"Invalid short option: -{k}") from _ex
            else:
                # option => --option
                canonicalized[prefix + "--" + k] = str(v)

        return canonicalized


def yes_or_no() -> bool:
    choice = input().lower()
    return {"": True, "yes": True, "y": True, "no": False, "n": False}[choice]


def confirm(execution: Execution) -> bool:
    sys.stdout.write("\n".join(execution.pending_executables.map(lambda x: x.module)))
    sys.stdout.write(
        f"\nSubmitting {len(execution.pending_executables)} jobs"
        f" ({len(execution.interfaces)} total). Proceed? [Y/n]: "
    )
    if yes_or_no():
        sys.stdout.write("yes\n")
        return True
    else:
        sys.stdout.write("no\n")
        return False


class Job:
    def __init__(self, job_id: str):
        self.job_id = job_id
        self.details = self._fetch_details()

    def _fetch_details(self) -> dict:
        cmd = ["scontrol", "show", "job", str(self.job_id)]
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            return None

        details = {}
        raw_info = result.stdout.split()
        for item in raw_info:
            if "=" in item:
                key, value = item.split("=", 1)
                details[key] = value
        return details

    @classmethod
    def find_by_name(cls, job_name: str) -> Optional["Job"]:
        if shutil.which("squeue") is None:
            return None  # no Slurm on this host (e.g. a dry run)
        cmd = ["squeue", "--name", job_name, "--noheader", "--format=%i"]
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode == 0 and result.stdout.strip():
            return cls(result.stdout.strip())

        return None

    @property
    def status(
        self,
    ) -> Literal[
        "",
        "PENDING",
        "RUNNING",
        "SUSPENDED",
        "CANCELLED",
        "COMPLETED",
        "FAILED",
        "TIMEOUT",
        "PREEMPTED",
    ]:
        return self.details.get("JobState", "")

    @property
    def info(self) -> dict:
        return self.details

    def cancel(self) -> bool:
        cmd = ["scancel", str(self.job_id)]
        result = subprocess.run(cmd, capture_output=True)
        return result.returncode == 0

MIT Licensed