import fcntl
import io
import logging
import os
import random
import re
import select
import shlex
import sys
import threading
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from subprocess import STDOUT
from typing import (
IO,
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from .context import Context
from .util import FatalError, Process, require_program, run
# TODO: rewrite this to use
# https://docs.python.org/3/library/concurrent.futures.html?
@dataclass
class Job:
proc: Process
jobid: str
outfiles: List[str]
nnodes: int = field(default=1, init=False)
start_time: float = field(default_factory=time.time, init=False)
onsuccess: Optional[Callable[["Job"], None]] = field(default=None, init=False)
onerror: Optional[Callable[["Job"], None]] = field(default=None, init=False)
output: str = field(default="", init=False)
@property
def stdout(self) -> IO:
return self.proc.stdout_io
@dataclass
class ProcessJob(Job):
outfile_handle: IO
@dataclass
class SSHJob(Job):
outfile_handle: IO
node: str
tunnel_src: Optional[int] = None
tunnel_dest: Optional[int] = None
@dataclass
class PrunJob(Job):
nnodes: int
logged: bool = False
[docs]class Pool(metaclass=ABCMeta):
"""
A pool is used to run processes in parallel as jobs when ``--parallel`` is
specified on the command line. The pool is created automatically by
:class:`Setup` and passed to :func:`Target.build` and :func:`Target.run`.
However, the pool is only passed if the method implementation defines a
parameter for the pool, i.e.::
class MyTarget(Target):
def build(self, ctx, instance, pool): # receives Pool instance
...
def run(self, ctx, instance): # does not receive it
...
The maximum number of parallel jobs is controlled by ``--parallelmax``. For
``--parallel=proc`` this is simply the number of parallel processes on the
current machine. For ``--parallel=prun`` it is the maximum number of
simultaneous jobs in the job queue (pending or running).
"""
poll_interval: float = 0.050 # seconds to wait for blocking actions
jobs: Dict[int, Job]
pollthread: Optional[threading.Thread]
@abstractmethod
def make_jobs(
self,
ctx: Context,
cmd: Union[str, Iterable[str]],
jobid_base: str,
outfile_base: str,
nnodes: int,
**kwargs: Any,
) -> Iterator[Job]:
pass
@abstractmethod
def process_job_output(self, job: Job) -> None:
pass
def __init__(self, logger: logging.Logger, parallelmax: int):
"""
:param logger: logging object for status updates (set to ``ctx.log``)
:param parallelmax: value of ``--parallelmax``
"""
self.log = logger
self.parallelmax = parallelmax
self.jobs = {}
self.pollthread = None
def __del__(self) -> None:
if self.pollthread is not None:
self.done = True
self.pollthread.join(self.poll_interval)
def _start_poller(self) -> None:
if self.pollthread is None:
self.poller = select.epoll()
self.pollthread = threading.Thread(
target=self._poller_thread, name="pool-poller"
)
self.pollthread.daemon = True
self.done = False
self.pollthread.start()
def _poller_thread(self) -> None:
# monitor the job queue for finished jobs, remove them from the queue
# and call success/error callbacks
while not self.done:
for fd, flags in self.poller.poll(timeout=self.poll_interval):
if flags & (select.EPOLLIN | select.EPOLLPRI):
self.process_job_output(self.jobs[fd])
if flags & select.EPOLLERR:
self.poller.unregister(fd)
job = self.jobs.pop(fd)
self.onerror(job)
if flags & select.EPOLLHUP:
job = self.jobs[fd]
if job.proc.poll() is None:
self.log.debug(
f"job {job.jobid} hung up but does not yet have a "
"return code, check later"
)
continue
self.poller.unregister(fd)
del self.jobs[fd]
if job.proc.poll() == 0:
self.onsuccess(job)
else:
self.onerror(job)
def _wait_for_queue_space(self, nodes_needed: int) -> None:
if self.parallelmax is not None:
def nodes_in_use() -> int:
return sum(job.nnodes for job in self.jobs.values())
while nodes_in_use() + nodes_needed > self.parallelmax:
time.sleep(self.poll_interval)
[docs] def wait_all(self) -> None:
"""
Block (busy-wait) until all jobs in the queue have been completed.
Called automatically by :class:`Setup` after the ``build`` and ``run``
commands.
"""
while len(self.jobs):
time.sleep(self.poll_interval)
[docs] def run(
self,
ctx: Context,
cmd: Union[str, Iterable[str]],
jobid: str,
outfile: str,
nnodes: int,
onsuccess: Optional[Callable[[Job], None]] = None,
onerror: Optional[Callable[[Job], None]] = None,
**kwargs: Any,
) -> Iterable[Job]:
"""
A non-blocking wrapper for :func:`util.run`, to be used when
``--parallel`` is specified.
:param ctx: the configuration context
:param cmd: the command to run
:param jobid: a human-readable ID for status reporting
:param outfile: full path to target file for command output
:param nnodes: number of cores or machines to run the command on
:param onsuccess: callback when the job finishes successfully
:param onerror: callback when the job exits with (typically I/O) error
:param kwargs: passed directly to :func:`util.run`
:returns: handles to created job processes
"""
# TODO: generate outfile from jobid
self._start_poller()
if isinstance(cmd, str):
cmd = shlex.split(cmd)
jobs = []
for job in self.make_jobs(ctx, cmd, jobid, outfile, nnodes, **kwargs):
job.onsuccess = onsuccess
job.onerror = onerror
job.output = ""
self.jobs[job.proc.stdout_io.fileno()] = job
self.poller.register(
job.proc.stdout_io,
select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP,
)
jobs.append(job)
return jobs
def onsuccess(self, job: Job) -> None:
# don't log if onsuccess() returns False
if not job.onsuccess or job.onsuccess(job) is not False:
self.log.info(f"job {job.jobid} finished {self._get_elapsed(job)}")
self.log.debug(f"command: {job.proc.cmd_str}")
def onerror(self, job: Job) -> None:
# don't log if onerror() returns False
if not job.onerror or job.onerror(job) is not False:
self.log.error(
f"job {job.jobid} returned status {job.proc.returncode} "
f"{self._get_elapsed(job)}"
)
self.log.error(f"command: {job.proc.cmd_str}")
sys.stdout.write(job.output)
def _get_elapsed(self, job: Job) -> str:
elapsed = round(time.time() - job.start_time)
return f"after {elapsed} seconds"
class ProcessPool(Pool):
def make_jobs(
self,
ctx: Context,
cmd: Union[str, Iterable[str]],
jobid_base: str,
outfile_base: str,
nnodes: int,
**kwargs: Any,
) -> Iterator[Job]:
for i in range(nnodes):
jobid = jobid_base
outfile = outfile_base
if nnodes > 1:
jobid += f"-{i}"
outfile += f"-{i}"
self._wait_for_queue_space(1)
ctx.log.info("running " + jobid)
proc = run(
ctx,
cmd,
defer=True,
stderr=STDOUT,
bufsize=io.DEFAULT_BUFFER_SIZE,
universal_newlines=False,
**kwargs,
)
_set_non_blocking(proc.stdout_io)
os.makedirs(os.path.dirname(outfile), exist_ok=True)
job = ProcessJob(proc, jobid, [outfile], open(outfile, "wb"))
yield job
def process_job_output(self, job: Job) -> None:
assert isinstance(job, ProcessJob)
buf = job.stdout.read(io.DEFAULT_BUFFER_SIZE)
if buf is not None:
job.output += buf.decode("ascii", errors="replace")
job.outfile_handle.write(buf)
def onsuccess(self, job: Job) -> None:
assert isinstance(job, ProcessJob)
job.outfile_handle.close()
super().onsuccess(job)
def onerror(self, job: Job) -> None:
assert isinstance(job, ProcessJob)
job.outfile_handle.close()
super().onerror(job)
class SSHPool(Pool):
"""
An SSHPool runs jobs on remote nodes via ssh.
The --ssh-nodes argument specified a list of ssh hosts to distribute the
work over. These hosts are passed as-is to the ssh command; the best way for
specifying alternative ssh ports, user, and other options is to add your
hosts to the ~/.ssh/config file. Additionally, make sure the hosts can be
reached without password prompts (e.g., by using passphrase-less keys or
using an ssh agent).
For targets that are being run via an SSHPool additional functionality is
available, such as distributing files to/from nodes.
"""
ssh_opts = [
# Block stdin and background ssh before executing command.
"-f",
# Eliminate some of the yes/no questions ssh may ask.
"-oStrictHostKeyChecking=accept-new",
]
scp_opts = [
# Quiet mode to disable progress meter
"-q",
# Batch mode to prevent asking for password
"-B",
# Copy directories
"-r",
]
_tempdir: Optional[str]
def __init__(
self, ctx: Context, logger: logging.Logger, parallelmax: int, nodes: List[str]
):
if parallelmax > len(nodes):
raise FatalError(
"parallelmax cannot be greater than number of available nodes"
)
super().__init__(logger, parallelmax)
self._ctx = ctx
self.nodes = nodes[:]
self.available_nodes = nodes[:]
self.has_tested_nodes = False
self.has_created_tempdirs = False
@property
def tempdir(self) -> str:
if not self.has_created_tempdirs:
self.create_tempdirs()
assert self._tempdir is not None
return self._tempdir
def _ssh_cmd(
self,
node: str,
cmd: Union[str, Iterable[str]],
extra_opts: Optional[Sequence[Any]] = None,
) -> List[str]:
if not isinstance(cmd, str):
cmd = " ".join(shlex.quote(str(c)) for c in cmd)
extra_opts = extra_opts or []
return ["ssh", *self.ssh_opts, *extra_opts, node, cmd]
def test_nodes(self) -> None:
if self.has_tested_nodes:
return
for node in self.nodes:
cmd = ["ssh", *self.ssh_opts, node, "echo -n hi"]
p = run(self._ctx, cmd, stderr=STDOUT, silent=True)
if p.returncode or not str(p.stdout).endswith("hi"):
self._ctx.log.error(
"Testing SSH node " + node + " failed:\n" + p.stdout
)
sys.exit(-1)
self.has_tested_nodes = True
def create_tempdirs(self) -> None:
if self.has_created_tempdirs:
return
self.test_nodes()
starttime = self._ctx.starttime.strftime("%Y-%m-%d.%H-%M-%S")
self._tempdir = os.path.join("/tmp", "infra-" + starttime)
self._ctx.log.debug(
f"creating SSHPool temp dir {self._tempdir} on nodes {self.nodes}"
)
for node in self.nodes:
run(self._ctx, self._ssh_cmd(node, ["mkdir", "-p", self._tempdir]))
self.has_created_tempdirs = True
def cleanup_tempdirs(self) -> None:
if not self.has_created_tempdirs:
return
assert self._tempdir is not None
self._ctx.log.debug(
f"cleaning up SSHPool temp directory {self._tempdir} on nodes {self.nodes}"
)
for node in self.nodes:
run(self._ctx, self._ssh_cmd(node, ["rm", "-rf", self._tempdir]))
self.has_created_tempdirs = False
self._tempdir = None
def sync_to_nodes(
self,
sources: Union[str, Iterable[str]],
destination: str = "",
target_nodes: Optional[Union[str, Iterable[str]]] = None,
) -> None:
if isinstance(sources, str):
sources = [sources]
if isinstance(target_nodes, str):
target_nodes = [target_nodes]
nodes = target_nodes or self.nodes
self._ctx.log.debug(
f"syncing file to SSHPool nodes, sources={sources},"
f"destination={destination}, nodes={nodes}"
)
for node in nodes:
dest = f"{node}:{os.path.join(self.tempdir, destination)}"
cmd = ["scp", *self.scp_opts, *sources, dest]
run(self._ctx, cmd)
def sync_from_nodes(
self,
source: str,
destination: str = "",
source_nodes: Optional[Sequence[str]] = None,
) -> None:
if isinstance(source_nodes, str):
source_nodes = [source_nodes]
nodes = source_nodes or self.nodes
self._ctx.log.debug(
f"syncing file from SSHPool nodes, source={source},"
f"destination={destination}, nodes={nodes}"
)
for i, node in enumerate(nodes):
dest = destination or os.path.basename(source)
if len(nodes) > 1:
dest += "." + node
if len(nodes) != len(set(nodes)):
dest = f"{dest}{i}"
src = f"{node}:{os.path.join(self.tempdir, source)}"
cmd = ["scp", *self.scp_opts, src, dest]
run(self._ctx, cmd)
def get_free_node(self, override_node: Optional[str] = None) -> str:
if override_node:
assert override_node in self.nodes
assert override_node in self.available_nodes
self.available_nodes.remove(override_node)
return override_node
else:
return self.available_nodes.pop()
def make_jobs(
self,
ctx: Context,
cmd: Union[str, Iterable[str]],
jobid_base: str,
outfile_base: str,
nnodes: int,
nodes: Optional[Union[str, List[str]]] = None,
tunnel_to_nodes_dest: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Job]:
if isinstance(nodes, str):
nodes = [nodes]
self.test_nodes()
for i in range(nnodes):
jobid = jobid_base
outfile = outfile_base
if nnodes > 1:
jobid += f"-{i}"
outfile += f"-{i}"
self._wait_for_queue_space(1)
override_node = nodes[i] if nodes else None
node = self.get_free_node(override_node)
ctx.log.info("running " + jobid + " on " + node)
ssh_node_opts = []
tunnel_src = None
if tunnel_to_nodes_dest:
tunnel_src = random.randint(10000, 30000)
ssh_node_opts += [
f"-Llocalhost:{tunnel_src}:0.0.0.0:{tunnel_to_nodes_dest}"
]
ssh_cmd = self._ssh_cmd(node, cmd, ssh_node_opts)
proc = run(
ctx,
ssh_cmd,
defer=True,
stderr=STDOUT,
bufsize=io.DEFAULT_BUFFER_SIZE,
universal_newlines=False,
**kwargs,
)
_set_non_blocking(proc.stdout_io)
os.makedirs(os.path.dirname(outfile), exist_ok=True)
job = SSHJob(proc, jobid, [outfile], open(outfile, "wb"), node)
if tunnel_to_nodes_dest:
job.tunnel_src = tunnel_src
job.tunnel_dest = tunnel_to_nodes_dest
yield job
def process_job_output(self, job: Job) -> None:
assert isinstance(job, SSHJob)
buf = job.stdout.read(io.DEFAULT_BUFFER_SIZE)
if buf is not None:
job.output += buf.decode("ascii", errors="replace")
job.outfile_handle.write(buf)
def onsuccess(self, job: Job) -> None:
assert isinstance(job, SSHJob)
job.outfile_handle.close()
self.available_nodes.append(job.node)
super().onsuccess(job)
def onerror(self, job: Job) -> None:
assert isinstance(job, SSHJob)
self.available_nodes.append(job.node)
job.outfile_handle.close()
super().onerror(job)
class PrunPool(Pool):
default_job_time = 900 # if prun reserves this amount, it is not logged
def __init__(
self, logger: logging.Logger, parallelmax: int, prun_opts: Iterable[str]
):
super().__init__(logger, parallelmax)
self.prun_opts = prun_opts
def make_jobs(
self,
ctx: Context,
cmd: Union[str, Iterable[str]],
jobid_base: str,
outfile_base: str,
nnodes: int,
**kwargs: Any,
) -> Iterator[Job]:
require_program(ctx, "prun")
self._wait_for_queue_space(nnodes)
ctx.log.info("scheduling " + jobid_base)
cmd = [
"prun",
"-v",
"-np",
str(nnodes),
"-1",
"-o",
outfile_base,
*self.prun_opts,
*cmd,
]
proc = run(
ctx,
cmd,
defer=True,
stderr=STDOUT,
bufsize=0,
universal_newlines=False,
**kwargs,
)
_set_non_blocking(proc.stdout_io)
outfiles = [f"{outfile_base}.{i}" for i in range(nnodes)]
job = PrunJob(proc, jobid_base, outfiles, nnodes)
yield job
def process_job_output(self, job: Job) -> None:
assert isinstance(job, PrunJob)
def group_nodes(
nodes: Sequence[Tuple[int, int]]
) -> List[Tuple[List[int], List[int]]]:
groups = [([m], [c]) for m, c in sorted(nodes)]
for i in range(len(groups) - 1, 0, -1):
lmachines, lcores = groups[i - 1]
rmachines, rcores = groups[i]
if lmachines == rmachines and lcores[-1] + 1 == rcores[0]:
groups[i - 1] = lmachines, lcores + rcores
del groups[i]
elif (
len(lcores) == 1
and lmachines[-1] + 1 == rmachines[0]
and lcores == rcores
):
groups[i - 1] = lmachines + rmachines, lcores
del groups[i]
return groups
def stringify_groups(groups: List[Tuple[List[int], List[int]]]) -> str:
samecore = set(c for m, cores in groups for c in cores) == set([0])
def join(n: Sequence[Any], fmt: str) -> str:
if len(n) == 1:
return fmt % n[0]
else:
return fmt % n[0] + "-" + fmt % n[-1]
if samecore:
# all on core 0, omit it
groupstrings = (join(m, "%03d") for m, c in groups)
else:
# different cores, add /N suffix
groupstrings = (f"{join(m, '%03d')}/{join(c, '%d')}" for m, c in groups)
if len(groups) == 1:
m, c = groups[0]
if len(m) == 1 and len(c) == 1:
return "node" + next(groupstrings)
return f"node[{','.join(groupstrings)}]"
buf = job.stdout.read(1024)
if buf is None:
return
job.output += buf.decode("ascii")
if job.logged:
return
numseconds = None
nodes: List[Tuple[int, int]] = []
for line in job.output.splitlines():
if line.startswith(":"):
for m in re.finditer(r"node(\d+)/(\d+)", line):
nodes.append((int(m.group(1)), int(m.group(2))))
elif numseconds is None:
match = re.search(r"for (\d+) seconds", line)
if match:
numseconds = int(match.group(1))
if len(nodes) == job.nnodes:
assert numseconds is not None
nodestr = stringify_groups(group_nodes(nodes))
self.log.info(f"running {job.jobid} on {nodestr}")
job.start_time = time.time()
job.logged = True
def _set_non_blocking(f: IO) -> None:
flags = fcntl.fcntl(f, fcntl.F_GETFL)
fcntl.fcntl(f, fcntl.F_SETFL, flags | os.O_NONBLOCK)