[源码解析] PyTorch 分布式之弹性训练(7)---节点变化

时间:2023-03-08 22:57:46
[源码解析] PyTorch 分布式之弹性训练(7)---节点变化

[源码解析] PyTorch 分布式之弹性训练(7)---节点变化

0x00 摘要

本文分析如何处理节点变化。即对成员更改作出反应,并使用新的成员来重启所有workers,从而实现弹性训练。

总体思路是和当工作进程失败时的处理一样:相应elastic agent将杀死该节点上的所有工作进程,与其他代理建立会合(rendezvous),并使用新的会合(rendezvous)信息重新启动所有工作进程。

弹性训练系列文章如下:

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

[源码解析] PyTorch 分布式之弹性训练(6)---监控/容错

0x01 变化方式

节点变化有两点方式。

1.1 Scale-down

节点离开(scale-down)的处理如下:

  • 当Scale down事件发生时,rendezvous将不会通知 torchelastic agent。
  • torchelastic agent 自己会监控到有进程错误,从而进行处理。
  • 如果TE agent以max_restarts=0配置启动,它依赖于底层调度程序来处理作业重新启动。
  • 如果max_restarts>0,TE代理将终止workers并开始新一轮rendezvous。
    • 代理得到离开的通知,于是现有workers(所有节点上的)都全部停止。
    • 这些workers将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。

1.2 Scale-up

节点加入(scale-up)的处理如下:

  • 当Scale up事件发生时,新节点被提交到作业,torchelastic rendezvous将检测到有新节点试图加入。
    • 如果rendezvous已经达到最多节点数,新节点将不会添加到等待列表,因为已经满了,所以没有必要拆除已经完全体的rendezvous。新节点将一直等待直到超时(默认为600秒)。
    • 新节点将定期检查参与节点数目。如果数目变为小于max_nodes,等待节点将被加入到等待列表中。否则它将在600秒之后超时。
  • 当代理决定处理 Scale up时:
    • torchelastic rendezvous将停止所有workers并执行新一轮的 re-rendezvous。
    • 这些workers(现有以及新加入的)将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。

注:scale up发生时,max_restarts 将不会减少。

0x02 节点加入

2.1 新节点加入

假设目前已经有了一个弹性训练集群正在运行,弹性区间为 (min=1, max=4)。目前已经有2个节点在运行,用户想启动第三个节点,于是使用如下方法启动一个新进程。

python -m torch.distributed.run
--nnodes=1:4
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

新进程会启动一个代理。代理经过一系列操作,调用 next_rendezvous,其中启动一个 ExitOp,一个 JoinOp 。

def next_rendezvous(self) -> Tuple[Store, int, int]:
exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp() self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline)

2.2 处理 Join 操作

以下操作是在 _DistributedRendezvousOpExecutor 之中。

有了前文分析,我们知道,业务流程是 run 调用 Join 算子来分析出来下一个 Action,然后根据 Action 来执行对应的业务操作

2.2.1 run处理

_DistributedRendezvousOpExecutor.run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。对于我们示例,state_handler 就是_RendezvousJoinOp。

    def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None while action != _Action.FINISH: # 一直循环,直到结束 # 这里很重要,在所有node之间做信息同步
has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。
self._state = self._state_holder.state
# 利用最新状态构建了 ctx
ctx = _RendezvousContext(self._node, self._state, self._settings) # Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline) # 调用_RendezvousJoinOp,决定下一个操作 # 省略后续部分

2.2.2 Join操作

因为之前做了同步,所以这里的ctx就包括了最新的state,这就是Rendezvous的全局状态。因为此时,Rendezvous 已经结束了,所以 state 的状态是 complete,进入如下流程,返回 _Action.ADD_TO_WAIT_LIST。

    if state.complete:
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果当前节点数目小于最大配置
if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 发送一个等待action

总体代码如下:

