Slurm execution
Integration to submit to the Slurm scheduler.
Usage example
python
from machinable import get
with get("slurm", {"ranks": 8, 'preamble': 'mpirun'}):
... # your component
Source
py
from typing import Literal, Optional, Union
import os
import subprocess
import sys
import time
import arrow
import yaml
from machinable import Execution, Project
from machinable.errors import ExecutionFailed
from machinable.utils import chmodx, run_and_stream
from pydantic import BaseModel, ConfigDict
class Slurm(Execution):
class Config(BaseModel):
model_config = ConfigDict(extra="forbid")
preamble: Optional[str] = ""
mpi: Optional[str] = "mpirun"
mpi_args: str = ""
python: Optional[str] = None
throttle: float = 0.5
confirm: bool = True
copy_project_source: bool = True
resume_failed: Union[bool, Literal["new", "skip"]] = False
dry: bool = False
def on_before_dispatch(self):
if self.config.confirm and not self.config.dry:
return confirm(self)
def on_compute_default_resources(self, executable):
resources = {}
resources["-p"] = "development"
resources["-t"] = "2:00:00"
if (nodes := executable.config.get("nodes", False)) not in [
None,
False,
]:
resources["--nodes"] = nodes
if (ranks := executable.config.get("ranks", False)) not in [
None,
False,
]:
resources["--ntasks-per-node"] = ranks
return resources
def __call__(self):
jobs = {}
for executable in self.pending_executables:
# check if job is already launched
if job := Job.find_by_name(executable.id):
if job.status in ["PENDING", "RUNNING"]:
print(
f"{executable.id} is already launched with job_id={job.job_id}, skipping ..."
)
continue
if self.config.resume_failed is not True:
if (
executable.executions.filter(
lambda x: x.is_incomplete(executable)
).count()
> 0
):
if self.config.resume_failed == "new":
executable = executable.new().commit()
elif self.config.resume_failed == "skip":
continue
else:
err = f"{executable.module} <{executable.id})> has previously been executed unsuccessfully. Set `resume_failed` to True, 'new' or 'skip' to handle resubmission."
if self.config.dry:
print(err)
else:
raise ExecutionFailed(err)
source_code = Project.get().path()
if self.config.copy_project_source:
print("Copy project source code ...")
source_code = self.local_directory(executable.id, "source_code")
cmd = [
"rsync",
"-rLptgoD",
"--exclude '.git'",
"--filter='dir-merge,- .gitignore'",
Project.get().path(""),
source_code,
]
print(" ".join(cmd))
if not self.config.dry:
run_and_stream(cmd, check=True)
else:
print("Dry run, skipping rsync ...")
script = "#!/usr/bin/env bash\n"
resources = self.computed_resources(executable)
mpi = executable.config.get("mpi", self.config.mpi)
mpi_args = self.config.mpi_args
ranks = executable.config.get("ranks", None)
if ranks is not None:
if mpi_args:
mpi_args = mpi_args.replace("{ranks}", str(ranks))
python = self.config.python or sys.executable
# usage dependencies
if "--dependency" not in resources and (
dependencies := executable.uses
):
ds = []
for dependency in dependencies:
if dependency.id in jobs:
ds.append(str(jobs[dependency.id]))
else:
if job := Job.find_by_name(dependency.id):
ds.append(str(job.job_id))
if ds:
resources["--dependency"] = "afterok:" + (":".join(ds))
if "--job-name" not in resources:
resources["--job-name"] = f"{executable.id}"
if "--output" not in resources:
resources["--output"] = os.path.abspath(
self.local_directory(executable.id, "output.log")
)
if "--open-mode" not in resources:
resources["--open-mode"] = "append"
sbatch_arguments = []
for k, v in resources.items():
if not k.startswith("--"):
continue
line = "#SBATCH " + k
if v not in [None, True]:
line += f"={v}"
sbatch_arguments.append(line)
script += "\n".join(sbatch_arguments) + "\n\n"
if self.config.preamble:
script += self.config.preamble
if mpi:
if mpi[-1] != " ":
mpi += " "
mpi = mpi + mpi_args
if mpi[-1] != " ":
mpi += " "
python = mpi + python
script += executable.dispatch_code(
project_directory=source_code, python=python
)
print(f"Submitting job {executable} with resources: ")
print(yaml.dump(resources))
# add debug information
script += "\n\n"
script += f"# generated at: {arrow.now()}\n"
script += f"# {executable.module} <{executable.id}>\n"
script += f"# {executable.local_directory()}\n\n"
script += "# " + yaml.dump(executable.version()).replace(
"\n", "\n# "
)
script += "\n"
# submit to slurm
script_file = chmodx(
self.save_file([executable.id, "slurm.sh"], script)
)
cmd = ["sbatch", script_file]
print(" ".join(cmd))
self.save_file(
[executable.id, "slurm.json"],
data={
"job_id": None,
"cmd": sbatch_arguments,
"script": script,
},
)
if self.config.dry:
print("Dry run ... ", executable)
continue
try:
output = subprocess.run(
cmd,
text=True,
check=True,
env=os.environ,
capture_output=True,
)
print(output.stdout)
except subprocess.CalledProcessError as _ex:
print(_ex.output)
raise _ex
try:
job_id = int(output.stdout.rsplit(" ", maxsplit=1)[-1])
except ValueError:
job_id = False
print(
f"{job_id} named `{resources['--job-name']}` for {executable.local_directory()} (output at {resources['--output']})"
)
# update job information
jobs[executable.id] = job_id
self.save_file(
[executable.id, "slurm.json"],
data={
"job_id": job_id,
"cmd": sbatch_arguments,
"script": script,
},
)
if self.config.throttle > 0 and len(self.pending_executables) > 1:
time.sleep(self.config.throttle)
def canonicalize_resources(self, resources):
if resources is None:
return {}
shorthands = {
"A": "account",
"B": "extra-node-info",
"C": "constraint",
"c": "cpus-per-task",
"d": "dependency",
"D": "workdir",
"e": "error",
"F": "nodefile",
"H": "hold",
"h": "help",
"I": "immediate",
"i": "input",
"J": "job-name",
"k": "no-kill",
"L": "licenses",
"M": "clusters",
"m": "distribution",
"N": "nodes",
"n": "ntasks",
"O": "overcommit",
"o": "output",
"p": "partition",
"Q": "quiet",
"s": "share",
"t": "time",
"u": "usage",
"V": "version",
"v": "verbose",
"w": "nodelist",
"x": "exclude",
"g": "geometry",
"R": "no-rotate",
}
canonicalized = {}
for k, v in resources.items():
prefix = ""
if k.startswith("#"):
prefix = "#"
k = k[1:]
if k.startswith("--"):
# already correct
canonicalized[prefix + k] = str(v)
continue
if k.startswith("-"):
# -p => --partition
try:
if len(k) != 2:
raise KeyError("Invalid length")
canonicalized[prefix + "--" + shorthands[k[1]]] = str(v)
continue
except KeyError as _ex:
raise ValueError(f"Invalid short option: {k}") from _ex
if len(k) == 1:
# p => --partition
try:
canonicalized[prefix + "--" + shorthands[k]] = str(v)
continue
except KeyError as _ex:
raise ValueError(f"Invalid short option: -{k}") from _ex
else:
# option => --option
canonicalized[prefix + "--" + k] = str(v)
return canonicalized
def yes_or_no() -> bool:
choice = input().lower()
return {"": True, "yes": True, "y": True, "no": False, "n": False}[choice]
def confirm(execution: Execution) -> bool:
sys.stdout.write(
"\n".join(execution.pending_executables.map(lambda x: x.module))
)
sys.stdout.write(
f"\nSubmitting {len(execution.pending_executables)} jobs ({len(execution.executables)} total). Proceed? [Y/n]: "
)
if yes_or_no():
sys.stdout.write("yes\n")
return True
else:
sys.stdout.write("no\n")
return False
class Job:
def __init__(self, job_id: str):
self.job_id = job_id
self.details = self._fetch_details()
def _fetch_details(self) -> dict:
cmd = ["scontrol", "show", "job", str(self.job_id)]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
return None
details = {}
raw_info = result.stdout.split()
for item in raw_info:
if "=" in item:
key, value = item.split("=", 1)
details[key] = value
return details
@classmethod
def find_by_name(cls, job_name: str) -> Optional["Job"]:
cmd = ["squeue", "--name", job_name, "--noheader", "--format=%i"]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0 and result.stdout.strip():
return cls(result.stdout.strip())
return None
@property
def status(
self,
) -> Literal[
"",
"PENDING",
"RUNNING",
"SUSPENDED",
"CANCELLED",
"COMPLETED",
"FAILED",
"TIMEOUT",
"PREEMPTED",
]:
return self.details.get("JobState", "")
@property
def info(self) -> dict:
return self.details
def cancel(self) -> bool:
cmd = ["scancel", str(self.job_id)]
result = subprocess.run(cmd, capture_output=True)
return result.returncode == 0