未验证 提交 9ccdb5fa 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Using program to be the only interface of TaskNode (#43869)

上级 c8d3b82c
......@@ -41,6 +41,20 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// TODO(liyurui): Will be removed when execute program is supported.
Init();
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
......
......@@ -55,6 +55,12 @@ class TaskNode final {
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
// TODO(liyurui): This will be the only constructor for task node
TaskNode(paddle::framework::ProgramDesc* program,
int64_t task_id,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
......
......@@ -165,6 +165,11 @@ void BindFleetExecutor(py::module* m) {
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
int64_t,
int64_t,
int64_t>())
.def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t,
const std::vector<framework::OpDesc*>&,
......
......@@ -717,6 +717,9 @@ class Executor(object):
self._executor_cache = _ExecutorCache(self.place)
self._fleet_executor = None
# TODO(liyurui): This option will be removed and always true when the functionality
# of fleet executor with standalone executor is ready.
self._fleet_executor_with_standalone = False
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
......@@ -1323,9 +1326,12 @@ class Executor(object):
# Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None:
self._fleet_executor = _prepare_fleet_executor()
return self._run_using_fleet_executor(program=program,
feed=feed,
fetch_list=fetch_list)
return self._run_using_fleet_executor(
program=program,
feed=feed,
fetch_list=fetch_list,
with_standalone_executor=self.
_fleet_executor_with_standalone)
if "startup_program" in program._pipeline_opt:
program = program._pipeline_opt["startup_program"]
else:
......@@ -2131,7 +2137,8 @@ class Executor(object):
carrier_id="",
program=None,
scope=None,
fleet_opt=None):
fleet_opt=None,
with_standalone_executor=False):
num_micro_batches = fleet_opt[
"num_micro_batches"] if "num_micro_batches" in fleet_opt else 1
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
......@@ -2157,7 +2164,8 @@ class Executor(object):
warnings.warn("Using 1F1B scheduler with pp_degree == 1.")
tasks, task_id_to_rank = run1f1b(
program, cur_rank, fleet_opt.get('num_micro_batches', 1),
fleet_opt.get('dist_strategy', {}), nrank)
fleet_opt.get('dist_strategy', {}), nrank,
with_standalone_executor)
elif scheduler == 'Origin':
from paddle.distributed.fleet.fleet_executor_utils import origin
if "dist_strategy" in fleet_opt and \
......@@ -2186,7 +2194,8 @@ class Executor(object):
feed=None,
feed_var_name="feed",
fetch_var_name="fetch",
fetch_list=None):
fetch_list=None,
with_standalone_executor=False):
cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_program = self._get_program_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key)
......@@ -2249,10 +2258,12 @@ class Executor(object):
core.op_proto_and_checker_maker.OpRole.Optimize)
fetch_task.set_program(fetch_program)
self._prepare_fleet_executor_carrier(cache_key,
program=cached_program,
scope=cached_scope,
fleet_opt=fleet_opt)
self._prepare_fleet_executor_carrier(
cache_key,
program=cached_program,
scope=cached_scope,
fleet_opt=fleet_opt,
with_standalone_executor=with_standalone_executor)
if feed:
# NOTE: don't have to traverse programs in task nodes,
......
......@@ -15,6 +15,7 @@
import unittest
import paddle
import paddle.fluid.core as core
from paddle.distributed.fleet.fleet_executor_utils import TaskNode
paddle.enable_static()
......@@ -33,6 +34,15 @@ class TestFleetExecutorTaskNode(unittest.TestCase):
task_node_0.add_downstream_task(task_node_1.task_id(), 1))
self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id(), 1))
def test_lazy_task_node(self):
program = paddle.static.Program()
task = TaskNode(program=program,
rank=0,
max_run_times=1,
max_slot_times=1,
lazy_initialize=True)
task_node = task.task_node()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid.core as core
from paddle.distributed.fleet.fleet_executor_utils import TaskNode, FleetExecutorUtils
paddle.enable_static()
class TestFleetExecutorUtils(unittest.TestCase):
def test_construct_program(self):
# TODO(liyurui): These functions are not ready now.
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding_configs = {
"dp_degree": 2,
"mp_degree": 2,
"pp_degree": 2
}
fleet_executor_utils = FleetExecutorUtils(
dist_strategy=strategy.sharding_configs,
rank=0,
nrank=1,
max_run_times=1)
op_list = {"lr": [], "fwd": [], "bwd": [], "opt": []}
program_map = fleet_executor_utils.convert_op_list_to_program(
op_list, paddle.static.Program())
task_node_map = fleet_executor_utils.construct_task_nodes_1f1b(
program_map)
if __name__ == "__main__":
unittest.main()
......@@ -51,9 +51,11 @@ class TestFleetExecutor(unittest.TestCase):
task_node = TaskNode(
# must clone, if copies, there will be two fetches and two feeds
program=empty_program.clone(),
cur_rank=0,
rank=0,
node_type="Compute",
max_run_times=1,
max_slot_times=1)
max_slot_times=1,
lazy_initialize=True)
empty_program._pipeline_opt = {
"fleet_opt": {
'tasks': [task_node],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册