class _RendezvousJoinOp:
"""Represents a rendezvous join operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
state = ctx.state # 从上下文之中提取 _RendezvousState 状态 # A closed rendezvous means that it no longer accepts new nodes.
if state.closed:
return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED is_participant = ctx.node in state.participants # 看看是参与者 # If we are part of the rendezvous and it is already complete there is
# no further action to take.
if state.complete and is_participant: # 如果是参与者且状态结束,就返回 _Action.FINISH
return _Action.FINISH now = time.monotonic()
if now > deadline: # 如果已经超时
rollback_period = 5 # 5 seconds # If we still have time to rollback (a short period on top of the
# operation deadline), try to remove ourself from the rendezvous.
# It is okay if we can't though as our keep-alive will eventually
# expire.
if now <= deadline + rollback_period: # 如果还有时间来 rollback
# If we are part of the rendezvous, it means we couldn't find
# enough participants to complete it on time.
if is_participant: # 已经是参与者了
return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
# If we are in the wait list, it means we couldn't wait till the
# next round of the rendezvous.
if ctx.node in state.wait_list: # 已经在等待列表之中
return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
return _Action.ERROR_TIMEOUT # 返回超时 if state.complete: # 如果 rendezvous 已经结束
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
elif is_participant: # 如果已经在参与者列表
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
else: # 否则就直接加入到参与者
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
return _Action.KEEP_ALIVE # At this point either the rendezvous is not complete, but we are part
# of it, which means we have to wait for other participants to join; or
# the rendezvous is complete, but we are not part of it, which means we
# have to wait for the next round.
return _Action.SYNC # 否则返回同步状态 _Action.SYNC

2.2.3 等待业务操作

_DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。

    def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None while action != _Action.FINISH: # 一直循环,直到结束 # 这里很重要,在所有node之间做信息同步
has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。
self._state = self._state_holder.state
# 使用最新state构建ctx
ctx = _RendezvousContext(self._node, self._state, self._settings) # Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline) # 调用_RendezvousJoinOp,决定下一个操作,这里得到了 _Action.ADD_TO_WAIT_LIST if action == _Action.SYNC:
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_WAIT_LIST: # 从 Join 算子得到了_Action.ADD_TO_WAIT_LIST
self._add_to_wait_list() # 进行业务逻辑
# 省略其他action # Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty() # 同步回其他节点

具体处理等待操作就是加入到等待列表。

def _add_to_wait_list(self) -> None:
self._state.wait_list.add(self._node)
self._keep_alive()

我们回忆一下 _RendezvousState。_RendezvousState 是rendezvous的状态。是动态信息。

  • round:Rendezvous的当前轮次
  • complete:一个布尔值,指示rendezvous当前一轮是否完成了。
  • deadline:截止时间,如果如果当前轮次一直在等待节点加入,如果这个参数设置了,就是等待的截至时间。
  • closed:一个布尔值,指示rendezvous是否结束了。
  • participants:字典,存放参与者和它们对应ranks。
  • wait_list:set结构,存放等待参与下一轮rendezvous操作的一组节点
  • last_heartbeats:字典,包含每个节点上次心跳时间。
class _RendezvousState:
round: int
complete: bool
deadline: Optional[datetime]
closed: bool
participants: Dict[_NodeDesc, int] # 参与者,未来会用到的成员变量
wait_list: Set[_NodeDesc] # 等待者,这里用到的成员变量
last_heartbeats: Dict[_NodeDesc, datetime] def __init__(self) -> None:
self.round = 0
self.complete = False
self.deadline = None
self.closed = False
self.participants = {}
self.wait_list = set() # 这里用到的成员变量
self.last_heartbeats = {}

目前逻辑如下:

  1. 启动一个新 worker。此时下图右侧上方的 _RendezvousState 之中,wait_list 为空。
  2. 调用 next_rendezvous,发起新一轮 rendezvous。
  3. _RendezvousJoinOp 内部运行,生成 ADD_TO_WAIT_LIST。
  4. executor . run 内部运行 _add_to_wait_list。
  5. 往 wait_list 添加一个新的 node。此时下图右侧上方的 _RendezvousState 之中,wait_list 多了一个 1。
  python -m torch.distributed.run             +-------------------------+     +
--nnodes=xxx TRAINING_SCRIPT.py | _RendezvousState | |
+ | | |
| | participants = [1,2] | |
| 1 | | |
v | wait_list = [] | |
next_rendezvous | | |
+ +------------+------------+ |
| 2 | |
| | |
v | |
+----------------+-----------------------+ | |
| _op_executor.run(_RendezvousJoinOp) | | |
| + + | | |
| | | 3 | | |
| | | | | |
| | v | | |
| | _Action.ADD_TO_WAIT_LIST | v |
| | + | |
| | | | +--------------------------+ |
| +<-------------+ | | _RendezvousState | |
| | | | | |
| | | | participants = [1,2] | |
| v 4 | 5 | | |
| self._add_to_wait_list() +----------------> wait_list = [3] | |
| | | | |
+----------------------------------------+ +--------------------------+ |
|
v Timeline

2.3 Agent 处理

_DistributedRendezvousOpExecutor . run 处理之后,操作回到了代理之中。代理主循环之中,程序会进入 while 循环,然后通过 _monitor_workers 定期轮训用户程序运行情况,依据情况作出判断。

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role spec = self._worker_group.spec
role = spec.role self._initialize_workers(self._worker_group) # 启动worker
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler while True:
assert self._worker_group.state != WorkerState.INIT
# 定期监控
time.sleep(monitor_interval)
# 监控客户程序运行情况
run_result = self._monitor_workers(self._worker_group)
state = run_result.state # 进程运行情况
self._worker_group.state = state if state == WorkerState.SUCCEEDED:
# 程序正常结束
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# 程序出错
if self._remaining_restarts > 0: # 重试
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group) # 重试次数达到,结束workers
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# 程序正常运行
# 节点成员关系有变化,比如scale up
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
# 如果有新的节点在waiting,就重启所有workers
if num_nodes_waiting > 0:
self._restart_workers(self._worker_group)
else:
raise Exception(f"[{role}] Worker group in {state.name} state")

所以,代理定期运行 _monitor_workers 监控worker运行情况才是关键。run_result.state 是进程运行情况,当状态是 WorkerState.HEALTHY,说明原有程序正常运行,接下来看看节点成员关系是否有变化。

调用 rdzv_handler.num_nodes_waiting() 拿到等待列表数目,如果有新的节点在waiting,就说明有新的节点试图加入集群,这时就会发生一个Re-rendezvous。代理将重启所有workers。重启时候,会把等待列表中的节点加入到参与列表之中。我们依次看看如何处理。

2.3.1 检查等待列表

处理时候,首先会调用 num_nodes_waiting 看看还有多少节点在等待,具体是看看 state.wait_list 的长度。我们通过之前 Join 操作知道,如果有新节点,会插入到这个列表之中。

num_nodes_waiting 方法的作用是 返回在 rendezvous barrier 上等待的节点数目(这些节点不会在当前工作组被包括)。调用者应该周期调用这个方法,来确定是否有新节点等候加入当前工作组,因此需要调用next_rendezvous() 来提交他们。

def num_nodes_waiting(self) -> int:
"""See base class."""
with self._heartbeat_lock:
self._state_holder.sync() return len(self._state_holder.state.wait_list)

目前逻辑如下:

  1. 启动一个新 worker。
  2. 调用 next_rendezvous,发起新一轮 rendezvous。
  3. _RendezvousJoinOp 内部运行,生成 ADD_TO_WAIT_LIST。
  4. executor.run 内部运行 _add_to_wait_list。
  5. 往 wait_list 添加一个新的 node。
  6. Agent 之中,定期(比如 30S)运行一次 _monitor_workers,获取worker 子进程状态。
  7. 如果是 HEALTHY,则调用num_nodes_waiting 获取 wait_list 个数。
  8. 如果 wait_list 之中等待节点数目大于 0,则:
  9. 调用 _restart_workers 重启进程组。
  python -m torch.distributed.run             +-------------------------+     +
--nnodes=xxx TRAINING_SCRIPT.py | _RendezvousState | |
+ | | |
| | participants = [1,2] | |
| 1 | | |
v | wait_list = [] | |
next_rendezvous | | |
+ +------------+------------+ |
| 2 | |
| | |
v | |
+----------------+-----------------------+ | |
| _op_executor.run(_RendezvousJoinOp) | | |
| + + | | |
| | | 3 | | |
| | | | | |
| | v | | |
| | _Action.ADD_TO_WAIT_LIST | v |
| | + | |
| | | | +--------------------------+ |
| +<-------------+ | | _RendezvousState | |
| | | | | |
| | | | participants = [1,2] | |
| v 4 | 5 | | |
| self._add_to_wait_list() +----------------> wait_list = [3] | |
| | | | |
+----------------------------------------+ +------------+-------------+ |
| |
+----------------------------------------+ | |
| agent._invoke_run | | |
| | | |
| | | |
| _monitor_workers Every 30S | | |
| + | | |
| | 6 | | |
| | | v |
| v | |
| WorkerState.HEALTHY | +--------------------------+ |
| + | | _RendezvousState | |
| | | | | |
| | 7 | | participants = [1,2] | |
| v | 8 | | |
| num_nodes_waiting <--------------------> wait_list = [3] | |
| + | | | |
| | 9 | | | |
| | | +--------------------------+ |
| v | |
| _restart_workers | v
| |
+----------------------------------------+ Timeline

2.3.3 重启worker组

如果等待列表之中有节点,就会重启workers。我们走一下这个流程。

@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
"""
Restarts (stops, rendezvous, starts) all local workers in the group.
""" role = worker_group.spec.role
self._stop_workers(worker_group)
worker_group.state = WorkerState.STOPPED
self._initialize_workers(worker_group)
2.3.3.1 _stop_workers

首先会停止目前 workers,代码在torch/distributed/elastic/agent/server/local_elastic_agent.py。

@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
self._shutdown()
2.3.3.2 _shutdown

_shutdown 就是让上下文关闭。

def _shutdown(self) -> None:
if self._pcontext:
self._pcontext.close()
2.3.3.3 关闭上下文

在 MultiprocessContext 之中,close 方法是关闭所有子进程,然后等待其全部停止。

    def _close(self) -> None:
if self._pc:
for proc in self._pc.processes:
proc.terminate()
proc.join()
2.3.3.4 _initialize_workers

当关闭了所有当前运行的子进程之后,会重新全部初始化。

@prof
def _initialize_workers(self, worker_group: WorkerGroup) -> None:
r"""
Starts a fresh set of workers for the worker_group.
Essentially a rendezvous followed by a start_workers. The caller should first call ``_stop_workers()`` to stop running workers
prior to calling this method. Optimistically sets the state of the worker group that
just started as ``HEALTHY`` and delegates the actual monitoring
of state to ``_monitor_workers()`` method
"""
role = worker_group.spec.role # TODO after stopping workers, wait at least monitor_interval*2 for
# workers on different nodes to fail on a collective op before waiting
# on the rdzv barrier, this way we ensure that nodes enter rdzv
# at around the same time and reduce false positive rdzv timeout errors
self._rendezvous(worker_group) worker_ids = self._start_workers(worker_group)
for local_rank, w_id in worker_ids.items():
worker = worker_group.workers[local_rank]
worker.id = w_id worker_group.state = WorkerState.HEALTHY

_rendezvous经过一系列操作,调用 next_rendezvous,在其中启动一个 ExitOp,一个 JoinOp 。

def next_rendezvous(self) -> Tuple[Store, int, int]:

    exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp() self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline)
2.3.3.5 _RendezvousJoinOp

我们又回来了,这是新一轮 Rendezvous 操作。_DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。对于我们示例,state_handler 就是_RendezvousJoinOp

def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None while action != _Action.FINISH:
# Reads or writes the latest rendezvous state shared by all nodes in
# the rendezvous. Note that our local changes might get overridden
# by another node if that node synced its changes before us.
has_set = self._state_holder.sync()
self._state = self._state_holder.state
ctx = _RendezvousContext(self._node, self._state, self._settings) # Determine the next action to take based on the current state of
# the rendezvous.
# 调用到_RendezvousJoinOp,大家可以过一下 _RendezvousJoinOp 代码,发现此时将返回 ADD_TO_PARTICIPANTS
action = state_handler(ctx, deadline) if action == _Action.SYNC:
# Delay the execution by one second to avoid overloading the
# backend if we are asked to poll for state changes.
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS: # 运行到这里
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed() # Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty()

这次会生成 ADD_TO_PARTICIPANTS。

class _RendezvousJoinOp:
"""Represents a rendezvous join operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
state = ctx.state # 从上下文之中提取 _RendezvousState 状态 # A closed rendezvous means that it no longer accepts new nodes.
if state.closed:
return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED is_participant = ctx.node in state.participants # 看看是参与者 # If we are part of the rendezvous and it is already complete there is
# no further action to take.
if state.complete and is_participant: # 如果是参与者且状态结束,就返回 _Action.FINISH
return _Action.FINISH now = time.monotonic()
if now > deadline: # 如果已经超时
rollback_period = 5 # 5 seconds # If we still have time to rollback (a short period on top of the
# operation deadline), try to remove ourself from the rendezvous.
# It is okay if we can't though as our keep-alive will eventually
# expire.
if now <= deadline + rollback_period: # 如果还有时间来 rollback
# If we are part of the rendezvous, it means we couldn't find
# enough participants to complete it on time.
if is_participant: # 已经是参与者了
return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
# If we are in the wait list, it means we couldn't wait till the
# next round of the rendezvous.
if ctx.node in state.wait_list: # 已经在等待列表之中
return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
return _Action.ERROR_TIMEOUT # 返回超时 if state.complete: # 如果 rendezvous 已经结束
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
elif is_participant: # 如果已经在参与者列表
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
else: # 否则就直接加入到参与者
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
return _Action.KEEP_ALIVE # At this point either the rendezvous is not complete, but we are part
# of it, which means we have to wait for other participants to join; or
# the rendezvous is complete, but we are not part of it, which means we
# have to wait for the next round.
return _Action.SYNC # 否则返回同步状态 _Action.SYNC
2.3.3.6 _add_to_participants

