未验证 提交 9ccdb5fa 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Using program to be the only interface of TaskNode (#43869)

上级 c8d3b82c
......@@ -41,6 +41,20 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// TODO(liyurui): Will be removed when execute program is supported.
Init();
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
......
......@@ -55,6 +55,12 @@ class TaskNode final {
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
// TODO(liyurui): This will be the only constructor for task node
TaskNode(paddle::framework::ProgramDesc* program,
int64_t task_id,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
......
......@@ -165,6 +165,11 @@ void BindFleetExecutor(py::module* m) {
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
int64_t,
int64_t,
int64_t>())
.def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t,
const std::vector<framework::OpDesc*>&,
......
......@@ -14,6 +14,7 @@
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY
from paddle.fluid import core
from paddle.static import Program
class TaskNode:
......@@ -21,86 +22,107 @@ class TaskNode:
Python side TaskNode, connection to the c++ side TaskNode
"""
# track the previous init method
previous = None
def __init__(self,
cur_rank,
rank,
max_run_times,
max_slot_times,
role=None,
node_type='Compute',
task_id=None,
node_type=None,
task_id=0,
ops=None,
program=None):
program=None,
lazy_initialize=False):
"""
:param cur_rank (int): Current rank of the task node.
:param rank (int): Current rank of the task node.
:param max_run_times (int): The max run times of the task node.
:param max_slot_times (int): The mas slot times of the task node.
:param role (int): The role of the task node.
:param node_type (str): The type of the task node, default is 'Compute'
:param task_id (int): The task id of the task node.
:param ops (list): A list of op.desc to init the task node.
:param role (int): The role of the task node. (Will be removed in the future)
:param node_type (str): The type of the task node.
:param task_id (int): The id of task node.
:param ops (list): A list of op.desc to init the task node. (Will be removed in the future)
:param program (Program): An instance of Program to init the task node.
:param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later.
"""
# NOTE: ops should be checked by `is not None`, since it may be empty list
assert ((ops is not None) ^ (program is not None)), \
"Should provide only one of ops or program to task node."
if not self.previous:
self.previous = 'program' if program else 'ops'
assert (program is not None and self.previous == 'program') or \
(ops is not None and self.previous == 'ops'), \
"In one program, task node should be inited in the same way, all by ops or all by program."
if ops is not None:
assert role is not None and task_id is not None, \
"If init task node with ops, should provide `role` and `task_id`."
self.node = core.TaskNode(role, ops, cur_rank, int(task_id),
max_run_times, max_slot_times)
print("Creating task node by ops. The role is:", self.role(),
"and the id is:", self.task_id())
else:
self.program = program
self.node = core.TaskNode(program.desc, cur_rank, max_run_times,
max_slot_times)
print("Creating task node by program. The id is:", self.task_id())
self.node.set_type(node_type)
def set_type(self, interceptor_type):
self.node.set_type(interceptor_type)
assert (not ((ops is not None) and lazy_initialize)), \
"Lazy initialization doesn't support with ops list"
self.id = int(task_id)
self.rank = rank
self.max_run_times = max_run_times
self.max_slot_times = max_slot_times
self.node_type = node_type
self.program = program
self.lazy_initialize = lazy_initialize
self.run_pre_steps = None
self.run_at_offset = None
self.node = None
self.upstreams = []
self.downstreams = []
if not lazy_initialize:
if ops is not None:
assert role is not None and task_id is not None, \
"If init task node with ops, should provide `role` and `task_id`."
self.node = core.TaskNode(role, ops, rank, task_id,
max_run_times, max_slot_times)
else:
self.node = core.TaskNode(program.desc, rank, self.id,
max_run_times, max_slot_times)
if self.node_type:
self.node.set_type(self.node_type)
def task_node(self):
if hasattr(self, 'program'):
print(
"The task node has been instantiated by program, calling init before passing to fleet executor."
)
self.node.init()
if self.lazy_initialize:
self.node = core.TaskNode(self.program.desc, self.rank, self.id,
self.max_run_times, self.max_slot_times)
if self.node_type:
self.node.set_type(self.node_type)
if self.run_pre_steps:
self.node.set_run_pre_steps(self.run_pre_steps)
if self.run_at_offset:
self.node.set_run_at_offset(self.run_at_offset)
for up in self.upstreams:
self.node.add_upstream_task(up[0], up[1])
for down in self.downstreams:
self.node.add_downstream_task(down[0], down[1])
self.lazy_initialize = False
return self.node
def set_program(self, program):
assert self.lazy_initialize, \
"Inside program is unchangable for immediate initialized task node. Set the lazy_initialize to be true if the inside program need to be update. Remember to do all your change before eval node.task_node()."
self.program = program
self.node.set_program(program.desc)
def get_program(self):
assert hasattr(self, 'program'), 'There is no program to get'
assert self.program is not None, "The task node is not initialized using program"
return self.program
def set_run_pre_steps(self, steps):
self.node.set_run_pre_steps(steps)
if self.lazy_initialize:
self.run_pre_steps = steps
else:
self.node.set_run_pre_steps(steps)
def set_run_at_offset(self, offset):
self.node.set_run_at_offset(offset)
def add_upstream_task(self, upstream, buffer_size):
self.node.add_upstream_task(upstream, buffer_size)
if self.lazy_initialize:
self.run_at_offset = offset
else:
self.node.set_run_at_offset(offset)
def add_downstream_task(self, downstream, buffer_size):
self.node.add_downstream_task(downstream, buffer_size)
def add_upstream_task(self, upstream, buffer_size=2):
if self.lazy_initialize:
self.upstreams.append((upstream, buffer_size))
else:
self.node.add_upstream_task(upstream, buffer_size)
def role(self):
return self.node.role()
def add_downstream_task(self, downstream, buffer_size=2):
if self.lazy_initialize:
self.downstreams.append((downstream, buffer_size))
else:
self.node.add_downstream_task(downstream, buffer_size)
def task_id(self):
return self.node.task_id()
return self.id
class CoordSys:
......@@ -158,166 +180,244 @@ class CoordSys:
}
def is_optimizer_op(op_role):
return op_role == int(OpRole.Optimize)
def is_lr_sched_op(op_role):
return op_role == int(OpRole.Optimize.LRSched)
def is_forward_op(op_role):
return (op_role == int(OpRole.Forward)) or \
(op_role == (int(OpRole.Forward) | int(OpRole.Loss)))
class FleetExecutorUtils:
def __init__(self,
dist_strategy=None,
rank=None,
nrank=None,
max_run_times=None):
self.dist_strategy = dist_strategy
self.rank = rank
self.nrank = nrank
self.max_run_times = max_run_times
self.is_auto_parallel = True if dist_strategy is None else False
self.num_of_functionality = 4
self.coord_sys = None
self.coord = None
if dist_strategy:
self.coord_sys = CoordSys(dist_strategy)
self.coord = self.coord_sys.rank_to_coord(rank)
def is_optimizer_op(self, op_role):
return op_role == int(OpRole.Optimize)
def is_lr_sched_op(self, op_role):
return op_role == int(OpRole.Optimize.LRSched)
def is_forward_op(self, op_role):
return (op_role == int(OpRole.Forward)) or \
(op_role == (int(OpRole.Forward) | int(OpRole.Loss)))
def is_backward_op(self, op_role):
return (op_role == int(OpRole.Backward)) or \
(op_role == (int(OpRole.Backward) | int(OpRole.Loss)))
def split_program_to_op_list(self, program):
op_list_map = {"lr": [], "fwd": [], "bwd": [], "opt": []}
for op in program.block(0).ops:
# split the program based on the op_role
op_role = int(op.all_attrs()[OP_ROLE_KEY])
if self.is_lr_sched_op(op_role):
op_list_map["lr"].append(op)
elif self.is_forward_op(op_role):
op_list_map["fwd"].append(op)
elif self.is_backward_op(op_role):
op_list_map["bwd"].append(op)
elif self.is_optimizer_op(op_role):
op_list_map["opt"].append(op)
else:
raise "The op role: " + str(
op_role
) + " isn't one of LRSched, Forward, Backward or Optimizer."
return op_list_map
def convert_op_list_to_program(self, op_list, complete_program):
#TODO(liyurui): Complete this convert logic
program_map = {
"lr": Program(),
"fwd": Program(),
"bwd": Program(),
"opt": Program()
}
return program_map
def build_1f1b_dependency(self, task_node_map):
assert not self.is_auto_parallel, "Handly add dependency should not be invoked in auto parallel mode"
# Generated the dependency based on this graph:
# lr(1:m) -> forward -> backward -> (m:1)optimize
# ↑ ↓
# lr(1:m) -> forward -> backward -> (m:1)optimize
# ↑ ↓
# lr(1:m) -> forward -> backward -> (m:1)optimize
# add dependency intra stage
cur_start_id = self.rank * self.num_of_functionality
pp_buff_size = int(self.dist_strategy['pp_degree'] -
self.coord['pp_idx'])
task_node_map["lr"].add_downstream_task(cur_start_id + 1)
task_node_map["fwd"].add_upstream_task(cur_start_id)
task_node_map["fwd"].add_downstream_task(cur_start_id + 2, pp_buff_size)
task_node_map["bwd"].add_upstream_task(cur_start_id + 1, pp_buff_size)
task_node_map["bwd"].add_downstream_task(cur_start_id + 3)
task_node_map["opt"].add_upstream_task(cur_start_id + 2)
# add dependency inter stage
upstream_coord, downstream_coord = self.coord.copy(), self.coord.copy()
upstream_coord['pp_idx'] = upstream_coord['pp_idx'] - 1
downstream_coord['pp_idx'] = downstream_coord['pp_idx'] + 1
pp_upstream = self.coord_sys.coord_to_rank(upstream_coord)
pp_downstream = self.coord_sys.coord_to_rank(downstream_coord)
first_stage = (pp_upstream == -1)
last_stage = (pp_downstream == -1)
prev_pp_start_id = pp_upstream * self.num_of_functionality
next_pp_start_id = pp_downstream * self.num_of_functionality
if not first_stage:
task_node_map["fwd"].add_upstream_task(prev_pp_start_id + 1)
task_node_map["bwd"].add_downstream_task(prev_pp_start_id + 2)
if not last_stage:
task_node_map["fwd"].add_downstream_task(next_pp_start_id + 1)
task_node_map["bwd"].add_upstream_task(next_pp_start_id + 2)
return task_node_map
def construct_task_nodes_1f1b(self, program_map):
max_slot_times = int(self.max_run_times - self.coord['pp_idx'])
cur_start_id = int(self.rank * self.num_of_functionality)
lr_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["lr"],
task_id=cur_start_id)
fwd_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["fwd"],
task_id=cur_start_id + 1)
bwd_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["bwd"],
task_id=cur_start_id + 2)
opt_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["opt"],
task_id=cur_start_id + 3)
return {
"lr": lr_task_node,
"fwd": fwd_task_node,
"bwd": bwd_task_node,
"opt": opt_task_node
}
def is_backward_op(op_role):
return (op_role == int(OpRole.Backward)) or \
(op_role == (int(OpRole.Backward) | int(OpRole.Loss)))
def task_id_to_rank(self):
task_id_to_rank = {}
for i in range(self.nrank):
for j in range(self.num_of_functionality):
task_id_to_rank[int(i * self.num_of_functionality + j)] = i
return task_id_to_rank
def construct_task_nodes_1f1b_op_list(self, op_list_map):
max_slot_times = int(self.max_run_times - self.coord['pp_idx'])
cur_start_id = int(self.rank * self.num_of_functionality)
lr_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize.LRSched),
ops=op_list_map["lr"],
task_id=cur_start_id,
node_type="Amplifier")
lr_task_node.set_run_pre_steps(self.max_run_times)
fwd_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Forward),
ops=op_list_map["fwd"],
task_id=cur_start_id + 1,
node_type="Compute")
bwd_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Backward),
ops=op_list_map["bwd"],
task_id=cur_start_id + 2,
node_type="Compute")
opt_task_node = TaskNode(rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize),
ops=op_list_map["opt"],
task_id=cur_start_id + 3,
node_type="Amplifier")
opt_task_node.set_run_pre_steps(self.max_run_times)
opt_task_node.set_run_at_offset(self.max_run_times - 1)
return {
"lr": lr_task_node,
"fwd": fwd_task_node,
"bwd": bwd_task_node,
"opt": opt_task_node
}
def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank):
def run1f1b(program,
rank,
max_run_times,
dist_opt,
nrank,
with_standalone_executor=False):
"""
Split the program to support 1f1b pipeline scheduler.
This funct will split the program based on the op_role.
The program will be split into four parts: lr_sched, fwd, bwd, opt.
And will create task nodes based on the four parts of the program.
:param program: The origin program.
:param cur_rank: Current rank (can be got from fleet.worker_index()).
:param rank: Current rank (can be got from fleet.worker_index()).
:param max_run_times: Max run times for a micro batch. AKA number of micro steps.
:param dist_opt: The fleet_opt configured by user.
:param nrank: Number of workers (can be got from fleet.worker_num()).
:param with_standalone_executor: Experiment feature, use fleet executor with standalone executor.
:return:
task_nodes (list): four task nodes for current rank
task_id_to_rank (dict): task nodes' ids to it's corresponding rank
"""
print("fleet executor will use python side 1f1b scheduler.")
coord_sys = CoordSys(dist_opt)
coord = coord_sys.rank_to_coord(cur_rank)
max_slot_times = int(max_run_times - coord['pp_idx'])
num_of_functionality = 4
lr_ops, fwd_ops, bwd_ops, opt_ops = [], [], [], []
for op in program.block(0).ops:
# split the program based on the op_role
op_role = int(op.all_attrs()[OP_ROLE_KEY])
if is_lr_sched_op(op_role):
lr_ops.append(op.desc)
elif is_optimizer_op(op_role):
opt_ops.append(op.desc)
elif is_forward_op(op_role):
fwd_ops.append(op.desc)
elif is_backward_op(op_role):
bwd_ops.append(op.desc)
else:
raise "The op role: " + str(
op_role
) + " isn't one of LRSched, Forward, Backward or Optimizer."
# Create task nodes.
# The lr_sched and opt should be 'amplifier interceptor.
# The fwd and bwd should be 'compute interceptor'.
lr_task_node = TaskNode(cur_rank=cur_rank,
max_run_times=max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize.LRSched),
ops=lr_ops,
task_id=int(cur_rank * num_of_functionality + 0),
node_type="Amplifier")
lr_task_node.set_run_pre_steps(max_run_times)
fwd_task_node = TaskNode(cur_rank=cur_rank,
max_run_times=max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Forward),
ops=fwd_ops,
task_id=int(cur_rank * num_of_functionality + 1),
node_type="Compute")
bwd_task_node = TaskNode(cur_rank=cur_rank,
max_run_times=max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Backward),
ops=bwd_ops,
task_id=int(cur_rank * num_of_functionality + 2),
node_type="Compute")
opt_task_node = TaskNode(cur_rank=cur_rank,
max_run_times=max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize),
ops=opt_ops,
task_id=int(cur_rank * num_of_functionality + 3),
node_type="Amplifier")
opt_task_node.set_run_pre_steps(max_run_times)
opt_task_node.set_run_at_offset(max_run_times - 1)
task_nodes = [lr_task_node, fwd_task_node, bwd_task_node, opt_task_node]
# Generated the dependency based on this graph:
# lr(1:m) -> forward -> backward -> (m:1)optimize
# ↑ ↓
# lr(1:m) -> forward -> backward -> (m:1)optimize
# ↑ ↓
# lr(1:m) -> forward -> backward -> (m:1)optimize
upstream_coord, downstream_coord = coord.copy(), coord.copy()
upstream_coord['pp_idx'] = upstream_coord['pp_idx'] - 1
downstream_coord['pp_idx'] = downstream_coord['pp_idx'] + 1
pp_upstream = coord_sys.coord_to_rank(upstream_coord)
pp_downstream = coord_sys.coord_to_rank(downstream_coord)
first_stage = (pp_upstream == -1)
last_stage = (pp_downstream == -1)
for i in range(num_of_functionality):
task_node = task_nodes[i]
task_role = task_node.role()
cur_id = int(cur_rank * num_of_functionality + i)
prev_id = cur_id - 1
next_id = cur_id + 1
upstream_id = int(pp_upstream * num_of_functionality + i)
downstream_id = int(pp_downstream * num_of_functionality + i)
pp_buff_size = int(dist_opt['pp_degree'] - coord['pp_idx'])
ups = []
downs = []
if not is_lr_sched_op(task_role):
buf_size = pp_buff_size if is_backward_op(task_role) else 2
ups.append((prev_id, buf_size))
if not is_optimizer_op(task_role):
buf_size = pp_buff_size if is_forward_op(task_role) else 2
downs.append((next_id, buf_size))
if is_forward_op(task_role):
if not first_stage:
ups.append((upstream_id, 2))
if not last_stage:
downs.append((downstream_id, 2))
elif is_backward_op(task_role):
if not last_stage:
ups.append((downstream_id, 2))
if not first_stage:
downs.append((upstream_id, 2))
for up in ups:
print("Task:", cur_id, "'s upstream includes:", up[0])
task_node.add_upstream_task(up[0], up[1])
for down in downs:
print("Task:", cur_id, "'s downstream includes:", down[0])
task_node.add_downstream_task(down[0], down[1])
task_id_to_rank = {}
for i in range(nrank):
for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i
return [task_node.task_node() for task_node in task_nodes], task_id_to_rank
def origin(program, cur_rank):
fleet_executor_utils = FleetExecutorUtils(dist_strategy=dist_opt,
rank=rank,
nrank=nrank,
max_run_times=max_run_times)
op_list_map = fleet_executor_utils.split_program_to_op_list(program)
task_node_map = None
if with_standalone_executor:
program_map = fleet_executor_utils.convert_op_list_to_program(
op_list_map, program)
task_node_map = fleet_executor_utils.construct_task_nodes_1f1b(
program_map)
else:
op_desc_list_map = {"lr": [], "fwd": [], "bwd": [], "opt": []}
for key in op_list_map:
for op in op_list_map[key]:
op_desc_list_map[key].append(op.desc)
task_node_map = fleet_executor_utils.construct_task_nodes_1f1b_op_list(
op_desc_list_map)
task_node_map = fleet_executor_utils.build_1f1b_dependency(task_node_map)
task_id_to_rank = fleet_executor_utils.task_id_to_rank()
task_node_list = [task_node_map[key].task_node() for key in task_node_map]
return task_node_list, task_id_to_rank
def origin(program, rank):
"""
Origin scheduler for fleet executor, supports non-pp mode
:param program: The origin program.
:param cur_rank: Current rank (can be got from fleet.worker_index()).
:param rank: Current rank (can be got from fleet.worker_index()).
:return:
task_nodes (list): four task nodes for current rank
task_id_to_rank (dict): a fake dict, since there is no upstream or downstream, this dict won't be used
"""
print("fleet executor will use python side origin scheduler.")
task_node = TaskNode(program=program,
cur_rank=cur_rank,
rank=rank,
node_type="Compute",
max_run_times=1,
max_slot_times=1)
task_node.set_type("Compute")
task_id = task_node.task_id()
task_id_to_rank = {task_id: cur_rank}
task_id_to_rank = {task_node.task_id(): rank}
return [task_node.task_node()], task_id_to_rank
......@@ -717,6 +717,9 @@ class Executor(object):
self._executor_cache = _ExecutorCache(self.place)
self._fleet_executor = None
# TODO(liyurui): This option will be removed and always true when the functionality
# of fleet executor with standalone executor is ready.
self._fleet_executor_with_standalone = False
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
......@@ -1323,9 +1326,12 @@ class Executor(object):
# Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None:
self._fleet_executor = _prepare_fleet_executor()
return self._run_using_fleet_executor(program=program,
feed=feed,
fetch_list=fetch_list)
return self._run_using_fleet_executor(
program=program,
feed=feed,
fetch_list=fetch_list,
with_standalone_executor=self.
_fleet_executor_with_standalone)
if "startup_program" in program._pipeline_opt:
program = program._pipeline_opt["startup_program"]
else:
......@@ -2131,7 +2137,8 @@ class Executor(object):
carrier_id="",
program=None,
scope=None,
fleet_opt=None):
fleet_opt=None,
with_standalone_executor=False):
num_micro_batches = fleet_opt[
"num_micro_batches"] if "num_micro_batches" in fleet_opt else 1
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
......@@ -2157,7 +2164,8 @@ class Executor(object):
warnings.warn("Using 1F1B scheduler with pp_degree == 1.")
tasks, task_id_to_rank = run1f1b(
program, cur_rank, fleet_opt.get('num_micro_batches', 1),
fleet_opt.get('dist_strategy', {}), nrank)
fleet_opt.get('dist_strategy', {}), nrank,
with_standalone_executor)
elif scheduler == 'Origin':
from paddle.distributed.fleet.fleet_executor_utils import origin
if "dist_strategy" in fleet_opt and \
......@@ -2186,7 +2194,8 @@ class Executor(object):
feed=None,
feed_var_name="feed",
fetch_var_name="fetch",
fetch_list=None):
fetch_list=None,
with_standalone_executor=False):
cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_program = self._get_program_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key)
......@@ -2249,10 +2258,12 @@ class Executor(object):
core.op_proto_and_checker_maker.OpRole.Optimize)
fetch_task.set_program(fetch_program)
self._prepare_fleet_executor_carrier(cache_key,
program=cached_program,
scope=cached_scope,
fleet_opt=fleet_opt)
self._prepare_fleet_executor_carrier(
cache_key,
program=cached_program,
scope=cached_scope,
fleet_opt=fleet_opt,
with_standalone_executor=with_standalone_executor)
if feed:
# NOTE: don't have to traverse programs in task nodes,
......
......@@ -15,6 +15,7 @@
import unittest
import paddle
import paddle.fluid.core as core
from paddle.distributed.fleet.fleet_executor_utils import TaskNode
paddle.enable_static()
......@@ -33,6 +34,15 @@ class TestFleetExecutorTaskNode(unittest.TestCase):
task_node_0.add_downstream_task(task_node_1.task_id(), 1))
self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id(), 1))
def test_lazy_task_node(self):
program = paddle.static.Program()
task = TaskNode(program=program,
rank=0,
max_run_times=1,
max_slot_times=1,
lazy_initialize=True)
task_node = task.task_node()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid.core as core
from paddle.distributed.fleet.fleet_executor_utils import TaskNode, FleetExecutorUtils
paddle.enable_static()
class TestFleetExecutorUtils(unittest.TestCase):
def test_construct_program(self):
# TODO(liyurui): These functions are not ready now.
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding_configs = {
"dp_degree": 2,
"mp_degree": 2,
"pp_degree": 2
}
fleet_executor_utils = FleetExecutorUtils(
dist_strategy=strategy.sharding_configs,
rank=0,
nrank=1,
max_run_times=1)
op_list = {"lr": [], "fwd": [], "bwd": [], "opt": []}
program_map = fleet_executor_utils.convert_op_list_to_program(
op_list, paddle.static.Program())
task_node_map = fleet_executor_utils.construct_task_nodes_1f1b(
program_map)
if __name__ == "__main__":
unittest.main()
......@@ -51,9 +51,11 @@ class TestFleetExecutor(unittest.TestCase):
task_node = TaskNode(
# must clone, if copies, there will be two fetches and two feeds
program=empty_program.clone(),
cur_rank=0,
rank=0,
node_type="Compute",
max_run_times=1,
max_slot_times=1)
max_slot_times=1,
lazy_initialize=True)
empty_program._pipeline_opt = {
"fleet_opt": {
'tasks': [task_node],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册