未验证 提交 decbb588 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Remove max_slot_num and implement multi-scope fetch (#50041)

* remove max_slot_num

* fix test case
上级 bad49b51
......@@ -24,33 +24,14 @@ namespace {
using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: program_(program),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
max_run_times_(max_run_times) {
// TODO(liyurui): Will be removed when execute program is supported.
Init();
}
......@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank
......@@ -98,13 +78,11 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
max_run_times_(max_run_times) {
if (op_descs.empty()) {
return;
}
......@@ -121,25 +99,21 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: ops_(ops),
role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
max_run_times_(max_run_times) {}
TaskNode::TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
max_run_times_(max_run_times) {}
bool TaskNode::AddUpstreamTask(int64_t task_id,
int64_t buff_size,
......
......@@ -37,34 +37,23 @@ class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
// TODO(liyurui): This will be the only constructor for task node
TaskNode(paddle::framework::ProgramDesc* program,
int64_t task_id,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
~TaskNode() = default;
......@@ -74,7 +63,6 @@ class TaskNode final {
int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
......@@ -151,7 +139,6 @@ class TaskNode final {
int64_t rank_;
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;
int64_t run_per_steps_{1};
int64_t run_at_offset_{0};
......
......@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2);
TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// source->a->b->sink
......
......@@ -37,8 +37,8 @@ TEST(ComputeInterceptor, Compute) {
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0);
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, 3);
TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* sink = new TaskNode(0, SINK_ID, 3);
// source->a->b->sink
......
......@@ -71,12 +71,12 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, 1); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1);
TaskNode* node_c = new TaskNode(0, 0, 2, 1);
TaskNode* node_d = new TaskNode(0, 0, 3, 1);
TaskNode* node_e = new TaskNode(0, 0, 4, 1);
TaskNode* node_f = new TaskNode(0, 0, 5, 1);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->e->f->sink
......
......@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* node_c = new TaskNode(0, 0, 2, 3);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->sink
......
......@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
......
......@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
......
......@@ -171,17 +171,11 @@ void BindFleetExecutor(py::module* m) {
.value("STOP_LOOP", DependType::STOP_LOOP);
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
int64_t,
int64_t,
int64_t>())
.def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t,
const std::vector<framework::OpDesc*>&,
int64_t,
int64_t,
int64_t,
int64_t>())
.def("task_id", &TaskNode::task_id)
.def("add_upstream_task", &TaskNode::AddUpstreamTask)
......
......@@ -26,7 +26,6 @@ class TaskNode:
self,
rank,
max_run_times,
max_slot_times,
role=None,
node_type=None,
task_id=0,
......@@ -38,7 +37,6 @@ class TaskNode:
"""
:param rank (int): Current rank of the task node.
:param max_run_times (int): The max run times of the task node.
:param max_slot_times (int): The mas slot times of the task node.
:param role (int): The role of the task node. (Will be removed in the future)
:param node_type (str): The type of the task node.
:param task_id (int): The id of task node.
......@@ -56,7 +54,6 @@ class TaskNode:
self.id = int(task_id)
self.rank = rank
self.max_run_times = max_run_times
self.max_slot_times = max_slot_times
self.node_type = node_type
self.program = program
self.lazy_initialize = lazy_initialize
......@@ -72,11 +69,18 @@ class TaskNode:
role is not None and task_id is not None
), "If init task node with ops, should provide `role` and `task_id`."
self.node = core.TaskNode(
role, ops, rank, task_id, max_run_times, max_slot_times
role,
ops,
rank,
task_id,
max_run_times,
)
else:
self.node = core.TaskNode(
program.desc, rank, self.id, max_run_times, max_slot_times
program.desc,
rank,
self.id,
max_run_times,
)
if self.node_type:
self.node.set_type(self.node_type)
......@@ -88,7 +92,6 @@ class TaskNode:
self.rank,
self.id,
self.max_run_times,
self.max_slot_times,
)
if self.node_type:
self.node.set_type(self.node_type)
......@@ -318,33 +321,28 @@ class FleetExecutorUtils:
return task_node_map
def construct_task_nodes_1f1b(self, program_map):
max_slot_times = int(self.max_run_times - self.coord['pp_idx'])
cur_start_id = int(self.rank * self.num_of_functionality)
lr_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["lr"],
task_id=cur_start_id,
)
fwd_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["fwd"],
task_id=cur_start_id + 1,
)
bwd_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["bwd"],
task_id=cur_start_id + 2,
)
opt_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["opt"],
task_id=cur_start_id + 3,
)
......@@ -363,12 +361,10 @@ class FleetExecutorUtils:
return task_id_to_rank
def construct_task_nodes_1f1b_op_list(self, op_list_map):
max_slot_times = int(self.max_run_times - self.coord['pp_idx'])
cur_start_id = int(self.rank * self.num_of_functionality)
lr_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize.LRSched),
ops=op_list_map["lr"],
task_id=cur_start_id,
......@@ -378,7 +374,6 @@ class FleetExecutorUtils:
fwd_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Forward),
ops=op_list_map["fwd"],
task_id=cur_start_id + 1,
......@@ -387,7 +382,6 @@ class FleetExecutorUtils:
bwd_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Backward),
ops=op_list_map["bwd"],
task_id=cur_start_id + 2,
......@@ -396,7 +390,6 @@ class FleetExecutorUtils:
opt_task_node = TaskNode(
rank=self.rank,
max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize),
ops=op_list_map["opt"],
task_id=cur_start_id + 3,
......@@ -480,7 +473,6 @@ def origin(program, rank):
rank=rank,
node_type="Compute",
max_run_times=1,
max_slot_times=1,
)
task_id_to_rank = {task_node.task_id(): rank}
return [task_node.task_node()], task_id_to_rank
......@@ -94,7 +94,6 @@ class TestFleetExecutor(unittest.TestCase):
task_a = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=0,
program=program_a,
......@@ -103,7 +102,6 @@ class TestFleetExecutor(unittest.TestCase):
task_b = TaskNode(
0,
num_micro_batches,
0,
node_type="Cond",
task_id=1,
program=paddle.static.Program(),
......@@ -113,7 +111,6 @@ class TestFleetExecutor(unittest.TestCase):
task_c = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=2,
program=program_b,
......@@ -122,7 +119,6 @@ class TestFleetExecutor(unittest.TestCase):
task_d = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=3,
program=paddle.static.Program(),
......@@ -131,7 +127,6 @@ class TestFleetExecutor(unittest.TestCase):
task_e = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=4,
program=paddle.static.Program(),
......
......@@ -24,9 +24,9 @@ paddle.enable_static()
class TestFleetExecutorTaskNode(unittest.TestCase):
def test_task_node(self):
program = paddle.static.Program()
task_node_0 = core.TaskNode(program.desc, 0, 1, 1)
task_node_0 = core.TaskNode(program.desc, 0, 0, 1)
task_node_1 = core.TaskNode(program.desc, 0, 1, 1)
task_node_2 = core.TaskNode(program.desc, 0, 1, 1)
task_node_2 = core.TaskNode(program.desc, 0, 2, 1)
self.assertEqual(task_node_0.task_id(), 0)
self.assertEqual(task_node_1.task_id(), 1)
self.assertEqual(task_node_2.task_id(), 2)
......@@ -47,7 +47,6 @@ class TestFleetExecutorTaskNode(unittest.TestCase):
program=program,
rank=0,
max_run_times=1,
max_slot_times=1,
lazy_initialize=True,
)
task_node = task.task_node()
......
......@@ -59,7 +59,6 @@ class TestFleetExecutor(unittest.TestCase):
rank=0,
node_type="Compute",
max_run_times=1,
max_slot_times=1,
lazy_initialize=True,
)
empty_program._pipeline_opt = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册