未验证 提交 0dd41a2a 编写于 作者: L LiYuRio 提交者: GitHub

Add start interceptor and fix bug in switch scope (#50225)

上级 2261390c
......@@ -37,6 +37,7 @@ cc_library(
compute_interceptor.cc
amplifier_interceptor.cc
cond_interceptor.cc
start_interceptor.cc
source_interceptor.cc
sink_interceptor.cc
message_service.cc
......@@ -69,6 +70,8 @@ if(WITH_DISTRIBUTE)
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
start_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
......
......@@ -36,6 +36,7 @@ USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond);
USE_INTERCEPTOR(Start);
void Carrier::Init(
int64_t rank,
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
......@@ -33,14 +34,18 @@ void ComputeInterceptor::PrepareDeps() {
auto& downstream = node_->downstream();
for (auto up : upstream) {
in_readys_.emplace(up.first, std::make_pair(up.second, 0));
std::map<int64_t, int64_t> ready_size_map;
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
ready_size_map.emplace(i, 0);
}
in_readys_.emplace(up.first, std::make_pair(up.second, ready_size_map));
}
for (auto down : downstream) {
out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
}
}
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) {
auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it,
in_readys_.end(),
......@@ -48,8 +53,11 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
"Cannot find upstream=%lld in in_readys.", up_id));
auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
const auto& ready_scope_map = it->second.second;
int64_t ready_size = 0;
for (auto& scope_iter : ready_scope_map) {
ready_size += scope_iter.second;
}
if (max_ready_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
ready_size,
......@@ -61,7 +69,14 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
ready_size,
max_ready_size));
}
it->second.second = ready_size;
PADDLE_ENFORCE_NE(
it->second.second.find(scope_id),
it->second.second.end(),
platform::errors::OutOfRange(
"Interceptor %lld can not find scope %lld in upstream ready map",
interceptor_id_,
scope_id));
it->second.second.at(scope_id) = ready_scope_map.at(scope_id) + 1;
}
void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
......@@ -83,16 +98,21 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
}
bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) {
auto ready_size = ins.second.second;
// not ready, return false
if (ready_size == 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
bool flag = true;
for (auto& ins : in_readys_) {
auto ready_size_map = ins.second.second;
flag = flag && (ready_size_map.at(i) != 0);
}
if (flag) {
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready.";
return false;
}
}
return true;
return false;
}
bool ComputeInterceptor::CanWriteOutput() {
......@@ -144,7 +164,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
void ComputeInterceptor::ReplyCompletedToUpStream() {
for (auto& ins : in_readys_) {
auto up_id = ins.first;
auto ready_size = ins.second.second;
auto ready_size = ins.second.second.at(cur_scope_id_);
ready_size -= 1;
PADDLE_ENFORCE_GE(
ready_size,
......@@ -153,7 +173,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
"upstream=%lld ready_size must >= 0, but now got %lld",
up_id,
ready_size));
ins.second.second = ready_size;
ins.second.second[cur_scope_id_] = ready_size;
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id
......@@ -187,11 +207,8 @@ void ComputeInterceptor::RunOps() {
void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// get the ready scope id from queue
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
VLOG(0) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_;
RunOps();
......@@ -204,10 +221,15 @@ void ComputeInterceptor::Run() {
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id());
ready_queue_.push(msg.scope_idx());
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_useless " << msg.src_id() << " "
<< msg.scope_idx() << " ";
DecreaseBuff(msg.src_id());
Run();
}
......
......@@ -32,25 +32,24 @@ class ComputeInterceptor : public Interceptor {
virtual void RunOps();
virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream();
virtual void Compute(const InterceptorMessage& msg);
void Run();
void IncreaseReady(int64_t up_id, int64_t scope_id);
void DecreaseBuff(int64_t down_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
// upstream_id-->(max_ready_size, scope-->ready_size)
std::map<int64_t, std::pair<int64_t, std::map<int64_t, int64_t>>>
in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
private:
void PrepareDeps();
void IncreaseReady(int64_t up_id);
void DecreaseBuff(int64_t down_id);
bool IsInputReady();
bool CanWriteOutput();
void Run();
void Compute(const InterceptorMessage& msg);
// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
};
} // namespace distributed
......
......@@ -98,8 +98,6 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
}
void CondInterceptor::Compute() {
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
......@@ -109,14 +107,14 @@ void CondInterceptor::Compute() {
SendDataReady(down_id);
}
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
VLOG(0) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ready_queue_.push(msg.scope_idx());
cur_scope_id_ = msg.scope_idx();
Compute();
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
......
......@@ -39,7 +39,6 @@ class CondInterceptor final : public Interceptor {
void SendDataReady(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
StartInterceptor::StartInterceptor(int64_t interceptor_id, TaskNode* node)
: ComputeInterceptor(interceptor_id, node) {
auto& downstream = node_->downstream();
PADDLE_ENFORCE_EQ(
downstream.size(),
1,
platform::errors::OutOfRange(
"The downstream for StartInterceptor only support 1 for now."));
for (auto down : downstream) {
batch_size_ = down.second;
}
bool evenly_divisible = ((node_->max_run_times() % batch_size_) == 0);
PADDLE_ENFORCE(
evenly_divisible,
platform::errors::Fatal(
"Wrong config: Num of step should be divided by batch_size,"
"num_step=%lld, batch_size=%lld",
node_->max_run_times(),
batch_size_));
}
void StartInterceptor::RunOps() {
finish_count_++;
ComputeInterceptor::RunOps();
}
void StartInterceptor::SendDataReadyToDownStream() {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
if (max_buff_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
}
outs.second.second = used_size;
}
if (finish_count_ == batch_size_) {
for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(step_);
VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << step_;
Send(down_id, ready_msg);
}
step_++;
}
}
}
void StartInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
VLOG(3) << "Start interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Start interceptor receive data_is_useless " << msg.src_id()
<< " " << finish_count_;
finish_count_--;
if (finish_count_ == 0) {
for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
DecreaseBuff(down_id);
}
}
for (int64_t i = 0; i < batch_size_; ++i) {
Run();
}
}
}
}
REGISTER_INTERCEPTOR(Start, StartInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace paddle {
namespace distributed {
class StartInterceptor final : public ComputeInterceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void SendDataReadyToDownStream() override;
void RunOps() override;
void Compute(const InterceptorMessage& msg) override;
int64_t batch_size_{0};
int64_t finish_count_{0};
int64_t step_{0};
};
} // namespace distributed
} // namespace paddle
......@@ -21,6 +21,9 @@ limitations under the License. */
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace distributed {
......
......@@ -66,7 +66,7 @@ TEST(AmplifierInterceptor, Amplifier) {
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3;
int64_t micro_steps = 1;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
......
......@@ -84,8 +84,8 @@ TEST(AmplifierInterceptor, Amplifier) {
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* node_c = new TaskNode(0, 0, 2, 3);
TaskNode* node_b = new TaskNode(0, 0, 1, micro_steps);
TaskNode* node_c = new TaskNode(0, 0, 2, micro_steps);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
......
......@@ -33,7 +33,7 @@ def body(i, ten, data):
return [i, ten, data]
num_micro_batches = 3
num_micro_batches = 4
def batch_generator_creator():
......@@ -126,7 +126,7 @@ class TestFleetExecutor(unittest.TestCase):
task_a = TaskNode(
0,
num_micro_batches,
node_type="Compute",
node_type="Start",
task_id=0,
program=program_a,
lazy_initialize=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册