From 3f815e761c204b2cb29c6fd2d293cbf3c137ec98 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Thu, 25 Nov 2021 13:09:20 +0800 Subject: [PATCH] Export task node to python (#37509) --- .../distributed/fleet_executor/task_node.cc | 21 +++++++++-- .../distributed/fleet_executor/task_node.h | 9 +++-- paddle/fluid/pybind/bind_fleet_executor.cc | 8 +++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../test_fleet_executor_task_node.py | 36 +++++++++++++++++++ 5 files changed, 70 insertions(+), 5 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index 3eee0fa3cb0..07fd091b04d 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -21,6 +21,17 @@ namespace { using OperatorBase = TaskNode::OperatorBase; } +TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank, + int64_t max_run_times, int64_t max_slot_nums) + : program_(program), + rank_(rank), + max_run_times_(max_run_times), + max_slot_nums_(max_slot_nums) { + // Should be serially invoked, not thread-safe + static int64_t task_node_cnt = 0; + task_id_ = task_node_cnt++; +} + TaskNode::TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums) @@ -55,10 +66,14 @@ std::unique_ptr TaskNode::CreateTaskNode( max_slot_nums); } -void TaskNode::AddUpstreamTask(int64_t task_id) { upstream_.insert(task_id); } +bool TaskNode::AddUpstreamTask(int64_t task_id) { + const auto& ret = upstream_.insert(task_id); + return *ret.first == task_id; +} -void TaskNode::AddDownstreamTask(int64_t task_id) { - downstream_.insert(task_id); +bool TaskNode::AddDownstreamTask(int64_t task_id) { + const auto& ret = downstream_.insert(task_id); + return *ret.first == task_id; } std::string TaskNode::DebugString() const { diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index f438e491daa..ec2ea0c0093 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -19,6 +19,7 @@ #include #include +#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/macros.h" namespace paddle { @@ -34,6 +35,8 @@ class TaskNode final { int64_t max_slot_nums); TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); + TaskNode(const paddle::framework::ProgramDesc& program, int64_t rank, + int64_t max_run_times, int64_t max_slot_nums); ~TaskNode() = default; int64_t rank() const { return rank_; } @@ -44,9 +47,10 @@ class TaskNode final { const std::unordered_set& upstream() const { return upstream_; } const std::unordered_set& downstream() const { return downstream_; } const std::string& type() const { return type_; } + const paddle::framework::ProgramDesc& program() const { return program_; } - void AddUpstreamTask(int64_t task_id); - void AddDownstreamTask(int64_t task_id); + bool AddUpstreamTask(int64_t task_id); + bool AddDownstreamTask(int64_t task_id); std::string DebugString() const; static std::unique_ptr CreateEmptyTaskNode(int32_t role, @@ -64,6 +68,7 @@ class TaskNode final { std::vector ops_; std::unordered_set upstream_; std::unordered_set downstream_; + framework::ProgramDesc program_; int32_t role_; int64_t rank_; int64_t task_id_; diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index 392cdfe19bd..726c0428390 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/pybind/bind_fleet_executor.h" #include #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/framework/program_desc.h" namespace py = pybind11; @@ -23,12 +24,19 @@ namespace paddle { namespace pybind { using paddle::distributed::FleetExecutor; +using paddle::distributed::TaskNode; void BindFleetExecutor(py::module* m) { py::class_(*m, "FleetExecutor") .def(py::init()) .def("init", &FleetExecutor::Init) .def("run", &FleetExecutor::Run); + + py::class_(*m, "TaskNode") + .def(py::init()) + .def("task_id", &TaskNode::task_id) + .def("add_upstream_task", &TaskNode::AddUpstreamTask) + .def("add_downstream_task", &TaskNode::AddDownstreamTask); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4ec42995139..4698b1dcb27 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -144,6 +144,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices) + LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node) endif() # Temporally disable test_deprecated_decorator diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py new file mode 100644 index 00000000000..2c944aa5dbc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import paddle.fluid.core as core + +paddle.enable_static() + + +class TestFleetExecutorTaskNode(unittest.TestCase): + def test_task_node(self): + program = paddle.static.Program() + task_node_0 = core.TaskNode(program.desc, 0, 1, 1) + task_node_1 = core.TaskNode(program.desc, 0, 1, 1) + task_node_2 = core.TaskNode(program.desc, 0, 1, 1) + self.assertEqual(task_node_0.task_id(), 0) + self.assertEqual(task_node_1.task_id(), 1) + self.assertEqual(task_node_2.task_id(), 2) + self.assertTrue(task_node_0.add_downstream_task(task_node_1.task_id())) + self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id())) + + +if __name__ == "__main__": + unittest.main() -- GitLab