引擎收到 ADD_TO_PARTICIPANTS 之后,会调用 _add_to_participants 从 wait_list 移除节点,插入到 participants。

def _add_to_participants(self) -> None:
log.debug(
f"The node '{self._node}' added itself to the participants of round "
f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
) state = self._state
state.wait_list.remove(self._node) # 移除节点 # The ranks of the participants will be set once the rendezvous is
# complete.
state.participants[self._node] = 0 # 重新插入 self._keep_alive() if len(state.participants) == self._settings.min_nodes:
state.deadline = datetime.utcnow() + self._settings.timeout.last_call if len(state.participants) == self._settings.max_nodes:
self._mark_rendezvous_complete()

我们这次从 _restart_workers 开始绘制。

  1. 调用 _stop_workers 来关闭worker子进程。此时下图右侧上方 _RendezvousState之中,participants=[1,2]。
  2. 通过 MultiprocessContext.close() 完成关闭操作。
  3. 通过 _initialize_workers 重新初始化 worker。
  4. 调用 next_rendezvous 完成新的同步操作。
  5. _RendezvousJoinOp 这次返回ADD_TO_PARTICIPANTS。
  6. 调用 _add_to_participants 进行状态切换。
  7. wait_list 之中的Node被移动到 participants。此时下图右侧上方 _RendezvousState之中,participants=[1,2,3]。
                         +-----------------------------+   +------------------------+  |
