未验证 提交 a4afb97a 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Python side fleet executor and task node (#38290)

上级 2005b98b
...@@ -23,20 +23,35 @@ namespace { ...@@ -23,20 +23,35 @@ namespace {
using OperatorBase = TaskNode::OperatorBase; using OperatorBase = TaskNode::OperatorBase;
} }
TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank, TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums) int64_t max_run_times, int64_t max_slot_nums)
: program_(program), : program_(program),
rank_(rank), rank_(rank),
max_run_times_(max_run_times), max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) { max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe // Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0; static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++; task_id_ = task_node_cnt++;
for (const auto& op_desc : program.Block(0).AllOps()) { }
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc));
} void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
for (const auto& op : ops_vec_) { program_ = program;
ops_.emplace_back(op.get()); }
void TaskNode::Init() {
if (ops_.empty()) {
// Q (for fleet executor dev): should we need another reset funct?
VLOG(3) << "Task node will be inited by calling Init().";
for (const auto& op_desc : program_->Block(0).AllOps()) {
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc));
}
for (const auto& op : ops_vec_) {
ops_.emplace_back(op.get());
}
} }
} }
...@@ -52,6 +67,7 @@ TaskNode::TaskNode(int32_t role, ...@@ -52,6 +67,7 @@ TaskNode::TaskNode(int32_t role,
if (op_descs.empty()) { if (op_descs.empty()) {
return; return;
} }
VLOG(3) << "Task node will be inited by providing list of ops.";
for (const auto& desc : op_descs) { for (const auto& desc : op_descs) {
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*desc)); ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*desc));
} }
......
...@@ -40,10 +40,12 @@ class TaskNode final { ...@@ -40,10 +40,12 @@ class TaskNode final {
TaskNode(int32_t role, const std::vector<framework::OperatorBase*>& ops, TaskNode(int32_t role, const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t task_id, int64_t max_run_times, int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums); int64_t max_slot_nums);
TaskNode(const paddle::framework::ProgramDesc& program, int64_t rank, TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums); int64_t max_run_times, int64_t max_slot_nums);
~TaskNode() = default; ~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
void Init();
int64_t rank() const { return rank_; } int64_t rank() const { return rank_; }
int64_t task_id() const { return task_id_; } int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; } int32_t role() const { return role_; }
...@@ -60,7 +62,7 @@ class TaskNode final { ...@@ -60,7 +62,7 @@ class TaskNode final {
return downstream_; return downstream_;
} }
const std::string& type() const { return type_; } const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; } const paddle::framework::ProgramDesc* program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; } const std::vector<OperatorBase*>& ops() const { return ops_; }
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const { const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_; return ops_vec_;
...@@ -94,7 +96,7 @@ class TaskNode final { ...@@ -94,7 +96,7 @@ class TaskNode final {
// task_id-->buff_size // task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_; std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_; std::unordered_map<int64_t, int64_t> downstream_;
framework::ProgramDesc program_; framework::ProgramDesc* program_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_; std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>> std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_; unused_vars_;
......
...@@ -29,6 +29,7 @@ namespace pybind { ...@@ -29,6 +29,7 @@ namespace pybind {
using paddle::distributed::FleetExecutor; using paddle::distributed::FleetExecutor;
using paddle::distributed::TaskNode; using paddle::distributed::TaskNode;
using paddle::framework::OpDesc; using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
void BindFleetExecutor(py::module* m) { void BindFleetExecutor(py::module* m) {
py::class_<FleetExecutor>(*m, "FleetExecutor") py::class_<FleetExecutor>(*m, "FleetExecutor")
...@@ -38,7 +39,7 @@ void BindFleetExecutor(py::module* m) { ...@@ -38,7 +39,7 @@ void BindFleetExecutor(py::module* m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::class_<TaskNode>(*m, "TaskNode") py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<const framework::ProgramDesc&, int64_t, int64_t, int64_t>()) .def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t, const std::vector<framework::OpDesc*>&, int64_t, .def(py::init<int32_t, const std::vector<framework::OpDesc*>&, int64_t,
int64_t, int64_t, int64_t>()) int64_t, int64_t, int64_t>())
.def("task_id", &TaskNode::task_id) .def("task_id", &TaskNode::task_id)
...@@ -47,7 +48,9 @@ void BindFleetExecutor(py::module* m) { ...@@ -47,7 +48,9 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_pre_steps", &TaskNode::SetRunPerSteps) .def("set_run_pre_steps", &TaskNode::SetRunPerSteps)
.def("set_run_at_offset", &TaskNode::SetRunAtOffset) .def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType) .def("set_type", &TaskNode::SetType)
.def("role", &TaskNode::role); .def("role", &TaskNode::role)
.def("init", &TaskNode::Init)
.def("set_program", &TaskNode::SetProgram);
} }
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,94 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY ...@@ -16,6 +16,94 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY
from paddle.fluid import core from paddle.fluid import core
class TaskNode:
"""
Python side TaskNode, connection to the c++ side TaskNode
"""
# track the previous init method
previous = None
def __init__(self,
cur_rank,
max_run_times,
max_slot_times,
role=None,
node_type='Compute',
task_id=None,
ops=None,
program=None):
"""
:param cur_rank (int): Current rank of the task node.
:param max_run_times (int): The max run times of the task node.
:param max_slot_times (int): The mas slot times of the task node.
:param role (int): The role of the task node.
:param node_type (str): The type of the task node, default is 'Compute'
:param task_id (int): The task id of the task node.
:param ops (list): A list of op.desc to init the task node.
:param program (Program): An instance of Program to init the task node.
"""
# NOTE: ops should be checked by `is not None`, since it may be empty list
assert ((ops is not None) ^ (program is not None)), \
"Should provide only one of ops or program to task node."
if not self.previous:
self.previous = 'program' if program else 'ops'
assert (program is not None and self.previous == 'program') or \
(ops is not None and self.previous == 'ops'), \
"In one program, task node should be inited in the same way, all by ops or all by program."
if ops is not None:
assert role is not None and task_id is not None, \
"If init task node with ops, should provide `role` and `task_id`."
self.node = core.TaskNode(role, ops, cur_rank,
int(task_id), max_run_times,
max_slot_times)
print("Creating task node by ops. The role is:",
self.role(), "and the id is:", self.task_id())
else:
self.program = program
self.node = core.TaskNode(program.desc, cur_rank, max_run_times,
max_slot_times)
print("Creating task node by program. The id is:", self.task_id())
self.node.set_type(node_type)
def set_type(self, interceptor_type):
self.node.set_type(interceptor_type)
def task_node(self):
if hasattr(self, 'program'):
print(
"The task node has been instantiated by program, calling init before passing to fleet executor."
)
self.node.init()
return self.node
def set_program(self, program):
self.program = program
self.node.set_program(program.desc)
def get_program(self):
assert hasattr(self, 'program'), 'There is no program to get'
return self.program
def set_run_pre_steps(self, steps):
self.node.set_run_pre_steps(steps)
def set_run_at_offset(self, offset):
self.node.set_run_at_offset(offset)
def add_upstream_task(self, upstream, buffer_size):
self.node.add_upstream_task(upstream, buffer_size)
def add_downstream_task(self, downstream, buffer_size):
self.node.add_downstream_task(downstream, buffer_size)
def role(self):
return self.node.role()
def task_id(self):
return self.node.task_id()
class CoordSys: class CoordSys:
""" """
This class is used to mapping rank to (mp rank, sharding rank, pp rank, dp rank). This class is used to mapping rank to (mp rank, sharding rank, pp rank, dp rank).
...@@ -81,12 +169,12 @@ def is_lr_sched_op(op_role): ...@@ -81,12 +169,12 @@ def is_lr_sched_op(op_role):
def is_forward_op(op_role): def is_forward_op(op_role):
return (op_role == int(OpRole.Forward)) or \ return (op_role == int(OpRole.Forward)) or \
(op_role == (int(OpRole.Forward) ^ int(OpRole.Loss))) (op_role == (int(OpRole.Forward) | int(OpRole.Loss)))
def is_backward_op(op_role): def is_backward_op(op_role):
return (op_role == int(OpRole.Backward)) or \ return (op_role == int(OpRole.Backward)) or \
(op_role == (int(OpRole.Backward) ^ int(OpRole.Loss))) (op_role == (int(OpRole.Backward) | int(OpRole.Loss)))
def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank):
...@@ -110,14 +198,6 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): ...@@ -110,14 +198,6 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank):
max_slot_times = int(max_run_times - coord['pp_idx']) max_slot_times = int(max_run_times - coord['pp_idx'])
num_of_functionality = 4 num_of_functionality = 4
def create_task_node(role, ops, offset, node_type):
task_id = int(cur_rank * num_of_functionality + offset)
print("Creating task node with role:", role, "and with id:", task_id)
node = core.TaskNode(role, ops, cur_rank, task_id, max_run_times,
max_slot_times)
node.set_type(node_type)
return node
lr_ops, fwd_ops, bwd_ops, opt_ops = [], [], [], [] lr_ops, fwd_ops, bwd_ops, opt_ops = [], [], [], []
for op in program.block(0).ops: for op in program.block(0).ops:
# split the program based on the op_role # split the program based on the op_role
...@@ -138,14 +218,39 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): ...@@ -138,14 +218,39 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank):
# Create task nodes. # Create task nodes.
# The lr_sched and opt should be 'amplifier interceptor. # The lr_sched and opt should be 'amplifier interceptor.
# The fwd and bwd should be 'compute interceptor'. # The fwd and bwd should be 'compute interceptor'.
lr_task_node = create_task_node( lr_task_node = TaskNode(
int(OpRole.Optimize.LRSched), lr_ops, 0, "Amplifier") cur_rank=cur_rank,
max_run_times=max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize.LRSched),
ops=lr_ops,
task_id=int(cur_rank * num_of_functionality + 0),
node_type="Amplifier")
lr_task_node.set_run_pre_steps(max_run_times) lr_task_node.set_run_pre_steps(max_run_times)
fwd_task_node = create_task_node(int(OpRole.Forward), fwd_ops, 1, "Compute") fwd_task_node = TaskNode(
bwd_task_node = create_task_node( cur_rank=cur_rank,
int(OpRole.Backward), bwd_ops, 2, "Compute") max_run_times=max_run_times,
opt_task_node = create_task_node( max_slot_times=max_slot_times,
int(OpRole.Optimize), opt_ops, 3, "Amplifier") role=int(OpRole.Forward),
ops=fwd_ops,
task_id=int(cur_rank * num_of_functionality + 1),
node_type="Compute")
bwd_task_node = TaskNode(
cur_rank=cur_rank,
max_run_times=max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Backward),
ops=bwd_ops,
task_id=int(cur_rank * num_of_functionality + 2),
node_type="Compute")
opt_task_node = TaskNode(
cur_rank=cur_rank,
max_run_times=max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize),
ops=opt_ops,
task_id=int(cur_rank * num_of_functionality + 3),
node_type="Amplifier")
opt_task_node.set_run_pre_steps(max_run_times) opt_task_node.set_run_pre_steps(max_run_times)
opt_task_node.set_run_at_offset(max_run_times - 1) opt_task_node.set_run_at_offset(max_run_times - 1)
task_nodes = [lr_task_node, fwd_task_node, bwd_task_node, opt_task_node] task_nodes = [lr_task_node, fwd_task_node, bwd_task_node, opt_task_node]
...@@ -200,7 +305,7 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): ...@@ -200,7 +305,7 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank):
for i in range(nrank): for i in range(nrank):
for j in range(num_of_functionality): for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i task_id_to_rank[int(i * num_of_functionality + j)] = i
return task_nodes, task_id_to_rank return [task_node.task_node() for task_node in task_nodes], task_id_to_rank
def origin(program, cur_rank): def origin(program, cur_rank):
...@@ -213,8 +318,9 @@ def origin(program, cur_rank): ...@@ -213,8 +318,9 @@ def origin(program, cur_rank):
task_id_to_rank (dict): a fake dict, since there is no upstream or downstream, this dict won't be used task_id_to_rank (dict): a fake dict, since there is no upstream or downstream, this dict won't be used
""" """
print("fleet executor will use python side origin scheduler.") print("fleet executor will use python side origin scheduler.")
task_node = core.TaskNode(program.desc, cur_rank, 1, 1) task_node = TaskNode(
program=program, cur_rank=cur_rank, max_run_times=1, max_slot_times=1)
task_node.set_type("Compute") task_node.set_type("Compute")
task_id = task_node.task_id() task_id = task_node.task_id()
task_id_to_rank = {task_id: cur_rank} task_id_to_rank = {task_id: cur_rank}
return [task_node], task_id_to_rank return [task_node.task_node()], task_id_to_rank
...@@ -1958,7 +1958,6 @@ class Executor(object): ...@@ -1958,7 +1958,6 @@ class Executor(object):
def _prepare_fleet_executor(self, program=None, scope=None, fleet_opt=None): def _prepare_fleet_executor(self, program=None, scope=None, fleet_opt=None):
from ..distributed.fleet.proto import fleet_executor_desc_pb2 from ..distributed.fleet.proto import fleet_executor_desc_pb2
from google.protobuf import text_format
assert program, "Program for fleet executor should not be None" assert program, "Program for fleet executor should not be None"
assert fleet_opt, "Configurations for fleet executor should not be None" assert fleet_opt, "Configurations for fleet executor should not be None"
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "") trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "")
...@@ -1974,35 +1973,44 @@ class Executor(object): ...@@ -1974,35 +1973,44 @@ class Executor(object):
fleet_exe_desc.cluster_info.append(rank_info) fleet_exe_desc.cluster_info.append(rank_info)
if "num_micro_batches" in fleet_opt: if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]
assert 'scheduler' in fleet_opt, \
"Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin." assert 'scheduler' in fleet_opt or 'tasks' in fleet_opt, \
scheduler = fleet_opt['scheduler'] "Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. " \
if scheduler == '1F1B': "Or you can provide a list of task nodes to init fleet executor directly."
from paddle.distributed.fleet.fleet_executor_utils import run1f1b if 'tasks' in fleet_opt:
if "dist_strategy" not in fleet_opt or \ assert 'task_id_to_rank' in fleet_opt, "If you provide tasks to init fleet executor," \
"pp_degree" not in fleet_opt["dist_strategy"] or \ " task_id_to_rank should also be provided."
fleet_opt["dist_strategy"]["pp_degree"] == 1: print('fleet executor will use user defined task nodes')
warnings.warn("Using 1F1B scheduler with pp_degree == 1.") tasks = [task.task_node() for task in fleet_opt['tasks']]
tasks, task_id_to_rank = run1f1b( task_id_to_rank = fleet_opt['task_id_to_rank']
program, cur_rank,
fleet_opt.get('num_micro_batches', 1),
fleet_opt.get('dist_strategy', {}), nrank)
elif scheduler == 'Origin':
from paddle.distributed.fleet.fleet_executor_utils import origin
if "dist_strategy" in fleet_opt and \
"pp_degree" in fleet_opt["dist_strategy"]:
assert fleet_opt["dist_strategy"]["pp_degree"] == 1, \
"For pipeline mode, the scheduler should be 1F1B instead of Origin."
if "num_micro_batches" in fleet_opt:
assert fleet_opt["num_micro_batches"] == 1, \
"For origin scheduler mode, the num micro batches should be 1."
tasks, task_id_to_rank = origin(program, cur_rank)
else: else:
raise "Fleet_executor only supports 1F1B and Origin scheduler, " \ scheduler = fleet_opt['scheduler']
"but received " + str(scheduler) + "." if scheduler == '1F1B':
# NOTE: have to hold these vars, otherwise will be destructed from paddle.distributed.fleet.fleet_executor_utils import run1f1b
fleet_opt['tasks'] = tasks if "dist_strategy" not in fleet_opt or \
fleet_opt['task_id_to_rank'] = task_id_to_rank "pp_degree" not in fleet_opt["dist_strategy"] or \
fleet_opt["dist_strategy"]["pp_degree"] == 1:
warnings.warn("Using 1F1B scheduler with pp_degree == 1.")
tasks, task_id_to_rank = run1f1b(
program, cur_rank,
fleet_opt.get('num_micro_batches', 1),
fleet_opt.get('dist_strategy', {}), nrank)
elif scheduler == 'Origin':
from paddle.distributed.fleet.fleet_executor_utils import origin
if "dist_strategy" in fleet_opt and \
"pp_degree" in fleet_opt["dist_strategy"]:
assert fleet_opt["dist_strategy"]["pp_degree"] == 1, \
"For pipeline mode, the scheduler should be 1F1B instead of Origin."
if "num_micro_batches" in fleet_opt:
assert fleet_opt["num_micro_batches"] == 1, \
"For origin scheduler mode, the num micro batches should be 1."
tasks, task_id_to_rank = origin(program, cur_rank)
else:
raise "Fleet_executor only supports 1F1B and Origin scheduler, " \
"but received " + str(scheduler) + "."
# NOTE: have to hold these vars, otherwise will be destructed
fleet_opt['tasks'] = tasks
fleet_opt['task_id_to_rank'] = task_id_to_rank
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place() place = core.Place()
place.set_place(self.place) place.set_place(self.place)
...@@ -2019,11 +2027,11 @@ class Executor(object): ...@@ -2019,11 +2027,11 @@ class Executor(object):
cached_ctx = self._get_ctx_cache(cache_key) cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key) cached_scope = self._get_scope_cache(cache_key)
cached_program = self._get_program_cache(cache_key) cached_program = self._get_program_cache(cache_key)
real_feed = [] if feed is None else feed
if cached_scope is None: if cached_scope is None:
cached_scope = global_scope() cached_scope = global_scope()
self._add_scope_cache(cache_key, cached_scope) self._add_scope_cache(cache_key, cached_scope)
if cached_program is None: if cached_program is None:
real_feed = [] if feed is None else feed
real_program = program real_program = program
if "section_program" in program._pipeline_opt: if "section_program" in program._pipeline_opt:
real_program = program._pipeline_opt["section_program"] real_program = program._pipeline_opt["section_program"]
...@@ -2044,10 +2052,48 @@ class Executor(object): ...@@ -2044,10 +2052,48 @@ class Executor(object):
self._add_program_cache(cache_key, cached_program) self._add_program_cache(cache_key, cached_program)
if cached_ctx is None: if cached_ctx is None:
fleet_opt = program._pipeline_opt["fleet_opt"] fleet_opt = program._pipeline_opt["fleet_opt"]
if 'tasks' in fleet_opt:
# Insert feed/fetch op for cloned program in each task node,
# these ops has already been inserted into the origin program.
# To avoid every task nodes all have feed/fetch ops,
# only insert feed ops into the first task node,
# then insert fetch ops into the last task node.
# Insert feed ops
feed_task = fleet_opt['tasks'][0]
print("Inserting feed ops for task", feed_task.task_id())
feed_program = feed_task.get_program()
feed_program = self._add_feed_ops(
program=feed_program,
feed=real_feed,
feed_var_name=feed_var_name)
feed_task.set_program(feed_program)
# Insert fetch ops
fetch_task = fleet_opt['tasks'][-1]
print("Inserting fetch ops for task", fetch_task.task_id())
fetch_program = fetch_task.get_program()
fetch_program = self._add_fetch_ops(
program=fetch_program,
fetch_list=fetch_list,
fetch_var_name=fetch_var_name)
main_block = fetch_program.block(0)
for op in main_block.ops:
# set the op_role of fetch op to Optimize to avoid
# erase the fetched vars by gc for pipeline
if op.type == 'fetch':
op._set_attr(
'op_role',
core.op_proto_and_checker_maker.OpRole.Optimize)
fetch_task.set_program(fetch_program)
cached_ctx = self._prepare_fleet_executor( cached_ctx = self._prepare_fleet_executor(
program=cached_program, scope=cached_scope, fleet_opt=fleet_opt) program=cached_program, scope=cached_scope, fleet_opt=fleet_opt)
self._add_ctx_cache(cache_key, cached_ctx) self._add_ctx_cache(cache_key, cached_ctx)
if feed: if feed:
# NOTE: don't have to traverse programs in task nodes,
# since they all sub program of cached program and
# cached program is also added feed fetch var
self._feed_data(cached_program, feed, feed_var_name, cached_scope) self._feed_data(cached_program, feed, feed_var_name, cached_scope)
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
...@@ -2068,6 +2114,73 @@ class Executor(object): ...@@ -2068,6 +2114,73 @@ class Executor(object):
return as_numpy(tensors) return as_numpy(tensors)
return None return None
def _add_feed_ops(self, program, feed, feed_var_name):
tmp_program = program.clone()
global_block = tmp_program.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
if global_block.has_var(name):
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
else:
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
return tmp_program
def _add_fetch_ops(self,
program,
fetch_list,
fetch_var_name,
use_fetch_v2=False):
tmp_program = program.clone()
global_block = tmp_program.global_block()
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
if use_fetch_v2:
fetch_op = 'fetch_v2'
else:
fetch_op = 'fetch'
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name,
fetch_op):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(
var, six.string_types), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type=fetch_op,
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
return tmp_program
def _run_pipeline(self, def _run_pipeline(self,
program=None, program=None,
dataset=None, dataset=None,
......
...@@ -145,6 +145,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) ...@@ -145,6 +145,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_fleet_gradient_scale) LIST(REMOVE_ITEM TEST_OPS test_fleet_gradient_scale)
LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler) LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_with_task_nodes)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_origin_scheduler) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_origin_scheduler)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.distributed.fleet.fleet_executor_utils import TaskNode
paddle.enable_static()
class TestFleetExecutor(unittest.TestCase):
def run_fleet_executor(self, place, x_data, y_data):
exe = paddle.static.Executor(place)
empty_program = paddle.static.Program()
with fluid.program_guard(empty_program, empty_program):
x = fluid.layers.data(
name='x', shape=x_data.shape, dtype=x_data.dtype)
y = fluid.layers.data(
name='y', shape=y_data.shape, dtype=y_data.dtype)
z = x + y
a = 2 * x + 3 * y
loss = paddle.mean(a)
base_lr = 0.1
passes = [30, 60, 80, 90]
steps_per_pass = 10
bd = [steps_per_pass * p for p in passes]
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
lr_val = paddle.optimizer.lr.PiecewiseDecay(
boundaries=bd, values=lr)
opt = paddle.optimizer.AdamW(
learning_rate=lr_val,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
opt.minimize(loss)
# TODO: section_program will be removed in the future
task_node = TaskNode(
# must clone, if copies, there will be two fetches and two feeds
program=empty_program.clone(),
cur_rank=0,
max_run_times=1,
max_slot_times=1)
empty_program._pipeline_opt = {
"fleet_opt": {
'tasks': [task_node],
'task_id_to_rank': {
task_node.task_id(): 0
}
},
"section_program": empty_program
}
res = exe.run(empty_program,
feed={'x': x_data,
'y': y_data},
fetch_list=[z.name, a.name])
return res
def test_executor_on_single_device(self):
if fluid.is_compiled_with_cuda():
shape = (10000, 3462)
x_data = np.random.rand(*shape)
y_data = np.random.rand(*shape)
z_data = x_data + y_data
a_data = 2 * x_data + 3 * y_data
res = self.run_fleet_executor(fluid.CUDAPlace(0), x_data, y_data)
self.assertTrue(np.allclose(res[0], z_data))
self.assertTrue(np.allclose(res[1], a_data))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册