未验证 提交 92c2dcbd 编写于 作者: L LiYuRio 提交者: GitHub

Cherry-pick fleet executor and auto parallel (#50071)

上级 4bacf2ab
......@@ -426,7 +426,8 @@ endif()
if(WITH_DISTRIBUTE
AND NOT WITH_PSLIB
AND NOT WITH_PSCORE)
AND NOT WITH_PSCORE
AND NOT WITH_RPC)
include(external/snappy)
list(APPEND third_party_deps extern_snappy)
......
......@@ -36,6 +36,8 @@ cc_library(
interceptor.cc
compute_interceptor.cc
amplifier_interceptor.cc
cond_interceptor.cc
start_interceptor.cc
source_interceptor.cc
sink_interceptor.cc
message_service.cc
......@@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
start_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
......
......@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
if ((step_ % run_per_steps_) == run_at_offset_) {
if ((cur_scope_id_ % run_per_steps_) == run_at_offset_) {
ComputeInterceptor::RunOps();
}
}
......@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
void AmplifierInterceptor::SendDataReadyToDownStream() {
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
if (step_ % send_down_per_steps_ == 0) {
if (cur_scope_id_ % send_down_per_steps_ == 0) {
ComputeInterceptor::SendDataReadyToDownStream();
}
}
......@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
void AmplifierInterceptor::ReplyCompletedToUpStream() {
// run multi times, reply one times to upstream, that is
// input one times, output multi times
if (step_ % reply_up_per_steps_ == 0) {
if (cur_scope_id_ % reply_up_per_steps_ == 0) {
ComputeInterceptor::ReplyCompletedToUpStream();
}
}
......
......@@ -21,7 +21,7 @@
namespace paddle {
namespace distributed {
class AmplifierInterceptor : public ComputeInterceptor {
class AmplifierInterceptor final : public ComputeInterceptor {
public:
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node);
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
......@@ -24,6 +25,7 @@
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace paddle {
......@@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond);
USE_INTERCEPTOR(Start);
void Carrier::Init(
int64_t rank,
......@@ -54,24 +58,38 @@ void Carrier::Init(
framework::Scope* scope,
int64_t num_micro_batches,
const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars) {
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_id_to_node_ = interceptor_id_to_node;
place_ = place;
root_scope_ = scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
bool need_create_scope = micro_scope_list.empty();
PADDLE_ENFORCE_NOT_NULL(
root_scope_,
platform::errors::InvalidArgument("root_scope can not be nullptr"));
minibatch_scope_ = &root_scope_->NewScope();
microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) {
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program, inference_root_scope_vars);
if (need_create_scope) {
minibatch_scope_ = &root_scope_->NewScope();
microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) {
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program, inference_root_scope_vars);
}
} else {
microbatch_scopes_ = micro_scope_list;
for (int i = 0; i < num_micro_batches; ++i) {
CopyParameters(i, program, inference_root_scope_vars);
}
}
// Add source and sink interceptor id to rank
interceptor_id_to_rank_.emplace(SOURCE_ID, rank);
interceptor_id_to_rank_.emplace(SINK_ID, rank);
// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
......@@ -93,29 +111,30 @@ void Carrier::CopyParameters(
int microbatch_id,
const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);
std::map<std::string, int> inference_root_scope_var_map;
for (auto var_name : inference_root_scope_vars) {
inference_root_scope_var_map.insert({var_name, 1});
}
for (auto& var : global_block.AllVars()) {
std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end();
if (force_root) {
VLOG(4) << var_name << " will be forced to be created in the root scope.";
}
if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
for (size_t i = 0; i < program.Size(); ++i) {
for (auto& var : program.Block(i).AllVars()) {
std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end();
if (force_root) {
VLOG(4) << var_name
<< " will be forced to be created in the root scope.";
}
if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
}
}
}
......@@ -159,16 +178,11 @@ void Carrier::Start() {
true,
platform::errors::PreconditionNotMet(
"Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id
<< ".";
InterceptorMessage start_msg;
// source node data_is_ready is send by carrier, so set src_id=-1
start_msg.set_src_id(-1);
start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
Send(start_msg);
}
InterceptorMessage start_msg;
start_msg.set_src_id(SOURCE_ID);
start_msg.set_dst_id(SOURCE_ID);
start_msg.set_message_type(START);
Send(start_msg);
// TODO(wangxi): async step
Wait();
dev_ctx_->Wait();
......@@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() {
auto gc = GetGC(place_);
// create source and sink task node
auto max_run_times = microbatch_scopes_.size();
TaskNode* source = new TaskNode(
rank_, SOURCE_ID, max_run_times); // rank, task_id, max_run_times
TaskNode* sink = new TaskNode(rank_, SINK_ID, max_run_times);
// find nodes without upstreams or without downstreams
std::vector<TaskNode*> origin_sources, origin_sinks;
for (const auto& item : interceptor_id_to_node_) {
TaskNode* task_node = item.second;
if (task_node->upstream().empty()) {
origin_sources.emplace_back(task_node);
}
if (task_node->downstream().empty()) {
origin_sinks.emplace_back(task_node);
}
}
// link source node with origin source
for (const auto& node : origin_sources) {
source->AddDownstreamTask(node->task_id(),
std::numeric_limits<int64_t>::max());
node->AddUpstreamTask(SOURCE_ID, std::numeric_limits<int64_t>::max());
}
// link sink node with origin sink
for (const auto& node : origin_sinks) {
sink->AddUpstreamTask(node->task_id(), std::numeric_limits<int64_t>::max());
node->AddDownstreamTask(SINK_ID, std::numeric_limits<int64_t>::max());
}
// create source and sink interceptor
SetInterceptor(SOURCE_ID,
InterceptorFactory::Create("Source", SOURCE_ID, source));
SetInterceptor(SINK_ID, InterceptorFactory::Create("Sink", SINK_ID, sink));
// create each Interceptor
// no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) {
......@@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() {
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< " with type: " << task_node->type() << ".";
if (task_node->upstream().empty()) {
source_interceptor_ids_.emplace_back(interceptor_id);
}
PADDLE_ENFORCE_EQ(
task_node->upstream().empty(),
false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as source nodes"));
PADDLE_ENFORCE_EQ(task_node->downstream().empty(),
false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as sink nodes"));
}
}
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
......@@ -60,7 +61,8 @@ class Carrier final {
framework::Scope* scope,
int64_t num_micro_batches,
const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars = {});
const std::vector<std::string>& inference_root_scope_vars = {},
const std::vector<framework::Scope*>& micro_scope_list = {});
void CopyParameters(
int microbatch_id,
......@@ -100,8 +102,6 @@ class Carrier final {
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
std::vector<int64_t> source_interceptor_ids_;
bool is_init_{false};
std::mutex running_mutex_;
......
......@@ -18,10 +18,85 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/core/utils/dim.h"
namespace paddle {
namespace distributed {
namespace {
template <typename T>
void SetVarResult(const std::string& name,
T value,
int64_t scope_id,
framework::Scope* scope,
const platform::Place& place,
const std::vector<int64_t>& dim_vec) {
auto* var = scope->FindVar(name);
auto* tensor = var->GetMutable<phi::DenseTensor>();
if (!var) {
VLOG(3) << "Create var and memory for var " << name;
var = scope->Var(name);
phi::DDim dims = phi::make_ddim(dim_vec);
tensor->Resize(dims);
tensor->mutable_data<T>(dims, place);
}
PADDLE_ENFORCE_EQ(
tensor->dims().size(),
1,
platform::errors::OutOfRange("Only support transfer size 1 value."));
PADDLE_ENFORCE_EQ(
tensor->dims().at(0),
1,
platform::errors::OutOfRange("Only support transfer size 1 value."));
if (platform::is_gpu_place(tensor->place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
auto dim = phi::make_ddim({1});
cpu_tensor.mutable_data<T>(dim, platform::CPUPlace());
auto* cpu_tensor_ptr = cpu_tensor.data<T>();
cpu_tensor_ptr[0] = value;
framework::TensorCopySync(cpu_tensor, tensor->place(), tensor);
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
}
template <typename T>
T GetVarResult(const std::string& name,
int64_t scope_id,
framework::Scope* scope) {
auto* var = scope->FindVar(name);
PADDLE_ENFORCE(var,
platform::errors::NotFound(
"Variable %s not exists in scope %ld", name, scope_id));
const auto& tensor = var->Get<phi::DenseTensor>();
T res;
if (platform::is_gpu_place(tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopySync(tensor, platform::CPUPlace(), &cpu_tensor);
res = cpu_tensor.data<T>()[0];
#endif
} else if (platform::is_cpu_place(tensor.place())) {
res = tensor.data<T>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
} // namespace
ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
......@@ -33,57 +108,49 @@ void ComputeInterceptor::PrepareDeps() {
auto& downstream = node_->downstream();
for (auto up : upstream) {
in_readys_.emplace(up.first, std::make_pair(up.second, 0));
in_stops_.emplace(up.first, false);
std::map<int64_t, int64_t> ready_size_map;
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
ready_size_map.emplace(i, 0);
}
in_readys_.emplace(up.first, std::make_pair(up.second, ready_size_map));
}
for (auto down : downstream) {
out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
}
// source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) {
is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(),
0,
platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
in_readys_.emplace(-1,
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_ = downstream.empty();
}
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) {
auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it,
in_readys_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id));
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) {
it->second.second += GetTaskNode()->max_run_times();
return;
}
auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
PADDLE_ENFORCE_LE(ready_size,
max_ready_size,
platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id,
ready_size,
max_ready_size));
it->second.second = ready_size;
const auto& ready_scope_map = it->second.second;
int64_t ready_size = 0;
for (auto& scope_iter : ready_scope_map) {
ready_size += scope_iter.second;
}
if (max_ready_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
ready_size,
max_ready_size,
platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id,
ready_size,
max_ready_size));
}
PADDLE_ENFORCE_NE(
it->second.second.find(scope_id),
it->second.second.end(),
platform::errors::OutOfRange(
"Interceptor %lld can not find scope %lld in upstream ready map",
interceptor_id_,
scope_id));
it->second.second.at(scope_id) = ready_scope_map.at(scope_id) + 1;
}
void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
......@@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
}
bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) {
auto ready_size = ins.second.second;
// not ready, return false
if (ready_size == 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
for (int64_t i = 0; i < node_->max_run_times(); ++i) {
bool flag = true;
for (auto& ins : in_readys_) {
auto ready_size_map = ins.second.second;
flag = flag && (ready_size_map.at(i) != 0);
}
if (flag) {
for (auto iter : scope_id_to_finish_flag_) {
if (iter.first == i) {
break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first;
return false;
}
}
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready.";
return false;
}
}
return true;
return false;
}
bool ComputeInterceptor::CanWriteOutput() {
for (auto& outs : out_buffs_) {
auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second;
if (max_buffer_size == INFINITE_BUFFER_SIZE) {
continue;
}
// full, return false
if (used_size == max_buffer_size) {
VLOG(3) << "Interceptor " << GetInterceptorId()
......@@ -137,30 +222,76 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
if (max_buff_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
}
outs.second.second = used_size;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " for step: " << step_;
Send(down_id, ready_msg);
bool need_send_vars = !(node_->vars_to_dtype().empty());
if (need_send_vars) {
InterceptorMessage ready_msg = PrepareVarsMsg();
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_with_vars msg to " << down_id
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
} else {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
}
}
}
InterceptorMessage ComputeInterceptor::PrepareVarsMsg() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* scope = microbatch_scopes_[cur_scope_id_];
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_WITH_VARS);
ready_msg.set_scope_idx(cur_scope_id_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
for (auto iter : node_->vars_to_dtype()) {
VarList* vars = ready_msg.add_vars_list();
const auto& var_name = iter.first;
vars->set_name(var_name);
std::ostringstream ss;
auto& dev_ctx = *pool.Get(place_);
auto* var = scope->FindVar(var_name);
PADDLE_ENFORCE(
var,
platform::errors::NotFound(
"Variable %s not exists in scope %ld", var_name, cur_scope_id_));
const auto& tensor = var->Get<phi::DenseTensor>();
SerializeToStream(ss, tensor, dev_ctx);
vars->set_stensor(ss.str());
VLOG(3) << "Prepare vars msg " << var_name << " with dimension "
<< tensor.dims() << " dtype " << tensor.dtype();
}
return ready_msg;
}
void ComputeInterceptor::ReplyCompletedToUpStream() {
for (auto& ins : in_readys_) {
auto up_id = ins.first;
auto ready_size = ins.second.second;
auto ready_size = ins.second.second.at(cur_scope_id_);
ready_size -= 1;
PADDLE_ENFORCE_GE(
ready_size,
......@@ -169,109 +300,114 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
"upstream=%lld ready_size must >= 0, but now got %lld",
up_id,
ready_size));
ins.second.second = ready_size;
ins.second.second[cur_scope_id_] = ready_size;
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id
<< " for step: " << step_;
if (is_source_ && up_id == -1) return;
<< " in scope: " << cur_scope_id_;
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS);
reply_msg.set_scope_idx(cur_scope_id_);
Send(up_id, reply_msg);
}
}
void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
op->Run(*microbatch_scopes_[cur_scope_id_], place_);
if (gc_) {
framework::DeleteUnusedTensors(
*microbatch_scopes_[step_ % node_->max_run_times()],
op,
node_->unused_vars(),
gc_.get());
framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
op,
node_->unused_vars(),
gc_.get());
}
}
}
void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
VLOG(3) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_;
RunOps();
++step_;
if (!scope_id_to_finish_flag_.empty()) {
PADDLE_ENFORCE_NE(
scope_id_to_finish_flag_.find(cur_scope_id_),
scope_id_to_finish_flag_.end(),
platform::errors::NotFound(
"Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
scope_id_to_finish_flag_.erase(cur_scope_id_);
}
// send to downstream and increase buff used
SendDataReadyToDownStream();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
// Try to stop Carrier
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier.";
// FIXME(wangxi): with multi sink interceptor
StopCarrier();
}
}
}
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
received_stop_ = true;
// source node has no upstream, stop is send by carrier or others
if (is_source_ && up_id == -1) return;
auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it,
in_stops_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_stops.", up_id));
PADDLE_ENFORCE_EQ(
it->second,
false,
platform::errors::AlreadyExists("Already received stop from %lld, stop "
"cannot be send more than once."));
it->second = true;
}
void ComputeInterceptor::TryStop() {
if (!received_stop_) return;
// can stop only when all upstream is stop and
// downstream complete
for (auto& in_stop : in_stops_) {
if (!in_stop.second) return;
}
for (auto& out_buff : out_buffs_) {
auto used_size = out_buff.second.second;
if (used_size != 0) return;
}
// send stop to downstream
for (auto& out : out_buffs_) {
auto down_id = out.first;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(down_id, stop);
void ComputeInterceptor::DecodeMsgVars(const InterceptorMessage& msg) {
int64_t scope_id = msg.scope_idx();
PADDLE_ENFORCE_LT(scope_id,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
scope_id));
auto* scope = microbatch_scopes_[scope_id];
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
for (const auto& var_iter : msg.vars_list()) {
const std::string& name = var_iter.name();
auto& dev_ctx = *pool.Get(place_);
std::istringstream ss(var_iter.stensor());
auto* var = scope->Var(name);
auto* tensor = var->GetMutable<phi::DenseTensor>();
DeserializeFromStream(ss, tensor, dev_ctx);
VLOG(3) << "Set vars " << name << " with value in scope " << scope_id
<< " with dims " << tensor->dims() << " with dtype "
<< tensor->dtype();
}
stop_ = true;
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id());
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_useless " << msg.src_id() << " "
<< msg.scope_idx() << " ";
DecreaseBuff(msg.src_id());
Run();
} else if (msg.message_type() == STOP) {
ReceivedStop(msg.src_id());
} else if (msg.message_type() == DATA_WITH_VARS) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_with_vars " << msg.src_id() << " "
<< msg.scope_idx() << " ";
DecodeMsgVars(msg);
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == START_LOOP) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive start_loop " << msg.src_id() << " " << msg.scope_idx()
<< " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
scope_id_to_finish_flag_.emplace(msg.scope_idx(), false);
Run();
}
TryStop();
}
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);
......
......@@ -14,6 +14,7 @@
#pragma once
#include <queue>
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
......@@ -21,6 +22,8 @@
namespace paddle {
namespace distributed {
const int64_t INFINITE_BUFFER_SIZE = -1;
class ComputeInterceptor : public Interceptor {
public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node);
......@@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor {
virtual void RunOps();
virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream();
virtual void Compute(const InterceptorMessage& msg);
void Run();
void IncreaseReady(int64_t up_id, int64_t scope_id);
void DecreaseBuff(int64_t down_id);
int64_t cur_scope_id_;
int64_t step_{0};
// upstream_id-->(max_ready_size, scope-->ready_size)
std::map<int64_t, std::pair<int64_t, std::map<int64_t, int64_t>>>
in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
private:
void PrepareDeps();
InterceptorMessage PrepareVarsMsg();
void DecodeMsgVars(const InterceptorMessage& msg);
void IncreaseReady(int64_t up_id);
void DecreaseBuff(int64_t down_id);
bool IsInputReady();
bool CanWriteOutput();
void Run();
void Compute(const InterceptorMessage& msg);
void ReceivedStop(int64_t up_id);
void TryStop();
bool is_source_{false};
bool is_last_{false};
// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
bool received_stop_{false};
std::map<int64_t, bool> in_stops_{};
std::map<int64_t, bool> scope_id_to_finish_flag_;
};
} // namespace distributed
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void CondInterceptor::PrepareDeps() {
auto& upstream = node_->upstream();
auto& downstream = node_->downstream();
auto& id_to_dep_type = node_->id_to_dep_type();
for (const auto& up : upstream) {
if (id_to_dep_type.at(up.first) == DependType::NORMAL) {
normal_in_id_.insert(up.first);
} else if (id_to_dep_type.at(up.first) == DependType::LOOP) {
loop_id_ = up.first;
}
}
for (const auto& down : downstream) {
if (id_to_dep_type.at(down.first) == DependType::NORMAL) {
normal_out_id_.insert(down.first);
} else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) {
stop_loop_id_ = down.first;
}
}
}
bool CondInterceptor::GetCondResult() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* cond_var =
microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var());
PADDLE_ENFORCE(cond_var,
platform::errors::NotFound(
"Condition variable %s not exists in scope %ld",
node_->cond_var(),
cur_scope_id_));
const auto& cond_tensor = cond_var->Get<phi::DenseTensor>();
bool res = false;
if (platform::is_gpu_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::SendStartLoop(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(START_LOOP);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_USELESS);
ready_msg.set_scope_idx(cur_scope_id_);
Send(up_id, ready_msg);
}
void CondInterceptor::Compute() {
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_;
for (auto& down_id : normal_out_id_) {
SendStartLoop(down_id);
}
++num_of_scopes_;
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY ||
msg.message_type() == DATA_WITH_VARS) {
if (msg.src_id() == loop_id_) {
--num_of_scopes_;
VLOG(3) << "Receving loop again message from " << msg.src_id()
<< " waiting other " << num_of_scopes_ << " scopes ready";
ready_scope_id_.emplace_back(msg.scope_idx());
if (num_of_scopes_ == 0) {
std::sort(ready_scope_id_.begin(), ready_scope_id_.end());
for (auto scope_id : ready_scope_id_) {
VLOG(3) << "Start a new loop in scope " << scope_id;
cur_scope_id_ = scope_id;
Compute();
}
ready_scope_id_.clear();
}
} else {
cur_scope_id_ = msg.scope_idx();
Compute();
}
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) {
ReplyDataIsUseless(up_id);
}
// Gc the variable in while block
int64_t scope_id = msg.scope_idx();
if (gc_) {
VLOG(3) << "Release vars in while block in scope " << scope_id;
framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id],
node_->while_block_vars(),
gc_.get());
}
}
}
}
REGISTER_INTERCEPTOR(Cond, CondInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iomanip>
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class CondInterceptor final : public Interceptor {
public:
CondInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void PrepareDeps();
void Run(const InterceptorMessage& msg);
void Compute();
bool GetCondResult();
void SendDataReady(int64_t down_id);
void SendStartLoop(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id);
int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_;
std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_;
int64_t loop_id_;
int64_t num_of_scopes_{0};
std::vector<int64_t> ready_scope_id_;
};
} // namespace distributed
} // namespace paddle
......@@ -14,6 +14,8 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -24,6 +26,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace distributed {
......@@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
}
}
void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto task_node : task_nodes) {
for (auto op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
namespace {
void GetSubBlockTask(const std::vector<TaskNode*>& tasks,
TaskNode* cur_task,
std::set<TaskNode*>* sub_block_task) {
auto& downstream = cur_task->downstream();
auto& id_to_dep_type = cur_task->id_to_dep_type();
for (auto& down : downstream) {
int64_t task_id = down.first;
if (id_to_dep_type.at(task_id) == DependType::NORMAL) {
for (const auto& task : tasks) {
if (task->task_id() == task_id) {
sub_block_task->emplace(task);
GetSubBlockTask(tasks, task, sub_block_task);
}
}
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
}
void PreventVarsDelete(
std::unordered_map<const framework::OperatorBase*,
std::vector<std::string>>* unused_vars,
const std::vector<std::string>& vars_not_gc) {
std::vector<const framework::OperatorBase*> changed_ops;
for (auto pair : unused_vars) {
for (const auto& pair : *unused_vars) {
const framework::OperatorBase* op = pair.first;
std::vector<std::string> unused = pair.second;
for (auto name : inference_root_scope_vars) {
auto iter = std::find(unused.begin(), unused.end(), name);
if (iter != unused.end()) {
std::vector<std::string> cur_unused = pair.second;
for (auto name : vars_not_gc) {
auto iter = std::find(cur_unused.begin(), cur_unused.end(), name);
if (iter != cur_unused.end()) {
VLOG(3) << "Removing var: [" << name
<< "] from the unused vars list of op: [" << op->Type() << "]";
unused.erase(iter);
cur_unused.erase(iter);
if (std::find(changed_ops.begin(), changed_ops.end(), op) ==
changed_ops.end()) {
// record the op whose unused vars have been updated
......@@ -93,28 +96,120 @@ void FleetExecutor::Init(
}
}
// update the unused vars list in the map
unused_vars[op] = unused;
unused_vars->at(op) = cur_unused;
}
for (auto op : changed_ops) {
auto iter = unused_vars.find(op);
const auto& iter = unused_vars->find(op);
if (iter->second.empty()) {
// remove those ops in the map that have empty unused vars list
VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map.";
unused_vars.erase(iter);
unused_vars->erase(iter);
}
}
}
std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
PreventVarsDelete(&unused_vars, vars_not_gc);
for (const auto& pair : unused_vars) {
if (pair.first->Type() == "while") {
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
}
}
return while_block_vars;
}
} // namespace
void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// Set the unused var after running while op
std::set<TaskNode*> sub_block_tasks;
std::vector<std::string> while_block_vars;
for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = GetUnusedVarsAfterWhile(
program_desc, task_node, inference_root_scope_vars);
VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) {
VLOG(3) << var;
}
task_node->SetWhileBlockVars(while_block_vars);
}
}
std::vector<framework::OperatorBase*> sub_block_ops;
for (const auto& task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
sub_block_ops.emplace_back(op);
}
}
// Analyse the unused vars in block 0. The operators in block 1
// should be passed in first for prevent vars been released but removed soon.
// Since the unused vars in block 1 need to analyse separately.
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& task_node : task_nodes) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto global_unused_vars =
framework::GetUnusedVars(program_desc.Block(0), ops, {});
for (auto& unique_op : ops) {
unique_op.release();
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
PreventVarsDelete(&global_unused_vars, inference_root_scope_vars);
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars);
}
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
}
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
for (auto& unique_op : ops) {
unique_op.release();
}
VLOG(5) << runtime_graph_->DebugString();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
......@@ -126,7 +221,8 @@ void FleetExecutor::Init(
place,
num_micro_batches,
program_desc,
inference_root_scope_vars);
inference_root_scope_vars,
micro_scope_list);
GlobalVal<MessageBus>::Get()->Barrier();
}
......@@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier(
const platform::Place& place,
int64_t num_micro_batches,
const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars) {
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(),
......@@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier(
scope,
num_micro_batches,
place,
inference_root_scope_vars);
inference_root_scope_vars,
micro_scope_list);
}
void FleetExecutor::InitMessageBus() {
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
......@@ -45,7 +46,8 @@ class FleetExecutor final {
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars = {});
const std::vector<std::string>& inference_root_scope_vars = {},
const std::vector<framework::Scope*>& micro_scope_list = {});
void Run(const std::string& carrier_id);
private:
......@@ -57,7 +59,8 @@ class FleetExecutor final {
const platform::Place& place,
int64_t num_micro_batches,
const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars = {});
const std::vector<std::string>& inference_root_scope_vars = {},
const std::vector<framework::Scope*>& micro_scope_list = {});
FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::unordered_set<std::string> carrier_ids_;
......
......@@ -93,7 +93,6 @@ class Interceptor {
TaskNode* node_;
// for stop
bool stop_{false};
void StopCarrier();
// for runtime
......@@ -114,9 +113,6 @@ class Interceptor {
std::mutex mutex_;
std::deque<InterceptorMessage> messages_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
};
class InterceptorFactory {
......
......@@ -24,6 +24,21 @@ enum MessageType {
ERR = 4; // current Interceptor encounters error
RESET = 5; // reset the status
START = 6;
DATA_WITH_VARS = 7;
START_LOOP = 8;
}
enum ValueType {
INT3 = 0;
INT6 = 1;
FLOAT = 2;
DOUBLE = 3;
BOOL = 4;
}
message VarList {
required string name = 1;
required string stensor = 2;
}
message InterceptorMessage {
......@@ -32,6 +47,7 @@ message InterceptorMessage {
optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ];
repeated VarList vars_list = 6;
}
message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
......
......@@ -25,7 +25,7 @@ namespace distributed {
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class SinkInterceptor : public Interceptor {
class SinkInterceptor final : public Interceptor {
public:
SinkInterceptor(int64_t interceptor_id, TaskNode* node);
......
......@@ -25,7 +25,7 @@ namespace distributed {
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class SourceInterceptor : public Interceptor {
class SourceInterceptor final : public Interceptor {
public:
SourceInterceptor(int64_t interceptor_id, TaskNode* node);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
StartInterceptor::StartInterceptor(int64_t interceptor_id, TaskNode* node)
: ComputeInterceptor(interceptor_id, node) {
auto& downstream = node_->downstream();
PADDLE_ENFORCE_EQ(
downstream.size(),
1,
platform::errors::OutOfRange(
"The downstream for StartInterceptor only support 1 for now."));
for (auto down : downstream) {
batch_size_ = down.second;
}
bool evenly_divisible = ((node_->max_run_times() % batch_size_) == 0);
PADDLE_ENFORCE(
evenly_divisible,
platform::errors::Fatal(
"Wrong config: Num of step should be divided by batch_size,"
"num_step=%lld, batch_size=%lld",
node_->max_run_times(),
batch_size_));
}
void StartInterceptor::RunOps() {
finish_count_++;
ComputeInterceptor::RunOps();
}
void StartInterceptor::SendDataReadyToDownStream() {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
if (max_buff_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
}
outs.second.second = used_size;
}
if (finish_count_ == batch_size_) {
for (int64_t i = 0; i < batch_size_; ++i) {
int64_t scope_id = step_ % node_->max_run_times();
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_id);
VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << scope_id;
Send(down_id, ready_msg);
}
step_++;
}
}
}
void StartInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
VLOG(3) << "Start interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Start interceptor receive data_is_useless " << msg.src_id()
<< " " << finish_count_;
finish_count_--;
if (finish_count_ == 0) {
for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
DecreaseBuff(down_id);
}
}
for (int64_t i = 0; i < batch_size_; ++i) {
Run();
}
}
}
}
REGISTER_INTERCEPTOR(Start, StartInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace paddle {
namespace distributed {
class StartInterceptor final : public ComputeInterceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void SendDataReadyToDownStream() override;
void RunOps() override;
void Compute(const InterceptorMessage& msg) override;
int64_t batch_size_{0};
int64_t finish_count_{0};
int64_t step_{0};
};
} // namespace distributed
} // namespace paddle
......@@ -24,33 +24,14 @@ namespace {
using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: program_(program),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
max_run_times_(max_run_times) {
// TODO(liyurui): Will be removed when execute program is supported.
Init();
}
......@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank
......@@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_ = program;
}
void TaskNode::SetVarsToDtype(
const std::map<std::string, std::string>& vars_to_dtype) {
vars_to_dtype_ = vars_to_dtype;
}
void TaskNode::SetVarsToShape(
const std::map<std::string, std::vector<int64_t>>& vars_to_shape) {
vars_to_shape_ = vars_to_shape;
}
void TaskNode::Init(bool use_feed_fetch_ops) {
if (!use_feed_fetch_ops) {
VLOG(3) << "TaskNode will be inited without feed and fetch ops";
......@@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
max_run_times_(max_run_times) {
if (op_descs.empty()) {
return;
}
......@@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: ops_(ops),
role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
max_run_times_(max_run_times) {}
TaskNode::TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
max_run_times_(max_run_times) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) {
bool TaskNode::AddUpstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = upstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second;
}
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) {
bool TaskNode::AddDownstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = downstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second;
}
......
......@@ -14,8 +14,10 @@
#pragma once
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -29,38 +31,30 @@ class OpDesc;
} // namespace framework
namespace distributed {
enum class DependType { NORMAL, LOOP, STOP_LOOP };
class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
// TODO(liyurui): This will be the only constructor for task node
TaskNode(paddle::framework::ProgramDesc* program,
int64_t task_id,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
......@@ -69,11 +63,11 @@ class TaskNode final {
int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::string& cond_var() const { return cond_var_; }
const std::unordered_map<int64_t, int64_t>& upstream() const {
return upstream_;
}
......@@ -86,11 +80,20 @@ class TaskNode final {
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_;
}
const std::unordered_map<int64_t, DependType> id_to_dep_type() const {
return id_to_dep_type_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const {
return unused_vars_;
}
const std::vector<std::string> while_block_vars() const {
return while_block_vars_;
}
void SetCondVarName(const std::string& cond_var_name) {
cond_var_ = cond_var_name;
}
void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value);
......@@ -101,11 +104,27 @@ class TaskNode final {
unused_vars) {
unused_vars_ = unused_vars;
}
void SetWhileBlockVars(const std::vector<std::string>& vars) {
while_block_vars_ = vars;
}
// upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddUpstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
bool AddDownstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
std::string DebugString() const;
const std::map<std::string, std::string>& vars_to_dtype() const {
return vars_to_dtype_;
}
void SetVarsToDtype(const std::map<std::string, std::string>& vars_to_dtype);
const std::map<std::string, std::vector<int64_t>>& vars_to_shape() const {
return vars_to_shape_;
}
void SetVarsToShape(
const std::map<std::string, std::vector<int64_t>>& vars_to_shape);
private:
DISABLE_COPY_AND_ASSIGN(TaskNode);
......@@ -115,16 +134,22 @@ class TaskNode final {
// task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_;
// task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_;
std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_;
std::vector<std::string> while_block_vars_;
std::map<std::string, std::string> vars_to_dtype_;
std::map<std::string, std::vector<int64_t>> vars_to_shape_;
int32_t role_;
int64_t rank_;
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;
int64_t run_per_steps_{1};
int64_t run_at_offset_{0};
......
......@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2);
TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// source->a->b->sink
......
......@@ -21,61 +21,49 @@ limitations under the License. */
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace distributed {
class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(1, stop); // stop 1, compute
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
}
};
TEST(ComputeInterceptor, Compute) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
// a->b->c
TaskNode* source =
new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 3);
TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* sink = new TaskNode(0, SINK_ID, 3);
// source->a->b->sink
source->AddDownstreamTask(0);
node_a->AddUpstreamTask(SOURCE_ID);
node_a->AddDownstreamTask(1, 3);
node_b->AddUpstreamTask(0, 3);
node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1);
node_b->AddDownstreamTask(SINK_ID);
sink->AddUpstreamTask(1);
Interceptor* a =
carrier->SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
// test run three times
a->Send(1, msg);
a->Send(1, msg);
a->Send(1, msg);
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
......
......@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
......
......@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
return;
}
......
......@@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) {
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3;
int64_t micro_steps = 1;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, 1); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1);
TaskNode* node_c = new TaskNode(0, 0, 2, 1);
TaskNode* node_d = new TaskNode(0, 0, 3, 1);
TaskNode* node_e = new TaskNode(0, 0, 4, 1);
TaskNode* node_f = new TaskNode(0, 0, 5, 1);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->e->f->sink
......
......@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, micro_steps);
TaskNode* node_c = new TaskNode(0, 0, 2, micro_steps);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->sink
......
......@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
......
......@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
......
......@@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast,
ops::CBroadcastOpCUDAKernel<plat::bfloat16>,
#endif
ops::CBroadcastOpCUDAKernel<int>,
ops::CBroadcastOpCUDAKernel<uint8_t>,
ops::CBroadcastOpCUDAKernel<int8_t>,
ops::CBroadcastOpCUDAKernel<int64_t>,
ops::CBroadcastOpCUDAKernel<plat::float16>);
......@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool(cudnn_deterministic);
namespace paddle {
namespace operators {
......@@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table,
}
}
template <typename T, typename IndexT>
__global__ void CEmbeddingGradSerial(T *table,
const T *output,
const IndexT *ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
if (i == 0) {
for (int j = 0; j < limit; j++) {
size_t row = j / columns;
size_t col = j % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
paddle::platform::CudaAtomicAdd(&table[real_idx * columns + col],
output[i]);
}
}
}
}
}
template <typename T>
class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -163,28 +191,56 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int32_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbeddingGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int64_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
blocks = 1;
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGradSerial<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int32_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbeddingGradSerial<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int64_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
}
} else {
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int32_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbeddingGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int64_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
}
}
}
};
......
......@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace paddle {
namespace pybind {
using paddle::distributed::DependType;
using paddle::distributed::DistModel;
using paddle::distributed::DistModelConfig;
using paddle::distributed::DistModelDataBuf;
......@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
.def(
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::enum_<DependType>(*m, "DependType")
.value("NORMAL", DependType::NORMAL)
.value("LOOP", DependType::LOOP)
.value("STOP_LOOP", DependType::STOP_LOOP);
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
int64_t,
int64_t,
int64_t>())
.def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t,
const std::vector<framework::OpDesc*>&,
int64_t,
int64_t,
int64_t,
int64_t>())
.def("task_id", &TaskNode::task_id)
.def("add_upstream_task", &TaskNode::AddUpstreamTask)
......@@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_pre_steps", &TaskNode::SetRunPerSteps)
.def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType)
.def("set_cond_var_name", &TaskNode::SetCondVarName)
.def("role", &TaskNode::role)
.def("set_vars_to_shape", &TaskNode::SetVarsToShape)
.def("set_vars_to_dtype", &TaskNode::SetVarsToDtype)
.def("init", [](TaskNode& self) { self.Init(); })
.def("set_program", &TaskNode::SetProgram);
......
......@@ -23,6 +23,8 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename InT, typename OutT>
......@@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor {
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8);
dim3 grids(gridx, 1);
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
grids.x = 1;
threads.y = 1;
}
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
......
......@@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
#########################################
# pipeline configuration
#########################################
PIPELINE = "pipeline"
set_field_default_config(PIPELINE, "enable", False)
set_field_default_config(PIPELINE, "schedule_mode", "1F1B")
set_field_default_config(PIPELINE, "micro_batch_size", 1)
set_field_default_config(PIPELINE, "accumulate_steps", 1)
set_field_default_config(PIPELINE, "generation_batch_size", 1)
#########################################
# quantization configuration
#########################################
......
......@@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode):
)
serial_startup_prog = (
engine._serial_startup_progs[mode].clone()
if mode in engine._serial_startup_progs
engine._fwd_dist_contexts[mode]._original_serial_main_program.clone()
if mode in engine._fwd_dist_contexts
else engine._orig_startup_prog.clone()
)
losses = (
......
......@@ -27,7 +27,6 @@ from .utils import convert_to_shard_spec, verify_shard_spec
class DistributedOperator:
def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op
self._serial_inputs = {}
......@@ -78,28 +77,34 @@ class DistributedOperator:
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
tensor_shape = []
else:
tensor_shape = tensor.shape
if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
self._dist_attr.set_input_dims_mapping(
tensor_name, tensor_dims_mapping
)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = []
else:
tensor_shape = tensor.shape
self._serial_outputs[tensor_name] = tensor
if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name,
tensor_dims_mapping)
self._dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping
)
if self._dist_attr.op_type is None:
self._dist_attr.op_type = self.serial_op.type
if self._dist_attr.impl_type is None:
......@@ -117,8 +122,10 @@ class DistributedOperator:
new_dist_attr = {}
for key, value in dist_attr.items():
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names \
or key.name in self._serial_op.output_arg_names:
if (
key.name in self._serial_op.input_arg_names
or key.name in self._serial_op.output_arg_names
):
new_dist_attr[key] = value
else:
new_dist_attr[key] = value
......@@ -129,13 +136,15 @@ class DistributedOperator:
for tensor_name in self._serial_op.input_arg_names:
tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_input_dist_attr(tensor_name,
tensor_dist_attr)
new_dist_attr.set_input_dist_attr(
tensor_name, tensor_dist_attr
)
for tensor_name in self._serial_op.output_arg_names:
tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_output_dist_attr(tensor_name,
tensor_dist_attr)
new_dist_attr.set_output_dist_attr(
tensor_name, tensor_dist_attr
)
else:
assert False, "Cannot recognize the {} parameter.".format(dist_attr)
return new_dist_attr
......@@ -146,8 +155,10 @@ class DistributedOperator:
for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping
if self.get_serial_input(
name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
if (
self.get_serial_input(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
shape = []
else:
shape = self.get_serial_input(name).shape
......@@ -155,7 +166,8 @@ class DistributedOperator:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
self.dist_attr.process_mesh.topology
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
......@@ -166,8 +178,12 @@ class DistributedOperator:
for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping
if self.get_serial_output(name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY\
or self.get_serial_output(name).type == core.VarDesc.VarType.STEP_SCOPES:
if (
self.get_serial_output(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.get_serial_output(name).type
== core.VarDesc.VarType.STEP_SCOPES
):
shape = []
else:
shape = self.get_serial_output(name).shape
......@@ -175,7 +191,8 @@ class DistributedOperator:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
self.dist_attr.process_mesh.topology
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
......@@ -185,8 +202,9 @@ class DistributedOperator:
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(),
self.serial_op.desc.id())
str = "{{op type: {}, op id: {}".format(
self.serial_op.desc.type(), self.serial_op.desc.id()
)
# str += ", {}".format(self.dist_attr)
# return str
......@@ -195,8 +213,9 @@ class DistributedOperator:
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
str += ", process_mesh ({}): {}".format(
annotated_str, self.dist_attr.process_mesh
)
for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
......@@ -212,7 +231,8 @@ class DistributedOperator:
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
arg_name, annotated_str, is_parameter_str, dims_mapping
)
for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
......@@ -228,12 +248,14 @@ class DistributedOperator:
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
arg_name, annotated_str, is_parameter_str, dims_mapping
)
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type)
self.dist_attr._impl_idx, self.dist_attr._impl_type
)
return str
......@@ -242,7 +264,11 @@ class DistributedOperator:
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_op" or k == "_serial_inputs" or k == "_serial_outputs":
if (
k == "_serial_op"
or k == "_serial_inputs"
or k == "_serial_outputs"
):
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
......@@ -250,9 +276,9 @@ class DistributedOperator:
class DistributedOperatorHelper:
def __init__(self, serial_op, process_mesh, in_dims_mappings,
out_dims_mappings):
def __init__(
self, serial_op, process_mesh, in_dims_mappings, out_dims_mappings
):
self._serial_op = serial_op
self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings
......@@ -262,8 +288,11 @@ class DistributedOperatorHelper:
tensor_to_dims_mapping = {}
index = 0
if self._in_dims_mappings:
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs))
assert len(args) + len(kwargs) == len(
self._in_dims_mappings
), "The length of dims_mapping {} does not matching the length output {}.".format(
len(self._in_dims_mappings), len(args) + len(kwargs)
)
for arg in args:
if isinstance(arg, Variable) and self._in_dims_mappings:
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
......@@ -287,13 +316,17 @@ class DistributedOperatorHelper:
raise ValueError("Unrecognized outpout.")
if self._out_dims_mappings:
assert len(new_output) == len(self._out_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output))
assert len(new_output) == len(
self._out_dims_mappings
), "The length of dims_mapping {} does not matching the length output {}.".format(
len(self._out_dims_mappings), len(new_output)
)
for i, item in enumerate(new_output):
if isinstance(item, Variable) and self._out_dims_mappings:
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
......@@ -302,53 +335,68 @@ class DistributedOperatorHelper:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_input(name)
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
name)
name
)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
dims_mapping, self._process_mesh
)
assert verify_shard_spec(
shard_spec, tensor_shape, self._process_mesh
), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh
)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
for name in dist_op.serial_op.output_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_output(name)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
name)
name
)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
dims_mapping, self._process_mesh
)
assert verify_shard_spec(
shard_spec, tensor_shape, self._process_mesh
), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh
)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
default_dist_ctx.add_process_mesh(self._process_mesh)
return output
......@@ -34,6 +34,7 @@ from paddle.fluid.framework import Operator, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
from paddle.distributed.parallel import _is_global_parallel_initialize
from .callbacks import config_callbacks
from .converter import Converter
......@@ -160,7 +161,6 @@ class Engine:
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = validate_opt(optimizer)
self._orig_optimizer = copy.deepcopy(self._optimizer)
metrics = metrics or []
for metric in to_list(metrics):
......@@ -185,12 +185,18 @@ class Engine:
self._strategy = strategy or Strategy()
self._logger = get_logger(logging.INFO)
if os.getenv("POD_NAME"):
if os.getenv("POD_NAME") and not _is_global_parallel_initialize():
self._logger.info(
"Distribute training by paddle.distributed.launch"
)
fleet.init(is_collective=True)
# for compute cost
# TODO: remove _fwd_main_progs and _orig_optimizer
self._fwd_dist_contexts = {}
self._fwd_main_progs = {}
self._orig_optimizer = copy.deepcopy(self._optimizer)
self._executor = None
self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size()
......@@ -200,14 +206,6 @@ class Engine:
self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {}
self._fwd_main_progs = {}
self._fwd_dist_contexts = {}
self._serial_main_progs = {}
self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs
self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._feed_vars = {}
self._fetch_vars = {}
self._planners = {}
self._has_prepared = {"train": False, "eval": False, "predict": False}
self._has_prepared_reader = {
......@@ -338,9 +336,9 @@ class Engine:
return inputs, labels
def _prepare_reader(self):
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
def _prepare_reader(self, feed_list=[]):
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block()
# NOTE: this list may be changed if Paddle changes the existing rules.
......@@ -361,10 +359,13 @@ class Engine:
if op.type in related_reader_ops:
reader_op_indices.append(idx)
# Step 2: insert the new reader ops to cpp
# record the read ops' desc to insert to program of forward task_node
read_ops_desc = []
new_reader_ops = []
for idx in reversed(reader_op_indices):
new_op_desc = dist_main_block.desc._prepend_op()
new_op_desc.copy_from(dist_main_block.ops[idx].desc)
read_ops_desc.append(new_op_desc)
new_op = Operator(
dist_main_block, new_op_desc, type=new_op_desc.type()
)
......@@ -383,6 +384,29 @@ class Engine:
dist_main_block._sync_with_cpp()
self._has_prepared_reader[self._mode] = True
# Insert read op to forward TaskNode if 1F1B pass is setted
if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = fleet_opt["tasks"][0]
fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block()
for var in feed_list:
if var.name not in fwd_block.vars:
fwd_block._clone_variable(var)
for op_desc in read_ops_desc:
new_op_desc = fwd_block.desc._prepend_op()
new_op_desc.copy_from(op_desc)
new_op = Operator(
fwd_block, new_op_desc, type=new_op_desc.type()
)
fwd_block.ops.insert(0, new_op)
fwd_block._sync_with_cpp()
fwd_task.set_program(fwd_prog)
def _prepare_feed(self, data, user_feeds, mode):
feeds = {}
if data is not None:
......@@ -430,14 +454,16 @@ class Engine:
fetch_names.append([])
fetch_indices.append(group_indices)
dist_context = self._dist_contexts[mode]
fetch_vars = dist_context.serial_fetch_vars
if mode != "predict":
_process_fetch_group("loss", self._fetch_vars[mode]["loss"])
_process_fetch_group("loss", fetch_vars["loss"])
if mode != "predict":
metrics = self._fetch_vars[mode]["metrics"]
metrics = fetch_vars["metrics"]
for i, var_list in enumerate(metrics):
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
_process_fetch_group("outputs", fetch_vars["outputs"])
user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES)
]
......@@ -471,7 +497,8 @@ class Engine:
logs["loss"] = outs[idx][0]
group_idx += 1
# logging metrics
metric_vars = self._fetch_vars[mode]["metrics"]
dist_context = self._dist_contexts[mode]
metric_vars = dist_context.serial_fetch_vars["metrics"]
if metric_vars:
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
......@@ -502,15 +529,18 @@ class Engine:
logs["fetches"] = logs_fetch
return logs
def _prepare_program(self, mode):
def _prepare_program(self, mode, init_parameters=True):
# Do the build process
self._build(mode)
# Do the planning process
self._plan(mode)
# Do the parallel process
self._parallel(mode)
# Init comm and startup program
self._initialize(mode)
# Init comm
self._init_comm()
if init_parameters:
# startup program
self._initialize(mode)
self._has_prepared[mode] = True
def _build(self, mode):
......@@ -542,8 +572,8 @@ class Engine:
paddle.enable_static()
else:
# build program in static mode
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
dist_context = self._dist_contexts.get(mode, None)
if dist_context is not None:
return
outputs = []
......@@ -581,7 +611,7 @@ class Engine:
metric.compute(*(outputs + self._labels))
)
)
else:
elif mode == "train":
assert isinstance(
self._loss, Variable
), "the type of `loss` of the Engine arguments should be Variable."
......@@ -724,37 +754,21 @@ class Engine:
)
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode):
# Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode
].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[
mode
].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[
mode
].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[
mode
].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode]._serial_optimizer
def _init_comm(self):
if self._nranks > 1:
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
if self._strategy.auto_mode == "full":
initialize_pg_in_full_mode(all_process_groups, cur_rank)
initialize_pg_in_full_mode(all_process_groups, self._cur_rank)
else:
for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks:
continue
process_group.instantiate()
def _initialize(self, mode):
place = _get_device()
if isinstance(place, fluid.CUDAPlace):
place = fluid.CUDAPlace(ParallelEnv().dev_id)
......@@ -764,15 +778,17 @@ class Engine:
np.random.seed(self._strategy.seed + self._dp_ranks[0])
random.seed(self._strategy.seed + self._dp_ranks[0])
dist_context = self._dist_contexts[mode]
if self._dygraph_mode:
dist_context = self._dist_contexts[mode]
dist_main_program = self._dist_main_progs[mode][self._cur_rank]
dist_main_program = dist_context.dist_main_programs[self._cur_rank]
self.program_helper.init(dist_main_program, place, dist_context)
if self._executor is None:
self._executor = paddle.static.Executor(place)
uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
for var in dist_startup_prog.list_vars():
scope_var = global_scope().find_var(var.name)
if scope_var and scope_var.get_tensor()._is_initialized():
......@@ -789,7 +805,9 @@ class Engine:
if self._strategy.reinit:
self._logger.info("NOTE: parameters will be re-initialized.")
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
self._executor.run(dist_startup_prog)
def fit(
......@@ -926,7 +944,7 @@ class Engine:
)
except core.EOFException:
break
lr = get_lr(self._optimizer)
lr = get_lr(self.optimizer)
logs = self._prepare_logger(
outs,
epoch,
......@@ -1262,6 +1280,7 @@ class Engine:
main_program=None,
startup_program=None,
mode=None,
init_parameters=True,
):
if mode is not None:
self.to_mode(mode)
......@@ -1304,7 +1323,7 @@ class Engine:
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
self._inputs, self._labels = inputs, labels
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
self._prepare_program(self._mode, init_parameters)
else:
self._switch_mode(self._mode)
......@@ -1355,16 +1374,17 @@ class Engine:
)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self._mode]["inputs"]
labels_var = self._feed_vars[self._mode]["labels"]
inputs_var = dist_context.serial_feed_vars["inputs"]
labels_var = dist_context.serial_feed_vars["labels"]
feed_list = []
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
......@@ -1423,16 +1443,17 @@ class Engine:
)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self._mode]["inputs"]
labels_var = self._feed_vars[self._mode]["labels"]
inputs_var = dist_context.serial_feed_vars["inputs"]
labels_var = dist_context.serial_feed_vars["labels"]
feed_list = []
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
......@@ -1462,7 +1483,7 @@ class Engine:
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks,
)
self._prepare_reader()
self._prepare_reader(feed_list)
return dataloader
def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
......@@ -1551,10 +1572,9 @@ class Engine:
def _switch_mode(self, mode):
assert (
mode in self._dist_main_progs
mode in self._dist_contexts
), "{} model is not ready, please call `prepare()` first.".format(mode)
self.to_mode(mode)
self._optimizer = self._dist_contexts[mode]._serial_optimizer
def to_mode(self, mode):
assert mode in [
......@@ -1565,8 +1585,8 @@ class Engine:
self._mode = mode
def _set_state_dict(self, mode, strict, state_dict, dist_attr):
program = self._dist_main_progs[mode][self._cur_rank]
dist_context = self._dist_contexts[mode]
program = dist_context.dist_main_programs[self._cur_rank]
cur_dist_attr = get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict)
......@@ -1618,10 +1638,10 @@ class Engine:
"""
if training:
assert self._mode in self._serial_main_progs
serial_program = self._serial_main_progs[self._mode]
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
assert self._mode in self._dist_contexts
dist_context = self._dist_contexts[self._mode]
serial_program = dist_context.serial_main_program
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
self._saver.save(
path,
serial_program=serial_program,
......@@ -1629,10 +1649,11 @@ class Engine:
dist_context=dist_context,
)
else:
assert "predict" in self._dist_main_progs
feed_vars = self._feed_vars["predict"]['inputs']
fetch_vars = self._fetch_vars["predict"]['outputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
assert "predict" in self._dist_contexts
dist_context = self._dist_contexts["predict"]
feed_vars = dist_context.serial_feed_vars['inputs']
fetch_vars = dist_context.serial_fetch_vars['outputs']
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
self._saver.save_inference_model(
path,
feed_vars,
......@@ -1758,11 +1779,13 @@ class Engine:
@property
def main_program(self):
return self._dist_main_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
return dist_context.dist_main_programs[self._cur_rank]
@property
def startup_program(self):
return self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
return dist_context.dist_startup_programs[self._cur_rank]
@property
def dist_context(self):
......@@ -1770,15 +1793,30 @@ class Engine:
@property
def serial_main_program(self):
return self._serial_main_progs[self._mode]
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_main_program
@property
def serial_startup_program(self):
return self._serial_startup_progs[self._mode]
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_startup_program
@property
def feed_vars(self):
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_feed_vars
@property
def fetch_vars(self):
return self._fetch_vars[self._mode]
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_fetch_vars
@property
def optimizer(self):
dist_context = self._dist_contexts[self._mode]
if dist_context._serial_optimizer:
return dist_context._serial_optimizer
return self._optimizer
@property
def inputs(self):
......
......@@ -67,29 +67,43 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
"""
if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
assert isinstance(
process_mesh, ProcessMesh
), "Argument process_mesh {} is not an instance of ProcessMesh".format(
process_mesh
)
else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
assert isinstance(shard_spec, list), \
"Argument shard_spec {} is not an instance of list".format(shard_spec)
dist_tensor = DistributedTensor(x)
assert (
process_mesh is not None
), "Specify the process mesh argument or use ProcessMesh context manager first."
assert isinstance(
shard_spec, list
), "Argument shard_spec {} is not an instance of list".format(shard_spec)
if isinstance(x, str):
x = paddle.fluid.default_main_program().global_block()._var_recursive(x)
dist_tensor = DistributedTensor(x)
else:
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh
if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
if (
serial_tensor.type == core.VarDesc.VarType.READER
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
if shard_spec is not None:
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
serial_tensor.name, shard_spec, tensor_shape, process_mesh)
assert verify_shard_spec(
shard_spec, tensor_shape, process_mesh
), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
serial_tensor.name, shard_spec, tensor_shape, process_mesh
)
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
shard_spec, process_mesh)
shard_spec, process_mesh
)
if process_mesh is not None:
dist_tensor.dist_attr.mark_annotated("process_mesh")
if shard_spec is not None:
......@@ -97,6 +111,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
default_dist_ctx.add_process_mesh(process_mesh)
return x
......@@ -144,41 +159,54 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
"""
if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
assert isinstance(
process_mesh, ProcessMesh
), "Argument process_mesh {} is not an instance of ProcessMesh".format(
process_mesh
)
else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
assert (
process_mesh is not None
), "Specify the process mesh argument or use ProcessMesh context manager first."
in_dims_mappings = []
if in_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \
"in_shard_spec {} is not a list of list or None".format(in_shard_specs)
assert all(
(isinstance(shard_spec, list) or shard_spec is None)
for shard_spec in in_shard_specs
), "in_shard_spec {} is not a list of list or None".format(
in_shard_specs
)
for shard_spec in in_shard_specs:
if shard_spec is not None:
in_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
convert_to_dims_mapping(shard_spec, process_mesh)
)
else:
in_dims_mappings.append(None)
out_dims_mappings = []
if out_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \
"out_shard_spec {} is not a list of list or None".format(out_shard_specs)
assert all(
(isinstance(shard_spec, list) or shard_spec is None)
for shard_spec in out_shard_specs
), "out_shard_spec {} is not a list of list or None".format(
out_shard_specs
)
for shard_spec in out_shard_specs:
if shard_spec is not None:
out_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
convert_to_dims_mapping(shard_spec, process_mesh)
)
else:
out_dims_mappings.append(None)
op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings,
out_dims_mappings)
op = DistributedOperatorHelper(
op, process_mesh, in_dims_mappings, out_dims_mappings
)
return op
def recompute(op):
class RecomputeOperator:
def __init__(self, op):
self._op = op
......@@ -219,11 +247,13 @@ def add_to_collection(collection_name, value, name=None):
_g_collections[collection_name] = []
if name is not None:
for _, v in _g_collections[collection_name]:
if v == value: return
if v == value:
return
_g_collections[collection_name].append((name, value))
else:
for _, v in _g_collections[collection_name]:
if v == value: return
if v == value:
return
_g_collections[collection_name].append((None, value))
......
......@@ -35,3 +35,4 @@ from . import dist_fused_attention
from . import dist_reduce_sum_p
from . import dist_shape
from . import dist_assign
from . import dist_scale
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册