| agent._invoke_run | | _RendezvousState | |
| | | | |
| _restart_workers | | participants = [1,2] | |
| + | | | |
+----------------------+ | | | | wait_list = [3] | |
| MultiprocessContext | | | 1 | | | |
| | | 2 v | +------------------------+ |
| close() <-----------+ _stop_workers | |
| | | + | |
+----------------------+ | | | |
| | 3 | |
| v | |
| _initialize_workers | |
| + | |
| | | |
+-----------------------------+ |
| |
| 4 |
v |
next_rendezvous |
+ |
| |
v |
+---------------------------+---------------+ |
| _op_executor.run(_RendezvousJoinOp) | |
| + + | |
| | | | |
| | | 5 | |
| | v | |
| | ADD_TO_PARTICIPANTS | |
| | + | +-----------------------+ |
| | | | | _RendezvousState | |
| | <-------------+ | | | |
| | | | participants = [1,2,3]| |
| v 6 7 | | | |
| _add_to_participants +--------------> | wait_list = [] | |
| | | | |
+-------------------------------------------+ +-----------------------+ v Timeline

0x03 节点离开

3.1 处理机制

节点离开(scale-down)的处理如下:

  • 当Scale down事件发生时,rendezvous将不会通知 torchelastic agent。
  • 如果TE agent以“max_restarts=0”启动,它依赖于底层调度程序来处理作业重新启动。
  • 如果“max_restarts>0”,TE代理将终止workers并开始新一轮rendezvous。
    • 代理得到离开的通知,于是现有workers(所有节点上)都全部停止。
    • 这些workers将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。、

