From a4afb97ac2212bfd687463a8d277297c9347276b Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 21 Dec 2021 16:57:58 +0800 Subject: [PATCH] [fleet_executor] Python side fleet executor and task node (#38290) --- .../distributed/fleet_executor/task_node.cc | 28 ++- .../distributed/fleet_executor/task_node.h | 8 +- paddle/fluid/pybind/bind_fleet_executor.cc | 7 +- .../distributed/fleet/fleet_executor_utils.py | 146 +++++++++++++-- python/paddle/fluid/executor.py | 173 +++++++++++++++--- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../test_fleet_executor_with_task_nodes.py | 82 +++++++++ 7 files changed, 384 insertions(+), 61 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index f03ee0acb47..656cfc431cd 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -23,20 +23,35 @@ namespace { using OperatorBase = TaskNode::OperatorBase; } -TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank, +TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, int64_t max_run_times, int64_t max_slot_nums) : program_(program), rank_(rank), max_run_times_(max_run_times), max_slot_nums_(max_slot_nums) { // Should be serially invoked, not thread-safe + // NOTE: when instantiate TaskNode with program, won't init task node + // immediately, since the provided program may be updated later (with + // high probability) by adding_feed_fetch_ops or by RuntimeGraph. + // So, delay the init part to the Init() function. static int64_t task_node_cnt = 0; task_id_ = task_node_cnt++; - for (const auto& op_desc : program.Block(0).AllOps()) { - ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc)); - } - for (const auto& op : ops_vec_) { - ops_.emplace_back(op.get()); +} + +void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) { + program_ = program; +} + +void TaskNode::Init() { + if (ops_.empty()) { + // Q (for fleet executor dev): should we need another reset funct? + VLOG(3) << "Task node will be inited by calling Init()."; + for (const auto& op_desc : program_->Block(0).AllOps()) { + ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc)); + } + for (const auto& op : ops_vec_) { + ops_.emplace_back(op.get()); + } } } @@ -52,6 +67,7 @@ TaskNode::TaskNode(int32_t role, if (op_descs.empty()) { return; } + VLOG(3) << "Task node will be inited by providing list of ops."; for (const auto& desc : op_descs) { ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*desc)); } diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index d43cd99926d..b9c1361dc99 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -40,10 +40,12 @@ class TaskNode final { TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); - TaskNode(const paddle::framework::ProgramDesc& program, int64_t rank, + TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, int64_t max_run_times, int64_t max_slot_nums); ~TaskNode() = default; + void SetProgram(paddle::framework::ProgramDesc* program); + void Init(); int64_t rank() const { return rank_; } int64_t task_id() const { return task_id_; } int32_t role() const { return role_; } @@ -60,7 +62,7 @@ class TaskNode final { return downstream_; } const std::string& type() const { return type_; } - const paddle::framework::ProgramDesc& program() const { return program_; } + const paddle::framework::ProgramDesc* program() const { return program_; } const std::vector& ops() const { return ops_; } const std::vector>& unique_ops() const { return ops_vec_; @@ -94,7 +96,7 @@ class TaskNode final { // task_id-->buff_size std::unordered_map upstream_; std::unordered_map downstream_; - framework::ProgramDesc program_; + framework::ProgramDesc* program_; std::vector> ops_vec_; std::unordered_map> unused_vars_; diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index cd7e559aa16..b2ace4c0b57 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -29,6 +29,7 @@ namespace pybind { using paddle::distributed::FleetExecutor; using paddle::distributed::TaskNode; using paddle::framework::OpDesc; +using paddle::framework::ProgramDesc; void BindFleetExecutor(py::module* m) { py::class_(*m, "FleetExecutor") @@ -38,7 +39,7 @@ void BindFleetExecutor(py::module* m) { py::call_guard()); py::class_(*m, "TaskNode") - .def(py::init()) + .def(py::init()) .def(py::init&, int64_t, int64_t, int64_t, int64_t>()) .def("task_id", &TaskNode::task_id) @@ -47,7 +48,9 @@ void BindFleetExecutor(py::module* m) { .def("set_run_pre_steps", &TaskNode::SetRunPerSteps) .def("set_run_at_offset", &TaskNode::SetRunAtOffset) .def("set_type", &TaskNode::SetType) - .def("role", &TaskNode::role); + .def("role", &TaskNode::role) + .def("init", &TaskNode::Init) + .def("set_program", &TaskNode::SetProgram); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/distributed/fleet/fleet_executor_utils.py b/python/paddle/distributed/fleet/fleet_executor_utils.py index dba3388d18f..67b4b5e8fe2 100644 --- a/python/paddle/distributed/fleet/fleet_executor_utils.py +++ b/python/paddle/distributed/fleet/fleet_executor_utils.py @@ -16,6 +16,94 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY from paddle.fluid import core +class TaskNode: + """ + Python side TaskNode, connection to the c++ side TaskNode + """ + + # track the previous init method + previous = None + + def __init__(self, + cur_rank, + max_run_times, + max_slot_times, + role=None, + node_type='Compute', + task_id=None, + ops=None, + program=None): + """ + :param cur_rank (int): Current rank of the task node. + :param max_run_times (int): The max run times of the task node. + :param max_slot_times (int): The mas slot times of the task node. + :param role (int): The role of the task node. + :param node_type (str): The type of the task node, default is 'Compute' + :param task_id (int): The task id of the task node. + :param ops (list): A list of op.desc to init the task node. + :param program (Program): An instance of Program to init the task node. + """ + # NOTE: ops should be checked by `is not None`, since it may be empty list + assert ((ops is not None) ^ (program is not None)), \ + "Should provide only one of ops or program to task node." + if not self.previous: + self.previous = 'program' if program else 'ops' + assert (program is not None and self.previous == 'program') or \ + (ops is not None and self.previous == 'ops'), \ + "In one program, task node should be inited in the same way, all by ops or all by program." + if ops is not None: + assert role is not None and task_id is not None, \ + "If init task node with ops, should provide `role` and `task_id`." + self.node = core.TaskNode(role, ops, cur_rank, + int(task_id), max_run_times, + max_slot_times) + print("Creating task node by ops. The role is:", + self.role(), "and the id is:", self.task_id()) + else: + self.program = program + self.node = core.TaskNode(program.desc, cur_rank, max_run_times, + max_slot_times) + print("Creating task node by program. The id is:", self.task_id()) + self.node.set_type(node_type) + + def set_type(self, interceptor_type): + self.node.set_type(interceptor_type) + + def task_node(self): + if hasattr(self, 'program'): + print( + "The task node has been instantiated by program, calling init before passing to fleet executor." + ) + self.node.init() + return self.node + + def set_program(self, program): + self.program = program + self.node.set_program(program.desc) + + def get_program(self): + assert hasattr(self, 'program'), 'There is no program to get' + return self.program + + def set_run_pre_steps(self, steps): + self.node.set_run_pre_steps(steps) + + def set_run_at_offset(self, offset): + self.node.set_run_at_offset(offset) + + def add_upstream_task(self, upstream, buffer_size): + self.node.add_upstream_task(upstream, buffer_size) + + def add_downstream_task(self, downstream, buffer_size): + self.node.add_downstream_task(downstream, buffer_size) + + def role(self): + return self.node.role() + + def task_id(self): + return self.node.task_id() + + class CoordSys: """ This class is used to mapping rank to (mp rank, sharding rank, pp rank, dp rank). @@ -81,12 +169,12 @@ def is_lr_sched_op(op_role): def is_forward_op(op_role): return (op_role == int(OpRole.Forward)) or \ - (op_role == (int(OpRole.Forward) ^ int(OpRole.Loss))) + (op_role == (int(OpRole.Forward) | int(OpRole.Loss))) def is_backward_op(op_role): return (op_role == int(OpRole.Backward)) or \ - (op_role == (int(OpRole.Backward) ^ int(OpRole.Loss))) + (op_role == (int(OpRole.Backward) | int(OpRole.Loss))) def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): @@ -110,14 +198,6 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): max_slot_times = int(max_run_times - coord['pp_idx']) num_of_functionality = 4 - def create_task_node(role, ops, offset, node_type): - task_id = int(cur_rank * num_of_functionality + offset) - print("Creating task node with role:", role, "and with id:", task_id) - node = core.TaskNode(role, ops, cur_rank, task_id, max_run_times, - max_slot_times) - node.set_type(node_type) - return node - lr_ops, fwd_ops, bwd_ops, opt_ops = [], [], [], [] for op in program.block(0).ops: # split the program based on the op_role @@ -138,14 +218,39 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): # Create task nodes. # The lr_sched and opt should be 'amplifier interceptor. # The fwd and bwd should be 'compute interceptor'. - lr_task_node = create_task_node( - int(OpRole.Optimize.LRSched), lr_ops, 0, "Amplifier") + lr_task_node = TaskNode( + cur_rank=cur_rank, + max_run_times=max_run_times, + max_slot_times=max_slot_times, + role=int(OpRole.Optimize.LRSched), + ops=lr_ops, + task_id=int(cur_rank * num_of_functionality + 0), + node_type="Amplifier") lr_task_node.set_run_pre_steps(max_run_times) - fwd_task_node = create_task_node(int(OpRole.Forward), fwd_ops, 1, "Compute") - bwd_task_node = create_task_node( - int(OpRole.Backward), bwd_ops, 2, "Compute") - opt_task_node = create_task_node( - int(OpRole.Optimize), opt_ops, 3, "Amplifier") + fwd_task_node = TaskNode( + cur_rank=cur_rank, + max_run_times=max_run_times, + max_slot_times=max_slot_times, + role=int(OpRole.Forward), + ops=fwd_ops, + task_id=int(cur_rank * num_of_functionality + 1), + node_type="Compute") + bwd_task_node = TaskNode( + cur_rank=cur_rank, + max_run_times=max_run_times, + max_slot_times=max_slot_times, + role=int(OpRole.Backward), + ops=bwd_ops, + task_id=int(cur_rank * num_of_functionality + 2), + node_type="Compute") + opt_task_node = TaskNode( + cur_rank=cur_rank, + max_run_times=max_run_times, + max_slot_times=max_slot_times, + role=int(OpRole.Optimize), + ops=opt_ops, + task_id=int(cur_rank * num_of_functionality + 3), + node_type="Amplifier") opt_task_node.set_run_pre_steps(max_run_times) opt_task_node.set_run_at_offset(max_run_times - 1) task_nodes = [lr_task_node, fwd_task_node, bwd_task_node, opt_task_node] @@ -200,7 +305,7 @@ def run1f1b(program, cur_rank, max_run_times, dist_opt, nrank): for i in range(nrank): for j in range(num_of_functionality): task_id_to_rank[int(i * num_of_functionality + j)] = i - return task_nodes, task_id_to_rank + return [task_node.task_node() for task_node in task_nodes], task_id_to_rank def origin(program, cur_rank): @@ -213,8 +318,9 @@ def origin(program, cur_rank): task_id_to_rank (dict): a fake dict, since there is no upstream or downstream, this dict won't be used """ print("fleet executor will use python side origin scheduler.") - task_node = core.TaskNode(program.desc, cur_rank, 1, 1) + task_node = TaskNode( + program=program, cur_rank=cur_rank, max_run_times=1, max_slot_times=1) task_node.set_type("Compute") task_id = task_node.task_id() task_id_to_rank = {task_id: cur_rank} - return [task_node], task_id_to_rank + return [task_node.task_node()], task_id_to_rank diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index a65370d99a8..710e86e14b5 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1958,7 +1958,6 @@ class Executor(object): def _prepare_fleet_executor(self, program=None, scope=None, fleet_opt=None): from ..distributed.fleet.proto import fleet_executor_desc_pb2 - from google.protobuf import text_format assert program, "Program for fleet executor should not be None" assert fleet_opt, "Configurations for fleet executor should not be None" trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "") @@ -1974,35 +1973,44 @@ class Executor(object): fleet_exe_desc.cluster_info.append(rank_info) if "num_micro_batches" in fleet_opt: fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] - assert 'scheduler' in fleet_opt, \ - "Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin." - scheduler = fleet_opt['scheduler'] - if scheduler == '1F1B': - from paddle.distributed.fleet.fleet_executor_utils import run1f1b - if "dist_strategy" not in fleet_opt or \ - "pp_degree" not in fleet_opt["dist_strategy"] or \ - fleet_opt["dist_strategy"]["pp_degree"] == 1: - warnings.warn("Using 1F1B scheduler with pp_degree == 1.") - tasks, task_id_to_rank = run1f1b( - program, cur_rank, - fleet_opt.get('num_micro_batches', 1), - fleet_opt.get('dist_strategy', {}), nrank) - elif scheduler == 'Origin': - from paddle.distributed.fleet.fleet_executor_utils import origin - if "dist_strategy" in fleet_opt and \ - "pp_degree" in fleet_opt["dist_strategy"]: - assert fleet_opt["dist_strategy"]["pp_degree"] == 1, \ - "For pipeline mode, the scheduler should be 1F1B instead of Origin." - if "num_micro_batches" in fleet_opt: - assert fleet_opt["num_micro_batches"] == 1, \ - "For origin scheduler mode, the num micro batches should be 1." - tasks, task_id_to_rank = origin(program, cur_rank) + + assert 'scheduler' in fleet_opt or 'tasks' in fleet_opt, \ + "Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. " \ + "Or you can provide a list of task nodes to init fleet executor directly." + if 'tasks' in fleet_opt: + assert 'task_id_to_rank' in fleet_opt, "If you provide tasks to init fleet executor," \ + " task_id_to_rank should also be provided." + print('fleet executor will use user defined task nodes') + tasks = [task.task_node() for task in fleet_opt['tasks']] + task_id_to_rank = fleet_opt['task_id_to_rank'] else: - raise "Fleet_executor only supports 1F1B and Origin scheduler, " \ - "but received " + str(scheduler) + "." - # NOTE: have to hold these vars, otherwise will be destructed - fleet_opt['tasks'] = tasks - fleet_opt['task_id_to_rank'] = task_id_to_rank + scheduler = fleet_opt['scheduler'] + if scheduler == '1F1B': + from paddle.distributed.fleet.fleet_executor_utils import run1f1b + if "dist_strategy" not in fleet_opt or \ + "pp_degree" not in fleet_opt["dist_strategy"] or \ + fleet_opt["dist_strategy"]["pp_degree"] == 1: + warnings.warn("Using 1F1B scheduler with pp_degree == 1.") + tasks, task_id_to_rank = run1f1b( + program, cur_rank, + fleet_opt.get('num_micro_batches', 1), + fleet_opt.get('dist_strategy', {}), nrank) + elif scheduler == 'Origin': + from paddle.distributed.fleet.fleet_executor_utils import origin + if "dist_strategy" in fleet_opt and \ + "pp_degree" in fleet_opt["dist_strategy"]: + assert fleet_opt["dist_strategy"]["pp_degree"] == 1, \ + "For pipeline mode, the scheduler should be 1F1B instead of Origin." + if "num_micro_batches" in fleet_opt: + assert fleet_opt["num_micro_batches"] == 1, \ + "For origin scheduler mode, the num micro batches should be 1." + tasks, task_id_to_rank = origin(program, cur_rank) + else: + raise "Fleet_executor only supports 1F1B and Origin scheduler, " \ + "but received " + str(scheduler) + "." + # NOTE: have to hold these vars, otherwise will be destructed + fleet_opt['tasks'] = tasks + fleet_opt['task_id_to_rank'] = task_id_to_rank fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) place = core.Place() place.set_place(self.place) @@ -2019,11 +2027,11 @@ class Executor(object): cached_ctx = self._get_ctx_cache(cache_key) cached_scope = self._get_scope_cache(cache_key) cached_program = self._get_program_cache(cache_key) + real_feed = [] if feed is None else feed if cached_scope is None: cached_scope = global_scope() self._add_scope_cache(cache_key, cached_scope) if cached_program is None: - real_feed = [] if feed is None else feed real_program = program if "section_program" in program._pipeline_opt: real_program = program._pipeline_opt["section_program"] @@ -2044,10 +2052,48 @@ class Executor(object): self._add_program_cache(cache_key, cached_program) if cached_ctx is None: fleet_opt = program._pipeline_opt["fleet_opt"] + if 'tasks' in fleet_opt: + # Insert feed/fetch op for cloned program in each task node, + # these ops has already been inserted into the origin program. + # To avoid every task nodes all have feed/fetch ops, + # only insert feed ops into the first task node, + # then insert fetch ops into the last task node. + + # Insert feed ops + feed_task = fleet_opt['tasks'][0] + print("Inserting feed ops for task", feed_task.task_id()) + feed_program = feed_task.get_program() + feed_program = self._add_feed_ops( + program=feed_program, + feed=real_feed, + feed_var_name=feed_var_name) + feed_task.set_program(feed_program) + + # Insert fetch ops + fetch_task = fleet_opt['tasks'][-1] + print("Inserting fetch ops for task", fetch_task.task_id()) + fetch_program = fetch_task.get_program() + fetch_program = self._add_fetch_ops( + program=fetch_program, + fetch_list=fetch_list, + fetch_var_name=fetch_var_name) + main_block = fetch_program.block(0) + for op in main_block.ops: + # set the op_role of fetch op to Optimize to avoid + # erase the fetched vars by gc for pipeline + if op.type == 'fetch': + op._set_attr( + 'op_role', + core.op_proto_and_checker_maker.OpRole.Optimize) + fetch_task.set_program(fetch_program) + cached_ctx = self._prepare_fleet_executor( program=cached_program, scope=cached_scope, fleet_opt=fleet_opt) self._add_ctx_cache(cache_key, cached_ctx) if feed: + # NOTE: don't have to traverse programs in task nodes, + # since they all sub program of cached program and + # cached program is also added feed fetch var self._feed_data(cached_program, feed, feed_var_name, cached_scope) from paddle.optimizer.lr import LRScheduler @@ -2068,6 +2114,73 @@ class Executor(object): return as_numpy(tensors) return None + def _add_feed_ops(self, program, feed, feed_var_name): + tmp_program = program.clone() + + global_block = tmp_program.global_block() + + if feed_var_name in global_block.vars: + feed_var = global_block.var(feed_var_name) + else: + feed_var = global_block.create_var( + name=feed_var_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) + + # prepend feed operators + if not has_feed_operators(global_block, feed, feed_var_name): + for i, name in enumerate(feed): + if global_block.has_var(name): + out = global_block.var(name) + global_block._prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + else: + warnings.warn( + "The variable %s is not found in program. It is not declared or is pruned." + % name) + + return tmp_program + + def _add_fetch_ops(self, + program, + fetch_list, + fetch_var_name, + use_fetch_v2=False): + tmp_program = program.clone() + + global_block = tmp_program.global_block() + + if fetch_var_name in global_block.vars: + fetch_var = global_block.var(fetch_var_name) + else: + fetch_var = global_block.create_var( + name=fetch_var_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) + + if use_fetch_v2: + fetch_op = 'fetch_v2' + else: + fetch_op = 'fetch' + + # append fetch_operators + if not has_fetch_operators(global_block, fetch_list, fetch_var_name, + fetch_op): + for i, var in enumerate(fetch_list): + assert isinstance(var, Variable) or isinstance( + var, six.string_types), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) + global_block.append_op( + type=fetch_op, + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + return tmp_program + def _run_pipeline(self, program=None, dataset=None, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 598ca24f595..791a8bf77ba 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -145,6 +145,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_fleet_gradient_scale) LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor) + LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_with_task_nodes) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_origin_scheduler) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py new file mode 100644 index 00000000000..61064175266 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.distributed.fleet.fleet_executor_utils import TaskNode + +paddle.enable_static() + + +class TestFleetExecutor(unittest.TestCase): + def run_fleet_executor(self, place, x_data, y_data): + exe = paddle.static.Executor(place) + empty_program = paddle.static.Program() + with fluid.program_guard(empty_program, empty_program): + x = fluid.layers.data( + name='x', shape=x_data.shape, dtype=x_data.dtype) + y = fluid.layers.data( + name='y', shape=y_data.shape, dtype=y_data.dtype) + z = x + y + a = 2 * x + 3 * y + loss = paddle.mean(a) + base_lr = 0.1 + passes = [30, 60, 80, 90] + steps_per_pass = 10 + bd = [steps_per_pass * p for p in passes] + lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] + lr_val = paddle.optimizer.lr.PiecewiseDecay( + boundaries=bd, values=lr) + opt = paddle.optimizer.AdamW( + learning_rate=lr_val, + grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)) + opt.minimize(loss) + # TODO: section_program will be removed in the future + task_node = TaskNode( + # must clone, if copies, there will be two fetches and two feeds + program=empty_program.clone(), + cur_rank=0, + max_run_times=1, + max_slot_times=1) + empty_program._pipeline_opt = { + "fleet_opt": { + 'tasks': [task_node], + 'task_id_to_rank': { + task_node.task_id(): 0 + } + }, + "section_program": empty_program + } + res = exe.run(empty_program, + feed={'x': x_data, + 'y': y_data}, + fetch_list=[z.name, a.name]) + return res + + def test_executor_on_single_device(self): + if fluid.is_compiled_with_cuda(): + shape = (10000, 3462) + x_data = np.random.rand(*shape) + y_data = np.random.rand(*shape) + z_data = x_data + y_data + a_data = 2 * x_data + 3 * y_data + res = self.run_fleet_executor(fluid.CUDAPlace(0), x_data, y_data) + self.assertTrue(np.allclose(res[0], z_data)) + self.assertTrue(np.allclose(res[1], a_data)) + + +if __name__ == "__main__": + unittest.main() -- GitLab