未验证 提交 3f815e76 编写于 作者: L LiYuRio 提交者: GitHub

Export task node to python (#37509)

上级 ed7a21de
......@@ -21,6 +21,17 @@ namespace {
using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(int32_t role, const std::vector<OperatorBase*>& ops,
int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums)
......@@ -55,10 +66,14 @@ std::unique_ptr<TaskNode> TaskNode::CreateTaskNode(
max_slot_nums);
}
void TaskNode::AddUpstreamTask(int64_t task_id) { upstream_.insert(task_id); }
bool TaskNode::AddUpstreamTask(int64_t task_id) {
const auto& ret = upstream_.insert(task_id);
return *ret.first == task_id;
}
void TaskNode::AddDownstreamTask(int64_t task_id) {
downstream_.insert(task_id);
bool TaskNode::AddDownstreamTask(int64_t task_id) {
const auto& ret = downstream_.insert(task_id);
return *ret.first == task_id;
}
std::string TaskNode::DebugString() const {
......
......@@ -19,6 +19,7 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
......@@ -34,6 +35,8 @@ class TaskNode final {
int64_t max_slot_nums);
TaskNode(int32_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);
TaskNode(const paddle::framework::ProgramDesc& program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums);
~TaskNode() = default;
int64_t rank() const { return rank_; }
......@@ -44,9 +47,10 @@ class TaskNode final {
const std::unordered_set<int64_t>& upstream() const { return upstream_; }
const std::unordered_set<int64_t>& downstream() const { return downstream_; }
const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; }
void AddUpstreamTask(int64_t task_id);
void AddDownstreamTask(int64_t task_id);
bool AddUpstreamTask(int64_t task_id);
bool AddDownstreamTask(int64_t task_id);
std::string DebugString() const;
static std::unique_ptr<TaskNode> CreateEmptyTaskNode(int32_t role,
......@@ -64,6 +68,7 @@ class TaskNode final {
std::vector<OperatorBase*> ops_;
std::unordered_set<int64_t> upstream_;
std::unordered_set<int64_t> downstream_;
framework::ProgramDesc program_;
int32_t role_;
int64_t rank_;
int64_t task_id_;
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/pybind/bind_fleet_executor.h"
#include <pybind11/stl.h>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace py = pybind11;
......@@ -23,12 +24,19 @@ namespace paddle {
namespace pybind {
using paddle::distributed::FleetExecutor;
using paddle::distributed::TaskNode;
void BindFleetExecutor(py::module* m) {
py::class_<FleetExecutor>(*m, "FleetExecutor")
.def(py::init<const std::string&>())
.def("init", &FleetExecutor::Init)
.def("run", &FleetExecutor::Run);
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<const framework::ProgramDesc&, int64_t, int64_t, int64_t>())
.def("task_id", &TaskNode::task_id)
.def("add_upstream_task", &TaskNode::AddUpstreamTask)
.def("add_downstream_task", &TaskNode::AddDownstreamTask);
}
} // namespace pybind
} // namespace paddle
......@@ -144,6 +144,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node)
endif()
# Temporally disable test_deprecated_decorator
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid.core as core
paddle.enable_static()
class TestFleetExecutorTaskNode(unittest.TestCase):
def test_task_node(self):
program = paddle.static.Program()
task_node_0 = core.TaskNode(program.desc, 0, 1, 1)
task_node_1 = core.TaskNode(program.desc, 0, 1, 1)
task_node_2 = core.TaskNode(program.desc, 0, 1, 1)
self.assertEqual(task_node_0.task_id(), 0)
self.assertEqual(task_node_1.task_id(), 1)
self.assertEqual(task_node_2.task_id(), 2)
self.assertTrue(task_node_0.add_downstream_task(task_node_1.task_id()))
self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id()))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册