3.2 如何模拟

如果想模拟调试的同学,可以在 test/distributed/elastic/agent/server/test/local_elastic_agent_test.py 之中找到示例代码。

def test_double_agent_elastic(self):
"""
start ``nnodes`` agents, kill odd ones (do not restart), validate
elasticity (scale-down) works. (scale-up covered in fault_tolerance test)
"""
min_nodes = 1
max_nodes = 2
wait = 2
node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2)
agent_results = mp.Queue()
agent_args = {
"conf": node_conf,
"agent_results": agent_results,
"min_nodes": min_nodes,
"max_nodes": max_nodes,
"max_restarts": 2,
} procs = []
for _ in range(max_nodes):
p = mp.Process(
target=self.run_agent,
kwargs=agent_args,
)
procs.append(p)
p.start() # kill odd agents
for i in range(max_nodes):
if i % 2 != 0:
procs[i].kill() for i in range(max_nodes):
p = procs[i]
p.join()
if i % 2 == 0:
self.assertEqual(0, p.exitcode)
else:
self.assertEqual(-signal.SIGKILL, p.exitcode)

3.3 如何处理

节点离开,与错误处理是同一个代码。错误处理代码如下,如果重试尚未达到最大次数,则试图重启workers。如果已经达到了最大次数,则停止 workers。

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:

        # 省略

        while True:

            # 定期监控
