未验证 提交 0dd41a2a 编写于 作者: L LiYuRio 提交者: GitHub

Add start interceptor and fix bug in switch scope (#50225)

上级 2261390c
...@@ -37,6 +37,7 @@ cc_library( ...@@ -37,6 +37,7 @@ cc_library(
compute_interceptor.cc compute_interceptor.cc
amplifier_interceptor.cc amplifier_interceptor.cc
cond_interceptor.cc cond_interceptor.cc
start_interceptor.cc
source_interceptor.cc source_interceptor.cc
sink_interceptor.cc sink_interceptor.cc
message_service.cc message_service.cc
...@@ -69,6 +70,8 @@ if(WITH_DISTRIBUTE) ...@@ -69,6 +70,8 @@ if(WITH_DISTRIBUTE)
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
start_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
......
...@@ -36,6 +36,7 @@ USE_INTERCEPTOR(Compute); ...@@ -36,6 +36,7 @@ USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink); USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond); USE_INTERCEPTOR(Cond);
USE_INTERCEPTOR(Start);
void Carrier::Init( void Carrier::Init(
int64_t rank, int64_t rank,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -33,14 +34,18 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -33,14 +34,18 @@ void ComputeInterceptor::PrepareDeps() {
auto& downstream = node_->downstream(); auto& downstream = node_->downstream();
for (auto up : upstream) { for (auto up : upstream) {
in_readys_.emplace(up.first, std::make_pair(up.second, 0)); std::map<int64_t, int64_t> ready_size_map;
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
ready_size_map.emplace(i, 0);
}
in_readys_.emplace(up.first, std::make_pair(up.second, ready_size_map));
} }
for (auto down : downstream) { for (auto down : downstream) {
out_buffs_.emplace(down.first, std::make_pair(down.second, 0)); out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
} }
} }
void ComputeInterceptor::IncreaseReady(int64_t up_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) {
auto it = in_readys_.find(up_id); auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, PADDLE_ENFORCE_NE(it,
in_readys_.end(), in_readys_.end(),
...@@ -48,8 +53,11 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { ...@@ -48,8 +53,11 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
"Cannot find upstream=%lld in in_readys.", up_id)); "Cannot find upstream=%lld in in_readys.", up_id));
auto max_ready_size = it->second.first; auto max_ready_size = it->second.first;
auto ready_size = it->second.second; const auto& ready_scope_map = it->second.second;
ready_size += 1; int64_t ready_size = 0;
for (auto& scope_iter : ready_scope_map) {
ready_size += scope_iter.second;
}
if (max_ready_size != INFINITE_BUFFER_SIZE) { if (max_ready_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
ready_size, ready_size,
...@@ -61,7 +69,14 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { ...@@ -61,7 +69,14 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
ready_size, ready_size,
max_ready_size)); max_ready_size));
} }
it->second.second = ready_size; PADDLE_ENFORCE_NE(
it->second.second.find(scope_id),
it->second.second.end(),
platform::errors::OutOfRange(
"Interceptor %lld can not find scope %lld in upstream ready map",
interceptor_id_,
scope_id));
it->second.second.at(scope_id) = ready_scope_map.at(scope_id) + 1;
} }
void ComputeInterceptor::DecreaseBuff(int64_t down_id) { void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
...@@ -83,16 +98,21 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) { ...@@ -83,16 +98,21 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
} }
bool ComputeInterceptor::IsInputReady() { bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) { for (int64_t i = 0; i < node_->max_run_times(); ++i) {
auto ready_size = ins.second.second; bool flag = true;
// not ready, return false for (auto& ins : in_readys_) {
if (ready_size == 0) { auto ready_size_map = ins.second.second;
VLOG(3) << "Interceptor " << GetInterceptorId() flag = flag && (ready_size_map.at(i) != 0);
}
if (flag) {
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready."; << "'s upstreams aren't all ready.";
return false;
} }
} }
return true; return false;
} }
bool ComputeInterceptor::CanWriteOutput() { bool ComputeInterceptor::CanWriteOutput() {
...@@ -144,7 +164,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -144,7 +164,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
void ComputeInterceptor::ReplyCompletedToUpStream() { void ComputeInterceptor::ReplyCompletedToUpStream() {
for (auto& ins : in_readys_) { for (auto& ins : in_readys_) {
auto up_id = ins.first; auto up_id = ins.first;
auto ready_size = ins.second.second; auto ready_size = ins.second.second.at(cur_scope_id_);
ready_size -= 1; ready_size -= 1;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ready_size, ready_size,
...@@ -153,7 +173,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -153,7 +173,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
"upstream=%lld ready_size must >= 0, but now got %lld", "upstream=%lld ready_size must >= 0, but now got %lld",
up_id, up_id,
ready_size)); ready_size));
ins.second.second = ready_size; ins.second.second[cur_scope_id_] = ready_size;
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id << " Reply data_is_useless msg to " << up_id
...@@ -187,11 +207,8 @@ void ComputeInterceptor::RunOps() { ...@@ -187,11 +207,8 @@ void ComputeInterceptor::RunOps() {
void ComputeInterceptor::Run() { void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) { while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; VLOG(0) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_;
// get the ready scope id from queue
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
RunOps(); RunOps();
...@@ -204,10 +221,15 @@ void ComputeInterceptor::Run() { ...@@ -204,10 +221,15 @@ void ComputeInterceptor::Run() {
void ComputeInterceptor::Compute(const InterceptorMessage& msg) { void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id()); VLOG(3) << "Compute interceptor " << interceptor_id_
ready_queue_.push(msg.scope_idx()); << " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run(); Run();
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_useless " << msg.src_id() << " "
<< msg.scope_idx() << " ";
DecreaseBuff(msg.src_id()); DecreaseBuff(msg.src_id());
Run(); Run();
} }
......
...@@ -32,25 +32,24 @@ class ComputeInterceptor : public Interceptor { ...@@ -32,25 +32,24 @@ class ComputeInterceptor : public Interceptor {
virtual void RunOps(); virtual void RunOps();
virtual void SendDataReadyToDownStream(); virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream(); virtual void ReplyCompletedToUpStream();
virtual void Compute(const InterceptorMessage& msg);
void Run();
void IncreaseReady(int64_t up_id, int64_t scope_id);
void DecreaseBuff(int64_t down_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_; int64_t cur_scope_id_;
// upstream_id-->(max_ready_size, scope-->ready_size)
std::map<int64_t, std::pair<int64_t, std::map<int64_t, int64_t>>>
in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
private: private:
void PrepareDeps(); void PrepareDeps();
void IncreaseReady(int64_t up_id);
void DecreaseBuff(int64_t down_id);
bool IsInputReady(); bool IsInputReady();
bool CanWriteOutput(); bool CanWriteOutput();
void Run();
void Compute(const InterceptorMessage& msg);
// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -98,8 +98,6 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { ...@@ -98,8 +98,6 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
} }
void CondInterceptor::Compute() { void CondInterceptor::Compute() {
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
bool cond = GetCondResult(); bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var() VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond; << " with value " << cond;
...@@ -109,14 +107,14 @@ void CondInterceptor::Compute() { ...@@ -109,14 +107,14 @@ void CondInterceptor::Compute() {
SendDataReady(down_id); SendDataReady(down_id);
} }
} else { } else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_; VLOG(0) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_); SendDataReady(stop_loop_id_);
} }
} }
void CondInterceptor::Run(const InterceptorMessage& msg) { void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
ready_queue_.push(msg.scope_idx()); cur_scope_id_ = msg.scope_idx();
Compute(); Compute();
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
......
...@@ -39,7 +39,6 @@ class CondInterceptor final : public Interceptor { ...@@ -39,7 +39,6 @@ class CondInterceptor final : public Interceptor {
void SendDataReady(int64_t down_id); void SendDataReady(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id); void ReplyDataIsUseless(int64_t up_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_; int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_; std::set<int64_t> normal_in_id_;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
StartInterceptor::StartInterceptor(int64_t interceptor_id, TaskNode* node)
: ComputeInterceptor(interceptor_id, node) {
auto& downstream = node_->downstream();
PADDLE_ENFORCE_EQ(
downstream.size(),
1,
platform::errors::OutOfRange(
"The downstream for StartInterceptor only support 1 for now."));
for (auto down : downstream) {
batch_size_ = down.second;
}
bool evenly_divisible = ((node_->max_run_times() % batch_size_) == 0);
PADDLE_ENFORCE(
evenly_divisible,
platform::errors::Fatal(
"Wrong config: Num of step should be divided by batch_size,"
"num_step=%lld, batch_size=%lld",
node_->max_run_times(),
batch_size_));
}
void StartInterceptor::RunOps() {
finish_count_++;
ComputeInterceptor::RunOps();
}
void StartInterceptor::SendDataReadyToDownStream() {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
if (max_buff_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
}
outs.second.second = used_size;
}
if (finish_count_ == batch_size_) {
for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(step_);
VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << step_;
Send(down_id, ready_msg);
}
step_++;
}
}
}
void StartInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
VLOG(3) << "Start interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Start interceptor receive data_is_useless " << msg.src_id()
<< " " << finish_count_;
finish_count_--;
if (finish_count_ == 0) {
for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
DecreaseBuff(down_id);
}
}
for (int64_t i = 0; i < batch_size_; ++i) {
Run();
}
}
}
}
REGISTER_INTERCEPTOR(Start, StartInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace paddle {
namespace distributed {
class StartInterceptor final : public ComputeInterceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void SendDataReadyToDownStream() override;
void RunOps() override;
void Compute(const InterceptorMessage& msg) override;
int64_t batch_size_{0};
int64_t finish_count_{0};
int64_t step_{0};
};
} // namespace distributed
} // namespace paddle
...@@ -21,6 +21,9 @@ limitations under the License. */ ...@@ -21,6 +21,9 @@ limitations under the License. */
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -66,7 +66,7 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -66,7 +66,7 @@ TEST(AmplifierInterceptor, Amplifier) {
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3; int64_t micro_steps = 1;
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
......
...@@ -84,8 +84,8 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -84,8 +84,8 @@ TEST(AmplifierInterceptor, Amplifier) {
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3); TaskNode* node_b = new TaskNode(0, 0, 1, micro_steps);
TaskNode* node_c = new TaskNode(0, 0, 2, 3); TaskNode* node_c = new TaskNode(0, 0, 2, micro_steps);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
......
...@@ -33,7 +33,7 @@ def body(i, ten, data): ...@@ -33,7 +33,7 @@ def body(i, ten, data):
return [i, ten, data] return [i, ten, data]
num_micro_batches = 3 num_micro_batches = 4
def batch_generator_creator(): def batch_generator_creator():
...@@ -126,7 +126,7 @@ class TestFleetExecutor(unittest.TestCase): ...@@ -126,7 +126,7 @@ class TestFleetExecutor(unittest.TestCase):
task_a = TaskNode( task_a = TaskNode(
0, 0,
num_micro_batches, num_micro_batches,
node_type="Compute", node_type="Start",
task_id=0, task_id=0,
program=program_a, program=program_a,
lazy_initialize=True, lazy_initialize=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册