未验证 提交 0633e14d 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] set pipeline 1f1b buffer size (#37807)

上级 b65708a8
...@@ -27,19 +27,15 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) ...@@ -27,19 +27,15 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
} }
void ComputeInterceptor::PrepareDeps() { void ComputeInterceptor::PrepareDeps() {
auto& upstream = GetTaskNode()->upstream(); auto& upstream = node_->upstream();
auto& downstream = GetTaskNode()->downstream(); auto& downstream = node_->downstream();
// TODO(wangxi): get from task node for (auto up : upstream) {
int64_t in_buff_size = std::numeric_limits<int64_t>::max(); in_readys_.emplace(up.first, std::make_pair(up.second, 0));
int64_t out_buff_size = 2; in_stops_.emplace(up.first, false);
for (auto up_id : upstream) {
in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0));
in_stops_.emplace(up_id, false);
} }
for (auto down_id : downstream) { for (auto down : downstream) {
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0)); out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
} }
// source compute node, should we add a new SourceInterceptor? // source compute node, should we add a new SourceInterceptor?
...@@ -114,8 +110,7 @@ bool ComputeInterceptor::CanWriteOutput() { ...@@ -114,8 +110,7 @@ bool ComputeInterceptor::CanWriteOutput() {
// only source node need reset // only source node need reset
bool ComputeInterceptor::ShouldReset() { bool ComputeInterceptor::ShouldReset() {
if (is_source_ && step_ == node_->max_run_times()) return true; return is_source_ && (step_ == node_->max_run_times());
return false;
} }
void ComputeInterceptor::SendDataReadyToDownStream() { void ComputeInterceptor::SendDataReadyToDownStream() {
......
...@@ -150,12 +150,14 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) { ...@@ -150,12 +150,14 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
} }
role_to_ops.at(new_op_role_id).emplace_back(op.get()); role_to_ops.at(new_op_role_id).emplace_back(op.get());
} }
int64_t cur_rank = exe_desc_.cur_rank(); int64_t cur_rank = exe_desc_.cur_rank();
DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(), DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(),
exe_desc_.mp_degree()); exe_desc_.mp_degree());
const auto& coord = coord_sys.RankToCoord(cur_rank); const auto& coord = coord_sys.RankToCoord(cur_rank);
int pipeline_stage = coord.pp_idx; int pipeline_stage = coord.pp_idx;
int64_t num_pipeline_stages = exe_desc_.pp_degree(); int64_t num_pipeline_stages = exe_desc_.pp_degree();
// TODO(fleet_executor dev): start up steps should be a config `num_slots` // TODO(fleet_executor dev): start up steps should be a config `num_slots`
int64_t start_up_steps = num_pipeline_stages - pipeline_stage; int64_t start_up_steps = num_pipeline_stages - pipeline_stage;
int64_t num_micro_batches = exe_desc_.num_micro_batches(); int64_t num_micro_batches = exe_desc_.num_micro_batches();
...@@ -199,36 +201,69 @@ void RuntimeGraph::FakeDependence() { ...@@ -199,36 +201,69 @@ void RuntimeGraph::FakeDependence() {
downstream_coord.pp_idx += 1; downstream_coord.pp_idx += 1;
int64_t pp_upstream = coord_sys.CoordToRank(upstream_coord); int64_t pp_upstream = coord_sys.CoordToRank(upstream_coord);
int64_t pp_downstream = coord_sys.CoordToRank(downstream_coord); int64_t pp_downstream = coord_sys.CoordToRank(downstream_coord);
bool is_first_stage = (pp_upstream == -1);
bool is_last_stage = (pp_downstream == -1);
int32_t num_of_functionality = functionality_order.size(); int32_t num_of_functionality = functionality_order.size();
// lr -> forward -> backward -> optimize // lr(1:m) -> forward -> backward -> (m:1)optimize
// | | // ↑ ↓
// lr -> forward -> backward -> optimize // lr(1:m) -> forward -> backward -> (m:1)optimize
// ↑ ↓
// lr(1:m) -> forward -> backward -> (m:1)optimize
for (std::size_t i = 0; i < task_nodes_.size(); ++i) { for (std::size_t i = 0; i < task_nodes_.size(); ++i) {
if (i != 0) { auto& node = task_nodes_[i];
task_nodes_[i]->AddUpstreamTask(cur_rank * num_of_functionality + i - 1); bool is_forward = IsForward(node->role());
bool is_backward = IsBackward(node->role());
int64_t cur_id = cur_rank * num_of_functionality + i;
int64_t prev_id = cur_id - 1;
int64_t next_id = cur_id + 1;
int64_t upstream_id = pp_upstream * num_of_functionality + i;
int64_t downstream_id = pp_downstream * num_of_functionality + i;
// 1F1B, last stage pp_buff_size should be 1, while first stage
// pp_buff_size should be pp_degree
int64_t pp_buff_size = exe_desc_.pp_degree() - coord.pp_idx;
std::vector<std::pair<int64_t, int64_t>> ups;
std::vector<std::pair<int64_t, int64_t>> downs;
if (i != 0) { // not lr
int64_t buff_size = is_backward ? pp_buff_size : 2;
ups.emplace_back(prev_id, buff_size);
} }
if (i != task_nodes_.size() - 1) { if (i != task_nodes_.size() - 1) { // not optimize
task_nodes_[i]->AddDownstreamTask(cur_rank * num_of_functionality + i + int64_t buff_size = is_forward ? pp_buff_size : 2;
1); downs.emplace_back(next_id, buff_size);
} }
if (IsForward(task_nodes_[i]->role())) {
if (pp_upstream != -1) { if (is_forward) {
task_nodes_[i]->AddUpstreamTask(pp_upstream * num_of_functionality + i); if (!is_first_stage) {
ups.emplace_back(upstream_id, 2);
} }
if (pp_downstream != -1) { if (!is_last_stage) {
task_nodes_[i]->AddDownstreamTask(pp_downstream * num_of_functionality + downs.emplace_back(downstream_id, 2);
i);
} }
} else if (IsBackward(task_nodes_[i]->role())) { } else if (is_backward) {
if (pp_downstream != -1) { if (!is_last_stage) {
task_nodes_[i]->AddUpstreamTask(pp_downstream * num_of_functionality + ups.emplace_back(downstream_id, 2);
i);
} }
if (pp_upstream != -1) { if (!is_first_stage) {
task_nodes_[i]->AddDownstreamTask(pp_upstream * num_of_functionality + downs.emplace_back(upstream_id, 2);
i);
} }
} }
for (auto up : ups) {
VLOG(3) << "Task(" << cur_id << ") AddUpstream Task(" << up.first
<< ") with buff_size=" << up.second;
node->AddUpstreamTask(up.first, up.second);
}
for (auto down : downs) {
VLOG(3) << "Task(" << cur_id << ") AddDownstream Task(" << down.first
<< ") with buff_size=" << down.second;
node->AddDownstreamTask(down.first, down.second);
}
} }
} }
......
...@@ -57,14 +57,14 @@ TaskNode::TaskNode(int32_t role, int64_t rank, int64_t task_id, ...@@ -57,14 +57,14 @@ TaskNode::TaskNode(int32_t role, int64_t rank, int64_t task_id,
max_run_times_(max_run_times), max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {} max_slot_nums_(max_slot_nums) {}
bool TaskNode::AddUpstreamTask(int64_t task_id) { bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) {
const auto& ret = upstream_.insert(task_id); const auto& ret = upstream_.emplace(task_id, buff_size);
return *ret.first == task_id; return ret.second;
} }
bool TaskNode::AddDownstreamTask(int64_t task_id) { bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) {
const auto& ret = downstream_.insert(task_id); const auto& ret = downstream_.emplace(task_id, buff_size);
return *ret.first == task_id; return ret.second;
} }
std::string TaskNode::DebugString() const { std::string TaskNode::DebugString() const {
......
...@@ -48,8 +48,12 @@ class TaskNode final { ...@@ -48,8 +48,12 @@ class TaskNode final {
int64_t run_at_offset() const { return run_at_offset_; } int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; } int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; } int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::unordered_set<int64_t>& upstream() const { return upstream_; } const std::unordered_map<int64_t, int64_t>& upstream() const {
const std::unordered_set<int64_t>& downstream() const { return downstream_; } return upstream_;
}
const std::unordered_map<int64_t, int64_t>& downstream() const {
return downstream_;
}
const std::string& type() const { return type_; } const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; } const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; } const std::vector<OperatorBase*>& ops() const { return ops_; }
...@@ -60,8 +64,9 @@ class TaskNode final { ...@@ -60,8 +64,9 @@ class TaskNode final {
void SetSendDownPerSteps(int64_t value); void SetSendDownPerSteps(int64_t value);
void SetType(const std::string& type) { type_ = type; } void SetType(const std::string& type) { type_ = type; }
bool AddUpstreamTask(int64_t task_id); // upstream need buffs?
bool AddDownstreamTask(int64_t task_id); bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1);
std::string DebugString() const; std::string DebugString() const;
private: private:
...@@ -69,8 +74,9 @@ class TaskNode final { ...@@ -69,8 +74,9 @@ class TaskNode final {
TaskNode() = default; TaskNode() = default;
// ops_ will be removed in the future // ops_ will be removed in the future
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
std::unordered_set<int64_t> upstream_; // task_id-->buff_size
std::unordered_set<int64_t> downstream_; std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_;
framework::ProgramDesc program_; framework::ProgramDesc program_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_; std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
int32_t role_; int32_t role_;
......
...@@ -56,8 +56,8 @@ TEST(ComputeInterceptor, Compute) { ...@@ -56,8 +56,8 @@ TEST(ComputeInterceptor, Compute) {
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
// a->b->c // a->b->c
node_a->AddDownstreamTask(1); node_a->AddDownstreamTask(1, 3);
node_b->AddUpstreamTask(0); node_b->AddUpstreamTask(0, 3);
node_b->AddDownstreamTask(2); node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1); node_c->AddUpstreamTask(1);
......
...@@ -25,19 +25,34 @@ limitations under the License. */ ...@@ -25,19 +25,34 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void LinkNodes(const std::vector<TaskNode*>& nodes) { int64_t GetBuffSize(
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs,
TaskNode* from, TaskNode* to) {
if (buffs.find({from, to}) != buffs.end()) {
return buffs.at({from, to});
}
if (buffs.find({to, from}) != buffs.end()) {
return buffs.at({to, from});
}
return 2; // set default 2
}
void LinkNodes(const std::vector<TaskNode*>& nodes,
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs) {
size_t size = nodes.size(); size_t size = nodes.size();
if (size <= 1) return; if (size <= 1) return;
{ // i = 0 { // i = 0
TaskNode* now = nodes[0]; TaskNode* now = nodes[0];
TaskNode* next = nodes[1]; TaskNode* next = nodes[1];
now->AddDownstreamTask(next->task_id()); auto buff_size = GetBuffSize(buffs, now, next);
now->AddDownstreamTask(next->task_id(), buff_size);
} }
{ // i = size - 1 { // i = size - 1
TaskNode* prev = nodes[size - 2]; TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1]; TaskNode* now = nodes[size - 1];
now->AddUpstreamTask(prev->task_id()); auto buff_size = GetBuffSize(buffs, prev, now);
now->AddUpstreamTask(prev->task_id(), buff_size);
} }
for (size_t i = 1; i < size - 1; ++i) { for (size_t i = 1; i < size - 1; ++i) {
...@@ -45,8 +60,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) { ...@@ -45,8 +60,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
TaskNode* now = nodes[i]; TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1]; TaskNode* next = nodes[i + 1];
now->AddUpstreamTask(prev->task_id()); auto buff_size = GetBuffSize(buffs, prev, now);
now->AddDownstreamTask(next->task_id()); now->AddUpstreamTask(prev->task_id(), buff_size);
buff_size = GetBuffSize(buffs, now, next);
now->AddDownstreamTask(next->task_id(), buff_size);
} }
} }
...@@ -55,7 +73,7 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -55,7 +73,7 @@ TEST(AmplifierInterceptor, Amplifier) {
MessageBus& msg_bus = MessageBus::Instance(); MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, ""); msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, "");
int64_t micro_steps = 3; int64_t micro_steps = 6;
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = TaskNode* node_a =
...@@ -65,7 +83,8 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -65,7 +83,8 @@ TEST(AmplifierInterceptor, Amplifier) {
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
// a->b->c->d // a->b->c->d
LinkNodes({node_a, node_b, node_c, node_d}); // LR->F->B->U
LinkNodes({node_a, node_b, node_c, node_d}, {{{node_b, node_c}, 1}});
node_a->SetRunPerSteps(micro_steps); node_a->SetRunPerSteps(micro_steps);
node_d->SetRunPerSteps(micro_steps); node_d->SetRunPerSteps(micro_steps);
......
...@@ -28,8 +28,9 @@ class TestFleetExecutorTaskNode(unittest.TestCase): ...@@ -28,8 +28,9 @@ class TestFleetExecutorTaskNode(unittest.TestCase):
self.assertEqual(task_node_0.task_id(), 0) self.assertEqual(task_node_0.task_id(), 0)
self.assertEqual(task_node_1.task_id(), 1) self.assertEqual(task_node_1.task_id(), 1)
self.assertEqual(task_node_2.task_id(), 2) self.assertEqual(task_node_2.task_id(), 2)
self.assertTrue(task_node_0.add_downstream_task(task_node_1.task_id())) self.assertTrue(
self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id())) task_node_0.add_downstream_task(task_node_1.task_id(), 1))
self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id(), 1))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册