time.sleep(monitor_interval)
# 监控客户程序运行情况
run_result = self._monitor_workers(self._worker_group) elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
# 程序出错 if self._remaining_restarts > 0: # 重试
self._remaining_restarts -= 1
self._restart_workers(self._worker_group) # 进行重启
else:
self._stop_workers(self._worker_group) # 重试次数达到,结束workers
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result

3.3.1 重启

_restart_workers 会停掉所有 workers,然后重新一轮 rendezvous 。

@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
"""
Restarts (stops, rendezvous, starts) all local workers in the group.
""" role = worker_group.spec.role
self._stop_workers(worker_group)
worker_group.state = WorkerState.STOPPED
self._initialize_workers(worker_group)

3.3.2 停止

停止 workers 就是关闭上下文。

def _shutdown(self) -> None:
if self._pcontext:
self._pcontext.close() @prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
self._shutdown()

在 MultiprocessContext 之中,close 方法是关闭所有子进程,然后等待其全部停止。

    def _close(self) -> None:
if self._pc:
for proc in self._pc.processes:
proc.terminate()
proc.join()

流程图如下:

  1. 监控子进程状态。
  2. 发现 UNHEALTHY 或者 FAILED,看看重启次数是否还有。我们假定是3号进程失败。
  3. 如果没有,就调用 _stop_workers 结束子进程。
  4. 调用 MultiprocessContext.close 进行具体结束操作。
  5. 如果还可以重启,调用_restart_workers。
  6. 调用 _stop_workers 结束子进程。
  7. 调用 MultiprocessContext.close 进行具体结束操作。
  8. 调用 _initialize_workers 重新初始化worker。
  9. 调用 next_rendezvous 重新同步。
  10. 进行后续操作。
                                                                                 +
+-------------------------------------------+ +---------------------------+ |
| agent._invoke_run | | _RendezvousState | |
| | | | |
| | | | |
| _monitor_workers Every 30S | | participants = [1,2,3] | |
| + | | | |
| | 1 | | wait_list = [ ] | |
| | | | | |
| v | +---------------------------+ |
| WorkerState.UNHEALTHY,FAILED | |
| + | |
| | | |
| | 2 | |
| v | |
| self._remaining_restarts > 0 ? +--+ | |
| + | | |
| 5 | YES NO | 3 | |
| | | | |
| v v | +----------------------+ |
| _restart_workers _stop_workers | | MultiprocessContext | |
| + + | | | |
| | 6 | 4 | | | |
| | +--------> | | |
| v | | close() | |
| _stop_workers +-------------------------> | | |
| + 7 | +----------------------+ |
| | | |
| | 8 | |
| v | |
| _initialize_workers | |
| + | |
| | | |
+-------------------------------------------+ |
| 9 |
| |
v +--------------------------+ |
next_rendezvous | _RendezvousState | |
+ | | |
| 10 | participants = [1,2] | |
+----------------------------> | | |
| | wait_list = [ ] | v
| 10 +--------------------------+
v Timeline

至此,弹性训练全部分析完毕,或者说PyTorch分布式分析就告一段落,我们下文会介绍其他框架/库的分布式实现,敬请期待。

0xFF 参考

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