未验证 提交 92c2dcbd 编写于 作者: L LiYuRio 提交者: GitHub

Cherry-pick fleet executor and auto parallel (#50071)

上级 4bacf2ab
...@@ -426,7 +426,8 @@ endif() ...@@ -426,7 +426,8 @@ endif()
if(WITH_DISTRIBUTE if(WITH_DISTRIBUTE
AND NOT WITH_PSLIB AND NOT WITH_PSLIB
AND NOT WITH_PSCORE) AND NOT WITH_PSCORE
AND NOT WITH_RPC)
include(external/snappy) include(external/snappy)
list(APPEND third_party_deps extern_snappy) list(APPEND third_party_deps extern_snappy)
......
...@@ -36,6 +36,8 @@ cc_library( ...@@ -36,6 +36,8 @@ cc_library(
interceptor.cc interceptor.cc
compute_interceptor.cc compute_interceptor.cc
amplifier_interceptor.cc amplifier_interceptor.cc
cond_interceptor.cc
start_interceptor.cc
source_interceptor.cc source_interceptor.cc
sink_interceptor.cc sink_interceptor.cc
message_service.cc message_service.cc
...@@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE) ...@@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE)
set_source_files_properties( set_source_files_properties(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
start_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
......
...@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() { ...@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_ // run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12 // 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15 // 4, 3 --> run at step 3, 7, 11, 15
if ((step_ % run_per_steps_) == run_at_offset_) { if ((cur_scope_id_ % run_per_steps_) == run_at_offset_) {
ComputeInterceptor::RunOps(); ComputeInterceptor::RunOps();
} }
} }
...@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() { ...@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
void AmplifierInterceptor::SendDataReadyToDownStream() { void AmplifierInterceptor::SendDataReadyToDownStream() {
// run multi times, send ready one times to downstream, that is // run multi times, send ready one times to downstream, that is
// input multi times, output one times // input multi times, output one times
if (step_ % send_down_per_steps_ == 0) { if (cur_scope_id_ % send_down_per_steps_ == 0) {
ComputeInterceptor::SendDataReadyToDownStream(); ComputeInterceptor::SendDataReadyToDownStream();
} }
} }
...@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() { ...@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
void AmplifierInterceptor::ReplyCompletedToUpStream() { void AmplifierInterceptor::ReplyCompletedToUpStream() {
// run multi times, reply one times to upstream, that is // run multi times, reply one times to upstream, that is
// input one times, output multi times // input one times, output multi times
if (step_ % reply_up_per_steps_ == 0) { if (cur_scope_id_ % reply_up_per_steps_ == 0) {
ComputeInterceptor::ReplyCompletedToUpStream(); ComputeInterceptor::ReplyCompletedToUpStream();
} }
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class AmplifierInterceptor : public ComputeInterceptor { class AmplifierInterceptor final : public ComputeInterceptor {
public: public:
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node); AmplifierInterceptor(int64_t interceptor_id, TaskNode* node);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include <algorithm> #include <algorithm>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...@@ -24,6 +25,7 @@ ...@@ -24,6 +25,7 @@
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
namespace paddle { namespace paddle {
...@@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source); ...@@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink); USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond);
USE_INTERCEPTOR(Start);
void Carrier::Init( void Carrier::Init(
int64_t rank, int64_t rank,
...@@ -54,24 +58,38 @@ void Carrier::Init( ...@@ -54,24 +58,38 @@ void Carrier::Init(
framework::Scope* scope, framework::Scope* scope,
int64_t num_micro_batches, int64_t num_micro_batches,
const platform::Place& place, const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
rank_ = rank; rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank; interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_id_to_node_ = interceptor_id_to_node; interceptor_id_to_node_ = interceptor_id_to_node;
place_ = place; place_ = place;
root_scope_ = scope; root_scope_ = scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
bool need_create_scope = micro_scope_list.empty();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
root_scope_, root_scope_,
platform::errors::InvalidArgument("root_scope can not be nullptr")); platform::errors::InvalidArgument("root_scope can not be nullptr"));
minibatch_scope_ = &root_scope_->NewScope();
microbatch_scopes_.resize(num_micro_batches); if (need_create_scope) {
for (int i = 0; i < num_micro_batches; ++i) { minibatch_scope_ = &root_scope_->NewScope();
microbatch_scopes_[i] = &minibatch_scope_->NewScope(); microbatch_scopes_.resize(num_micro_batches);
CopyParameters(i, program, inference_root_scope_vars); for (int i = 0; i < num_micro_batches; ++i) {
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program, inference_root_scope_vars);
}
} else {
microbatch_scopes_ = micro_scope_list;
for (int i = 0; i < num_micro_batches; ++i) {
CopyParameters(i, program, inference_root_scope_vars);
}
} }
// Add source and sink interceptor id to rank
interceptor_id_to_rank_.emplace(SOURCE_ID, rank);
interceptor_id_to_rank_.emplace(SINK_ID, rank);
// TODO(fleet_exe dev): thread pool // TODO(fleet_exe dev): thread pool
thread_num_ = 1; thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_); thread_pool_.SetThreadNum(thread_num_);
...@@ -93,29 +111,30 @@ void Carrier::CopyParameters( ...@@ -93,29 +111,30 @@ void Carrier::CopyParameters(
int microbatch_id, int microbatch_id,
const framework::ProgramDesc& program, const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);
std::map<std::string, int> inference_root_scope_var_map; std::map<std::string, int> inference_root_scope_var_map;
for (auto var_name : inference_root_scope_vars) { for (auto var_name : inference_root_scope_vars) {
inference_root_scope_var_map.insert({var_name, 1}); inference_root_scope_var_map.insert({var_name, 1});
} }
for (auto& var : global_block.AllVars()) { for (size_t i = 0; i < program.Size(); ++i) {
std::string var_name = var->Name(); for (auto& var : program.Block(i).AllVars()) {
bool force_root = inference_root_scope_var_map.find(var_name) != std::string var_name = var->Name();
inference_root_scope_var_map.end(); bool force_root = inference_root_scope_var_map.find(var_name) !=
if (force_root) { inference_root_scope_var_map.end();
VLOG(4) << var_name << " will be forced to be created in the root scope."; if (force_root) {
} VLOG(4) << var_name
if ((var->Persistable() || force_root) && microbatch_id == 0) { << " will be forced to be created in the root scope.";
auto* ptr = root_scope_->Var(var->Name()); }
InitializeVariable(ptr, var->GetType()); if ((var->Persistable() || force_root) && microbatch_id == 0) {
VLOG(5) << "Create persistable var: " << var->Name() auto* ptr = root_scope_->Var(var->Name());
<< ", which pointer is " << ptr; InitializeVariable(ptr, var->GetType());
} else if (!var->Persistable()) { VLOG(5) << "Create persistable var: " << var->Name()
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); << ", which pointer is " << ptr;
VLOG(5) << "Create variable " << var->Name() << " for microbatch " } else if (!var->Persistable()) {
<< microbatch_id << ", which pointer is " << ptr << "."; auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
InitializeVariable(ptr, var->GetType()); VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
} }
} }
} }
...@@ -159,16 +178,11 @@ void Carrier::Start() { ...@@ -159,16 +178,11 @@ void Carrier::Start() {
true, true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Using carrier before initialized.")); "Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) { InterceptorMessage start_msg;
VLOG(3) << "Carrier Start is sending start to source interceptor " << id start_msg.set_src_id(SOURCE_ID);
<< "."; start_msg.set_dst_id(SOURCE_ID);
InterceptorMessage start_msg; start_msg.set_message_type(START);
// source node data_is_ready is send by carrier, so set src_id=-1 Send(start_msg);
start_msg.set_src_id(-1);
start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
Send(start_msg);
}
// TODO(wangxi): async step // TODO(wangxi): async step
Wait(); Wait();
dev_ctx_->Wait(); dev_ctx_->Wait();
...@@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() { ...@@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() {
auto gc = GetGC(place_); auto gc = GetGC(place_);
// create source and sink task node
auto max_run_times = microbatch_scopes_.size();
TaskNode* source = new TaskNode(
rank_, SOURCE_ID, max_run_times); // rank, task_id, max_run_times
TaskNode* sink = new TaskNode(rank_, SINK_ID, max_run_times);
// find nodes without upstreams or without downstreams
std::vector<TaskNode*> origin_sources, origin_sinks;
for (const auto& item : interceptor_id_to_node_) {
TaskNode* task_node = item.second;
if (task_node->upstream().empty()) {
origin_sources.emplace_back(task_node);
}
if (task_node->downstream().empty()) {
origin_sinks.emplace_back(task_node);
}
}
// link source node with origin source
for (const auto& node : origin_sources) {
source->AddDownstreamTask(node->task_id(),
std::numeric_limits<int64_t>::max());
node->AddUpstreamTask(SOURCE_ID, std::numeric_limits<int64_t>::max());
}
// link sink node with origin sink
for (const auto& node : origin_sinks) {
sink->AddUpstreamTask(node->task_id(), std::numeric_limits<int64_t>::max());
node->AddDownstreamTask(SINK_ID, std::numeric_limits<int64_t>::max());
}
// create source and sink interceptor
SetInterceptor(SOURCE_ID,
InterceptorFactory::Create("Source", SOURCE_ID, source));
SetInterceptor(SINK_ID, InterceptorFactory::Create("Sink", SINK_ID, sink));
// create each Interceptor // create each Interceptor
// no auto init since there is no config // no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) { for (const auto& item : interceptor_id_to_node_) {
...@@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() { ...@@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() {
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< " with type: " << task_node->type() << "."; << " with type: " << task_node->type() << ".";
if (task_node->upstream().empty()) { PADDLE_ENFORCE_EQ(
source_interceptor_ids_.emplace_back(interceptor_id); task_node->upstream().empty(),
} false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as source nodes"));
PADDLE_ENFORCE_EQ(task_node->downstream().empty(),
false,
platform::errors::PreconditionNotMet(
"There should not have normal nodes as sink nodes"));
} }
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h" #include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
...@@ -60,7 +61,8 @@ class Carrier final { ...@@ -60,7 +61,8 @@ class Carrier final {
framework::Scope* scope, framework::Scope* scope,
int64_t num_micro_batches, int64_t num_micro_batches,
const platform::Place& place, const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars = {}); const std::vector<std::string>& inference_root_scope_vars = {},
const std::vector<framework::Scope*>& micro_scope_list = {});
void CopyParameters( void CopyParameters(
int microbatch_id, int microbatch_id,
...@@ -100,8 +102,6 @@ class Carrier final { ...@@ -100,8 +102,6 @@ class Carrier final {
std::unordered_map<int64_t, std::unique_ptr<Interceptor>> std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_; interceptor_idx_to_interceptor_;
std::vector<int64_t> source_interceptor_ids_;
bool is_init_{false}; bool is_init_{false};
std::mutex running_mutex_; std::mutex running_mutex_;
......
...@@ -18,10 +18,85 @@ ...@@ -18,10 +18,85 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/core/utils/dim.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
namespace {
template <typename T>
void SetVarResult(const std::string& name,
T value,
int64_t scope_id,
framework::Scope* scope,
const platform::Place& place,
const std::vector<int64_t>& dim_vec) {
auto* var = scope->FindVar(name);
auto* tensor = var->GetMutable<phi::DenseTensor>();
if (!var) {
VLOG(3) << "Create var and memory for var " << name;
var = scope->Var(name);
phi::DDim dims = phi::make_ddim(dim_vec);
tensor->Resize(dims);
tensor->mutable_data<T>(dims, place);
}
PADDLE_ENFORCE_EQ(
tensor->dims().size(),
1,
platform::errors::OutOfRange("Only support transfer size 1 value."));
PADDLE_ENFORCE_EQ(
tensor->dims().at(0),
1,
platform::errors::OutOfRange("Only support transfer size 1 value."));
if (platform::is_gpu_place(tensor->place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
auto dim = phi::make_ddim({1});
cpu_tensor.mutable_data<T>(dim, platform::CPUPlace());
auto* cpu_tensor_ptr = cpu_tensor.data<T>();
cpu_tensor_ptr[0] = value;
framework::TensorCopySync(cpu_tensor, tensor->place(), tensor);
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
}
template <typename T>
T GetVarResult(const std::string& name,
int64_t scope_id,
framework::Scope* scope) {
auto* var = scope->FindVar(name);
PADDLE_ENFORCE(var,
platform::errors::NotFound(
"Variable %s not exists in scope %ld", name, scope_id));
const auto& tensor = var->Get<phi::DenseTensor>();
T res;
if (platform::is_gpu_place(tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopySync(tensor, platform::CPUPlace(), &cpu_tensor);
res = cpu_tensor.data<T>()[0];
#endif
} else if (platform::is_cpu_place(tensor.place())) {
res = tensor.data<T>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
} // namespace
ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) { : Interceptor(interceptor_id, node) {
PrepareDeps(); PrepareDeps();
...@@ -33,57 +108,49 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -33,57 +108,49 @@ void ComputeInterceptor::PrepareDeps() {
auto& downstream = node_->downstream(); auto& downstream = node_->downstream();
for (auto up : upstream) { for (auto up : upstream) {
in_readys_.emplace(up.first, std::make_pair(up.second, 0)); std::map<int64_t, int64_t> ready_size_map;
in_stops_.emplace(up.first, false); for (int64_t i = 0; i < node_->max_run_times(); ++i) {
ready_size_map.emplace(i, 0);
}
in_readys_.emplace(up.first, std::make_pair(up.second, ready_size_map));
} }
for (auto down : downstream) { for (auto down : downstream) {
out_buffs_.emplace(down.first, std::make_pair(down.second, 0)); out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
} }
// source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) {
is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(),
0,
platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
in_readys_.emplace(-1,
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_ = downstream.empty();
} }
void ComputeInterceptor::IncreaseReady(int64_t up_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) {
auto it = in_readys_.find(up_id); auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, PADDLE_ENFORCE_NE(it,
in_readys_.end(), in_readys_.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id)); "Cannot find upstream=%lld in in_readys.", up_id));
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) {
it->second.second += GetTaskNode()->max_run_times();
return;
}
auto max_ready_size = it->second.first; auto max_ready_size = it->second.first;
auto ready_size = it->second.second; const auto& ready_scope_map = it->second.second;
ready_size += 1; int64_t ready_size = 0;
PADDLE_ENFORCE_LE(ready_size, for (auto& scope_iter : ready_scope_map) {
max_ready_size, ready_size += scope_iter.second;
platform::errors::OutOfRange( }
"upstream=%lld ready_size must <= max_ready_size, but " if (max_ready_size != INFINITE_BUFFER_SIZE) {
"now ready_size=%lld, max_ready_size=%lld", PADDLE_ENFORCE_LE(
up_id, ready_size,
ready_size, max_ready_size,
max_ready_size)); platform::errors::OutOfRange(
it->second.second = ready_size; "upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id,
ready_size,
max_ready_size));
}
PADDLE_ENFORCE_NE(
it->second.second.find(scope_id),
it->second.second.end(),
platform::errors::OutOfRange(
"Interceptor %lld can not find scope %lld in upstream ready map",
interceptor_id_,
scope_id));
it->second.second.at(scope_id) = ready_scope_map.at(scope_id) + 1;
} }
void ComputeInterceptor::DecreaseBuff(int64_t down_id) { void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
...@@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) { ...@@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
} }
bool ComputeInterceptor::IsInputReady() { bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) { for (int64_t i = 0; i < node_->max_run_times(); ++i) {
auto ready_size = ins.second.second; bool flag = true;
// not ready, return false for (auto& ins : in_readys_) {
if (ready_size == 0) { auto ready_size_map = ins.second.second;
VLOG(3) << "Interceptor " << GetInterceptorId() flag = flag && (ready_size_map.at(i) != 0);
}
if (flag) {
for (auto iter : scope_id_to_finish_flag_) {
if (iter.first == i) {
break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first;
return false;
}
}
cur_scope_id_ = i;
return true;
} else {
VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i
<< "'s upstreams aren't all ready."; << "'s upstreams aren't all ready.";
return false;
} }
} }
return true; return false;
} }
bool ComputeInterceptor::CanWriteOutput() { bool ComputeInterceptor::CanWriteOutput() {
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto max_buffer_size = outs.second.first; auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second; auto used_size = outs.second.second;
if (max_buffer_size == INFINITE_BUFFER_SIZE) {
continue;
}
// full, return false // full, return false
if (used_size == max_buffer_size) { if (used_size == max_buffer_size) {
VLOG(3) << "Interceptor " << GetInterceptorId() VLOG(3) << "Interceptor " << GetInterceptorId()
...@@ -137,30 +222,76 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -137,30 +222,76 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto max_buff_size = outs.second.first; auto max_buff_size = outs.second.first;
auto used_size = outs.second.second; auto used_size = outs.second.second;
used_size += 1; used_size += 1;
PADDLE_ENFORCE_LE( if (max_buff_size != INFINITE_BUFFER_SIZE) {
used_size, PADDLE_ENFORCE_LE(
max_buff_size, used_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= " max_buff_size,
"max_buff_size, but now used_size=%lld, " platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size=%lld", "max_buff_size, but now used_size=%lld, "
down_id, "max_buff_size=%lld",
used_size, down_id,
max_buff_size)); used_size,
max_buff_size));
}
outs.second.second = used_size; outs.second.second = used_size;
InterceptorMessage ready_msg; bool need_send_vars = !(node_->vars_to_dtype().empty());
ready_msg.set_message_type(DATA_IS_READY); if (need_send_vars) {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ InterceptorMessage ready_msg = PrepareVarsMsg();
<< " Send data_is_ready msg to " << down_id VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " for step: " << step_; << " Send data_with_vars msg to " << down_id
Send(down_id, ready_msg); << " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
} else {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_;
Send(down_id, ready_msg);
}
}
}
InterceptorMessage ComputeInterceptor::PrepareVarsMsg() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* scope = microbatch_scopes_[cur_scope_id_];
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_WITH_VARS);
ready_msg.set_scope_idx(cur_scope_id_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
for (auto iter : node_->vars_to_dtype()) {
VarList* vars = ready_msg.add_vars_list();
const auto& var_name = iter.first;
vars->set_name(var_name);
std::ostringstream ss;
auto& dev_ctx = *pool.Get(place_);
auto* var = scope->FindVar(var_name);
PADDLE_ENFORCE(
var,
platform::errors::NotFound(
"Variable %s not exists in scope %ld", var_name, cur_scope_id_));
const auto& tensor = var->Get<phi::DenseTensor>();
SerializeToStream(ss, tensor, dev_ctx);
vars->set_stensor(ss.str());
VLOG(3) << "Prepare vars msg " << var_name << " with dimension "
<< tensor.dims() << " dtype " << tensor.dtype();
} }
return ready_msg;
} }
void ComputeInterceptor::ReplyCompletedToUpStream() { void ComputeInterceptor::ReplyCompletedToUpStream() {
for (auto& ins : in_readys_) { for (auto& ins : in_readys_) {
auto up_id = ins.first; auto up_id = ins.first;
auto ready_size = ins.second.second; auto ready_size = ins.second.second.at(cur_scope_id_);
ready_size -= 1; ready_size -= 1;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ready_size, ready_size,
...@@ -169,109 +300,114 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -169,109 +300,114 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
"upstream=%lld ready_size must >= 0, but now got %lld", "upstream=%lld ready_size must >= 0, but now got %lld",
up_id, up_id,
ready_size)); ready_size));
ins.second.second = ready_size; ins.second.second[cur_scope_id_] = ready_size;
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id << " Reply data_is_useless msg to " << up_id
<< " for step: " << step_; << " in scope: " << cur_scope_id_;
if (is_source_ && up_id == -1) return;
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS); reply_msg.set_message_type(DATA_IS_USELESS);
reply_msg.set_scope_idx(cur_scope_id_);
Send(up_id, reply_msg); Send(up_id, reply_msg);
} }
} }
void ComputeInterceptor::RunOps() { void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time.";
for (auto op : node_->ops()) { for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_); PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
op->Run(*microbatch_scopes_[cur_scope_id_], place_);
if (gc_) { if (gc_) {
framework::DeleteUnusedTensors( framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
*microbatch_scopes_[step_ % node_->max_run_times()], op,
op, node_->unused_vars(),
node_->unused_vars(), gc_.get());
gc_.get());
} }
} }
} }
void ComputeInterceptor::Run() { void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) { while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; VLOG(3) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_;
RunOps(); RunOps();
++step_;
if (!scope_id_to_finish_flag_.empty()) {
PADDLE_ENFORCE_NE(
scope_id_to_finish_flag_.find(cur_scope_id_),
scope_id_to_finish_flag_.end(),
platform::errors::NotFound(
"Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
scope_id_to_finish_flag_.erase(cur_scope_id_);
}
// send to downstream and increase buff used // send to downstream and increase buff used
SendDataReadyToDownStream(); SendDataReadyToDownStream();
// reply to upstream and decrease ready data // reply to upstream and decrease ready data
ReplyCompletedToUpStream(); ReplyCompletedToUpStream();
// Try to stop Carrier
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier.";
// FIXME(wangxi): with multi sink interceptor
StopCarrier();
}
} }
} }
void ComputeInterceptor::ReceivedStop(int64_t up_id) { void ComputeInterceptor::DecodeMsgVars(const InterceptorMessage& msg) {
received_stop_ = true; int64_t scope_id = msg.scope_idx();
PADDLE_ENFORCE_LT(scope_id,
// source node has no upstream, stop is send by carrier or others microbatch_scopes_.size(),
if (is_source_ && up_id == -1) return; platform::errors::InvalidArgument(
"Step out of range. There are %ld "
auto it = in_stops_.find(up_id); "microbatch_scopes, but recevice scope index %ld",
PADDLE_ENFORCE_NE(it, microbatch_scopes_.size(),
in_stops_.end(), scope_id));
platform::errors::NotFound( auto* scope = microbatch_scopes_[scope_id];
"Cannot find upstream=%lld in in_stops.", up_id)); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
PADDLE_ENFORCE_EQ( for (const auto& var_iter : msg.vars_list()) {
it->second, const std::string& name = var_iter.name();
false, auto& dev_ctx = *pool.Get(place_);
platform::errors::AlreadyExists("Already received stop from %lld, stop " std::istringstream ss(var_iter.stensor());
"cannot be send more than once.")); auto* var = scope->Var(name);
it->second = true; auto* tensor = var->GetMutable<phi::DenseTensor>();
} DeserializeFromStream(ss, tensor, dev_ctx);
void ComputeInterceptor::TryStop() { VLOG(3) << "Set vars " << name << " with value in scope " << scope_id
if (!received_stop_) return; << " with dims " << tensor->dims() << " with dtype "
<< tensor->dtype();
// can stop only when all upstream is stop and
// downstream complete
for (auto& in_stop : in_stops_) {
if (!in_stop.second) return;
}
for (auto& out_buff : out_buffs_) {
auto used_size = out_buff.second.second;
if (used_size != 0) return;
}
// send stop to downstream
for (auto& out : out_buffs_) {
auto down_id = out.first;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(down_id, stop);
} }
stop_ = true;
} }
void ComputeInterceptor::Compute(const InterceptorMessage& msg) { void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id()); VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run(); Run();
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_is_useless " << msg.src_id() << " "
<< msg.scope_idx() << " ";
DecreaseBuff(msg.src_id()); DecreaseBuff(msg.src_id());
Run(); Run();
} else if (msg.message_type() == STOP) { } else if (msg.message_type() == DATA_WITH_VARS) {
ReceivedStop(msg.src_id()); VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_with_vars " << msg.src_id() << " "
<< msg.scope_idx() << " ";
DecodeMsgVars(msg);
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == START_LOOP) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive start_loop " << msg.src_id() << " " << msg.scope_idx()
<< " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
scope_id_to_finish_flag_.emplace(msg.scope_idx(), false);
Run();
} }
TryStop();
} }
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <queue>
#include <utility> #include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...@@ -21,6 +22,8 @@ ...@@ -21,6 +22,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
const int64_t INFINITE_BUFFER_SIZE = -1;
class ComputeInterceptor : public Interceptor { class ComputeInterceptor : public Interceptor {
public: public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node); ComputeInterceptor(int64_t interceptor_id, TaskNode* node);
...@@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor { ...@@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor {
virtual void RunOps(); virtual void RunOps();
virtual void SendDataReadyToDownStream(); virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream(); virtual void ReplyCompletedToUpStream();
virtual void Compute(const InterceptorMessage& msg);
void Run();
void IncreaseReady(int64_t up_id, int64_t scope_id);
void DecreaseBuff(int64_t down_id);
int64_t cur_scope_id_;
int64_t step_{0}; // upstream_id-->(max_ready_size, scope-->ready_size)
std::map<int64_t, std::pair<int64_t, std::map<int64_t, int64_t>>>
in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
private: private:
void PrepareDeps(); void PrepareDeps();
InterceptorMessage PrepareVarsMsg();
void DecodeMsgVars(const InterceptorMessage& msg);
void IncreaseReady(int64_t up_id);
void DecreaseBuff(int64_t down_id);
bool IsInputReady(); bool IsInputReady();
bool CanWriteOutput(); bool CanWriteOutput();
std::map<int64_t, bool> scope_id_to_finish_flag_;
void Run();
void Compute(const InterceptorMessage& msg);
void ReceivedStop(int64_t up_id);
void TryStop();
bool is_source_{false};
bool is_last_{false};
// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
bool received_stop_{false};
std::map<int64_t, bool> in_stops_{};
}; };
} // namespace distributed } // namespace distributed
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void CondInterceptor::PrepareDeps() {
auto& upstream = node_->upstream();
auto& downstream = node_->downstream();
auto& id_to_dep_type = node_->id_to_dep_type();
for (const auto& up : upstream) {
if (id_to_dep_type.at(up.first) == DependType::NORMAL) {
normal_in_id_.insert(up.first);
} else if (id_to_dep_type.at(up.first) == DependType::LOOP) {
loop_id_ = up.first;
}
}
for (const auto& down : downstream) {
if (id_to_dep_type.at(down.first) == DependType::NORMAL) {
normal_out_id_.insert(down.first);
} else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) {
stop_loop_id_ = down.first;
}
}
}
bool CondInterceptor::GetCondResult() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* cond_var =
microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var());
PADDLE_ENFORCE(cond_var,
platform::errors::NotFound(
"Condition variable %s not exists in scope %ld",
node_->cond_var(),
cur_scope_id_));
const auto& cond_tensor = cond_var->Get<phi::DenseTensor>();
bool res = false;
if (platform::is_gpu_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::SendStartLoop(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(START_LOOP);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_USELESS);
ready_msg.set_scope_idx(cur_scope_id_);
Send(up_id, ready_msg);
}
void CondInterceptor::Compute() {
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_;
for (auto& down_id : normal_out_id_) {
SendStartLoop(down_id);
}
++num_of_scopes_;
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY ||
msg.message_type() == DATA_WITH_VARS) {
if (msg.src_id() == loop_id_) {
--num_of_scopes_;
VLOG(3) << "Receving loop again message from " << msg.src_id()
<< " waiting other " << num_of_scopes_ << " scopes ready";
ready_scope_id_.emplace_back(msg.scope_idx());
if (num_of_scopes_ == 0) {
std::sort(ready_scope_id_.begin(), ready_scope_id_.end());
for (auto scope_id : ready_scope_id_) {
VLOG(3) << "Start a new loop in scope " << scope_id;
cur_scope_id_ = scope_id;
Compute();
}
ready_scope_id_.clear();
}
} else {
cur_scope_id_ = msg.scope_idx();
Compute();
}
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) {
ReplyDataIsUseless(up_id);
}
// Gc the variable in while block
int64_t scope_id = msg.scope_idx();
if (gc_) {
VLOG(3) << "Release vars in while block in scope " << scope_id;
framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id],
node_->while_block_vars(),
gc_.get());
}
}
}
}
REGISTER_INTERCEPTOR(Cond, CondInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iomanip>
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class CondInterceptor final : public Interceptor {
public:
CondInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void PrepareDeps();
void Run(const InterceptorMessage& msg);
void Compute();
bool GetCondResult();
void SendDataReady(int64_t down_id);
void SendStartLoop(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id);
int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_;
std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_;
int64_t loop_id_;
int64_t num_of_scopes_{0};
std::vector<int64_t> ready_scope_id_;
};
} // namespace distributed
} // namespace paddle
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -24,6 +26,7 @@ ...@@ -24,6 +26,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() { ...@@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
} }
} }
void FleetExecutor::Init( namespace {
const std::string& carrier_id, void GetSubBlockTask(const std::vector<TaskNode*>& tasks,
const framework::ProgramDesc& program_desc, TaskNode* cur_task,
framework::Scope* scope, std::set<TaskNode*>* sub_block_task) {
const platform::Place& place, auto& downstream = cur_task->downstream();
int64_t num_micro_batches, auto& id_to_dep_type = cur_task->id_to_dep_type();
const std::vector<TaskNode*>& task_nodes, for (auto& down : downstream) {
const std::unordered_map<int64_t, int64_t>& task_id_to_rank, int64_t task_id = down.first;
const std::vector<std::string>& inference_root_scope_vars) { if (id_to_dep_type.at(task_id) == DependType::NORMAL) {
PADDLE_ENFORCE_GT(task_nodes.size(), for (const auto& task : tasks) {
0, if (task->task_id() == task_id) {
platform::errors::InvalidArgument( sub_block_task->emplace(task);
"Fleet executor is inited with empty task node")); GetSubBlockTask(tasks, task, sub_block_task);
// TODO(fleet_exe devs): the unused_vars should be got from run time graph }
std::vector<std::unique_ptr<framework::OperatorBase>> ops; }
for (auto task_node : task_nodes) {
for (auto op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
} }
} }
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); }
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the void PreventVarsDelete(
// inf. If they are GCed, it will cause error during ZeroCopy the result. std::unordered_map<const framework::OperatorBase*,
std::vector<std::string>>* unused_vars,
const std::vector<std::string>& vars_not_gc) {
std::vector<const framework::OperatorBase*> changed_ops; std::vector<const framework::OperatorBase*> changed_ops;
for (auto pair : unused_vars) {
for (const auto& pair : *unused_vars) {
const framework::OperatorBase* op = pair.first; const framework::OperatorBase* op = pair.first;
std::vector<std::string> unused = pair.second; std::vector<std::string> cur_unused = pair.second;
for (auto name : inference_root_scope_vars) { for (auto name : vars_not_gc) {
auto iter = std::find(unused.begin(), unused.end(), name); auto iter = std::find(cur_unused.begin(), cur_unused.end(), name);
if (iter != unused.end()) { if (iter != cur_unused.end()) {
VLOG(3) << "Removing var: [" << name VLOG(3) << "Removing var: [" << name
<< "] from the unused vars list of op: [" << op->Type() << "]"; << "] from the unused vars list of op: [" << op->Type() << "]";
unused.erase(iter); cur_unused.erase(iter);
if (std::find(changed_ops.begin(), changed_ops.end(), op) == if (std::find(changed_ops.begin(), changed_ops.end(), op) ==
changed_ops.end()) { changed_ops.end()) {
// record the op whose unused vars have been updated // record the op whose unused vars have been updated
...@@ -93,28 +96,120 @@ void FleetExecutor::Init( ...@@ -93,28 +96,120 @@ void FleetExecutor::Init(
} }
} }
// update the unused vars list in the map // update the unused vars list in the map
unused_vars[op] = unused; unused_vars->at(op) = cur_unused;
} }
for (auto op : changed_ops) { for (auto op : changed_ops) {
auto iter = unused_vars.find(op); const auto& iter = unused_vars->find(op);
if (iter->second.empty()) { if (iter->second.empty()) {
// remove those ops in the map that have empty unused vars list // remove those ops in the map that have empty unused vars list
VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map."; VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map.";
unused_vars.erase(iter); unused_vars->erase(iter);
}
}
}
std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
PreventVarsDelete(&unused_vars, vars_not_gc);
for (const auto& pair : unused_vars) {
if (pair.first->Type() == "while") {
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
}
}
return while_block_vars;
}
} // namespace
void FleetExecutor::Init(
const std::string& carrier_id,
const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
PADDLE_ENFORCE_GT(task_nodes.size(),
0,
platform::errors::InvalidArgument(
"Fleet executor is inited with empty task node"));
// Set the unused var after running while op
std::set<TaskNode*> sub_block_tasks;
std::vector<std::string> while_block_vars;
for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = GetUnusedVarsAfterWhile(
program_desc, task_node, inference_root_scope_vars);
VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) {
VLOG(3) << var;
}
task_node->SetWhileBlockVars(while_block_vars);
}
}
std::vector<framework::OperatorBase*> sub_block_ops;
for (const auto& task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
sub_block_ops.emplace_back(op);
} }
} }
// Analyse the unused vars in block 0. The operators in block 1
// should be passed in first for prevent vars been released but removed soon.
// Since the unused vars in block 1 need to analyse separately.
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (const auto& task_node : task_nodes) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto global_unused_vars =
framework::GetUnusedVars(program_desc.Block(0), ops, {});
for (auto& unique_op : ops) {
unique_op.release();
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
PreventVarsDelete(&global_unused_vars, inference_root_scope_vars);
runtime_graph_ = std::make_shared<RuntimeGraph>(); runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars); if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars);
}
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
} }
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
for (auto& unique_op : ops) {
unique_op.release();
}
VLOG(5) << runtime_graph_->DebugString(); VLOG(5) << runtime_graph_->DebugString();
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
...@@ -126,7 +221,8 @@ void FleetExecutor::Init( ...@@ -126,7 +221,8 @@ void FleetExecutor::Init(
place, place,
num_micro_batches, num_micro_batches,
program_desc, program_desc,
inference_root_scope_vars); inference_root_scope_vars,
micro_scope_list);
GlobalVal<MessageBus>::Get()->Barrier(); GlobalVal<MessageBus>::Get()->Barrier();
} }
...@@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier( ...@@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier(
const platform::Place& place, const platform::Place& place,
int64_t num_micro_batches, int64_t num_micro_batches,
const framework::ProgramDesc& program_desc, const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars,
const std::vector<framework::Scope*>& micro_scope_list) {
carrier->Init(exe_desc_.cur_rank(), carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(), runtime_graph_->interceptor_id_to_node(),
...@@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier( ...@@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier(
scope, scope,
num_micro_batches, num_micro_batches,
place, place,
inference_root_scope_vars); inference_root_scope_vars,
micro_scope_list);
} }
void FleetExecutor::InitMessageBus() { void FleetExecutor::InitMessageBus() {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -45,7 +46,8 @@ class FleetExecutor final { ...@@ -45,7 +46,8 @@ class FleetExecutor final {
int64_t num_micro_batches, int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes, const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank, const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
const std::vector<std::string>& inference_root_scope_vars = {}); const std::vector<std::string>& inference_root_scope_vars = {},
const std::vector<framework::Scope*>& micro_scope_list = {});
void Run(const std::string& carrier_id); void Run(const std::string& carrier_id);
private: private:
...@@ -57,7 +59,8 @@ class FleetExecutor final { ...@@ -57,7 +59,8 @@ class FleetExecutor final {
const platform::Place& place, const platform::Place& place,
int64_t num_micro_batches, int64_t num_micro_batches,
const framework::ProgramDesc& program_desc, const framework::ProgramDesc& program_desc,
const std::vector<std::string>& inference_root_scope_vars = {}); const std::vector<std::string>& inference_root_scope_vars = {},
const std::vector<framework::Scope*>& micro_scope_list = {});
FleetExecutorDesc exe_desc_; FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_; std::shared_ptr<RuntimeGraph> runtime_graph_;
std::unordered_set<std::string> carrier_ids_; std::unordered_set<std::string> carrier_ids_;
......
...@@ -93,7 +93,6 @@ class Interceptor { ...@@ -93,7 +93,6 @@ class Interceptor {
TaskNode* node_; TaskNode* node_;
// for stop // for stop
bool stop_{false};
void StopCarrier(); void StopCarrier();
// for runtime // for runtime
...@@ -114,9 +113,6 @@ class Interceptor { ...@@ -114,9 +113,6 @@ class Interceptor {
std::mutex mutex_; std::mutex mutex_;
std::deque<InterceptorMessage> messages_; std::deque<InterceptorMessage> messages_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
}; };
class InterceptorFactory { class InterceptorFactory {
......
...@@ -24,6 +24,21 @@ enum MessageType { ...@@ -24,6 +24,21 @@ enum MessageType {
ERR = 4; // current Interceptor encounters error ERR = 4; // current Interceptor encounters error
RESET = 5; // reset the status RESET = 5; // reset the status
START = 6; START = 6;
DATA_WITH_VARS = 7;
START_LOOP = 8;
}
enum ValueType {
INT3 = 0;
INT6 = 1;
FLOAT = 2;
DOUBLE = 3;
BOOL = 4;
}
message VarList {
required string name = 1;
required string stensor = 2;
} }
message InterceptorMessage { message InterceptorMessage {
...@@ -32,6 +47,7 @@ message InterceptorMessage { ...@@ -32,6 +47,7 @@ message InterceptorMessage {
optional MessageType message_type = 3 [ default = RESET ]; optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ]; optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ]; optional int64 scope_idx = 5 [ default = 0 ];
repeated VarList vars_list = 6;
} }
message InterceptorResponse { optional bool rst = 1 [ default = false ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
......
...@@ -25,7 +25,7 @@ namespace distributed { ...@@ -25,7 +25,7 @@ namespace distributed {
* 1. record the num of micro-step * 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished * 2. check whether to notify carrier the current step is finished
*/ */
class SinkInterceptor : public Interceptor { class SinkInterceptor final : public Interceptor {
public: public:
SinkInterceptor(int64_t interceptor_id, TaskNode* node); SinkInterceptor(int64_t interceptor_id, TaskNode* node);
......
...@@ -25,7 +25,7 @@ namespace distributed { ...@@ -25,7 +25,7 @@ namespace distributed {
* 1. receive `start` message from carrier * 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream * 2. send num_of_steps `data_is_ready` message to downstream
*/ */
class SourceInterceptor : public Interceptor { class SourceInterceptor final : public Interceptor {
public: public:
SourceInterceptor(int64_t interceptor_id, TaskNode* node); SourceInterceptor(int64_t interceptor_id, TaskNode* node);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
StartInterceptor::StartInterceptor(int64_t interceptor_id, TaskNode* node)
: ComputeInterceptor(interceptor_id, node) {
auto& downstream = node_->downstream();
PADDLE_ENFORCE_EQ(
downstream.size(),
1,
platform::errors::OutOfRange(
"The downstream for StartInterceptor only support 1 for now."));
for (auto down : downstream) {
batch_size_ = down.second;
}
bool evenly_divisible = ((node_->max_run_times() % batch_size_) == 0);
PADDLE_ENFORCE(
evenly_divisible,
platform::errors::Fatal(
"Wrong config: Num of step should be divided by batch_size,"
"num_step=%lld, batch_size=%lld",
node_->max_run_times(),
batch_size_));
}
void StartInterceptor::RunOps() {
finish_count_++;
ComputeInterceptor::RunOps();
}
void StartInterceptor::SendDataReadyToDownStream() {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
if (max_buff_size != INFINITE_BUFFER_SIZE) {
PADDLE_ENFORCE_LE(
used_size,
max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id,
used_size,
max_buff_size));
}
outs.second.second = used_size;
}
if (finish_count_ == batch_size_) {
for (int64_t i = 0; i < batch_size_; ++i) {
int64_t scope_id = step_ % node_->max_run_times();
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_id);
VLOG(3) << "StartInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << scope_id;
Send(down_id, ready_msg);
}
step_++;
}
}
}
void StartInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
VLOG(3) << "Start interceptor " << interceptor_id_
<< " receive data_is_ready " << msg.src_id() << " "
<< msg.scope_idx() << " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == DATA_IS_USELESS) {
VLOG(3) << "Start interceptor receive data_is_useless " << msg.src_id()
<< " " << finish_count_;
finish_count_--;
if (finish_count_ == 0) {
for (int64_t i = 0; i < batch_size_; ++i) {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
DecreaseBuff(down_id);
}
}
for (int64_t i = 0; i < batch_size_; ++i) {
Run();
}
}
}
}
REGISTER_INTERCEPTOR(Start, StartInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace paddle {
namespace distributed {
class StartInterceptor final : public ComputeInterceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void SendDataReadyToDownStream() override;
void RunOps() override;
void Compute(const InterceptorMessage& msg) override;
int64_t batch_size_{0};
int64_t finish_count_{0};
int64_t step_{0};
};
} // namespace distributed
} // namespace paddle
...@@ -24,33 +24,14 @@ namespace { ...@@ -24,33 +24,14 @@ namespace {
using OperatorBase = TaskNode::OperatorBase; using OperatorBase = TaskNode::OperatorBase;
} }
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: program_(program), : program_(program),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {
max_slot_nums_(max_slot_nums) {
// TODO(liyurui): Will be removed when execute program is supported. // TODO(liyurui): Will be removed when execute program is supported.
Init(); Init();
} }
...@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, ...@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank) TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) { : program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1; max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO) LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: " << "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank << rank
...@@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) { ...@@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_ = program; program_ = program;
} }
void TaskNode::SetVarsToDtype(
const std::map<std::string, std::string>& vars_to_dtype) {
vars_to_dtype_ = vars_to_dtype;
}
void TaskNode::SetVarsToShape(
const std::map<std::string, std::vector<int64_t>>& vars_to_shape) {
vars_to_shape_ = vars_to_shape;
}
void TaskNode::Init(bool use_feed_fetch_ops) { void TaskNode::Init(bool use_feed_fetch_ops) {
if (!use_feed_fetch_ops) { if (!use_feed_fetch_ops) {
VLOG(3) << "TaskNode will be inited without feed and fetch ops"; VLOG(3) << "TaskNode will be inited without feed and fetch ops";
...@@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role, ...@@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: role_(role), : role_(role),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {
max_slot_nums_(max_slot_nums) {
if (op_descs.empty()) { if (op_descs.empty()) {
return; return;
} }
...@@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role, ...@@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops, const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: ops_(ops), : ops_(ops),
role_(role), role_(role),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {}
max_slot_nums_(max_slot_nums) {}
TaskNode::TaskNode(int32_t role, TaskNode::TaskNode(int32_t role,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: role_(role), : role_(role),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {}
max_slot_nums_(max_slot_nums) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) { bool TaskNode::AddUpstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = upstream_.emplace(task_id, buff_size); const auto& ret = upstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second; return ret.second;
} }
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) { bool TaskNode::AddDownstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = downstream_.emplace(task_id, buff_size); const auto& ret = downstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second; return ret.second;
} }
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -29,38 +31,30 @@ class OpDesc; ...@@ -29,38 +31,30 @@ class OpDesc;
} // namespace framework } // namespace framework
namespace distributed { namespace distributed {
enum class DependType { NORMAL, LOOP, STOP_LOOP };
class TaskNode final { class TaskNode final {
public: public:
using OperatorBase = paddle::framework::OperatorBase; using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times); TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role, TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times);
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times);
int64_t max_slot_nums);
TaskNode(int32_t role, TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops, const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times);
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank); TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
// TODO(liyurui): This will be the only constructor for task node // TODO(liyurui): This will be the only constructor for task node
TaskNode(paddle::framework::ProgramDesc* program, TaskNode(paddle::framework::ProgramDesc* program,
int64_t task_id, int64_t task_id,
int64_t rank, int64_t rank,
int64_t max_run_times, int64_t max_run_times);
int64_t max_slot_nums);
~TaskNode() = default; ~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program); void SetProgram(paddle::framework::ProgramDesc* program);
...@@ -69,11 +63,11 @@ class TaskNode final { ...@@ -69,11 +63,11 @@ class TaskNode final {
int64_t task_id() const { return task_id_; } int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; } int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; } int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; } int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; } int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; } int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; } int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::string& cond_var() const { return cond_var_; }
const std::unordered_map<int64_t, int64_t>& upstream() const { const std::unordered_map<int64_t, int64_t>& upstream() const {
return upstream_; return upstream_;
} }
...@@ -86,11 +80,20 @@ class TaskNode final { ...@@ -86,11 +80,20 @@ class TaskNode final {
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const { const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_; return ops_vec_;
} }
const std::unordered_map<int64_t, DependType> id_to_dep_type() const {
return id_to_dep_type_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>& const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const { unused_vars() const {
return unused_vars_; return unused_vars_;
} }
const std::vector<std::string> while_block_vars() const {
return while_block_vars_;
}
void SetCondVarName(const std::string& cond_var_name) {
cond_var_ = cond_var_name;
}
void SetRunPerSteps(int64_t value); void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value); void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value); void SetReplyUpPerSteps(int64_t value);
...@@ -101,11 +104,27 @@ class TaskNode final { ...@@ -101,11 +104,27 @@ class TaskNode final {
unused_vars) { unused_vars) {
unused_vars_ = unused_vars; unused_vars_ = unused_vars;
} }
void SetWhileBlockVars(const std::vector<std::string>& vars) {
while_block_vars_ = vars;
}
// upstream need buffs? // upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1); bool AddUpstreamTask(int64_t task_id,
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1); int64_t buff_size = 1,
DependType type = DependType::NORMAL);
bool AddDownstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
std::string DebugString() const; std::string DebugString() const;
const std::map<std::string, std::string>& vars_to_dtype() const {
return vars_to_dtype_;
}
void SetVarsToDtype(const std::map<std::string, std::string>& vars_to_dtype);
const std::map<std::string, std::vector<int64_t>>& vars_to_shape() const {
return vars_to_shape_;
}
void SetVarsToShape(
const std::map<std::string, std::vector<int64_t>>& vars_to_shape);
private: private:
DISABLE_COPY_AND_ASSIGN(TaskNode); DISABLE_COPY_AND_ASSIGN(TaskNode);
...@@ -115,16 +134,22 @@ class TaskNode final { ...@@ -115,16 +134,22 @@ class TaskNode final {
// task_id-->buff_size // task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_; std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_; std::unordered_map<int64_t, int64_t> downstream_;
// task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_; std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>> std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_; unused_vars_;
std::vector<std::string> while_block_vars_;
std::map<std::string, std::string> vars_to_dtype_;
std::map<std::string, std::vector<int64_t>> vars_to_shape_;
int32_t role_; int32_t role_;
int64_t rank_; int64_t rank_;
int64_t task_id_; int64_t task_id_;
int64_t max_run_times_; int64_t max_run_times_;
int64_t max_slot_nums_;
int64_t run_per_steps_{1}; int64_t run_per_steps_{1};
int64_t run_at_offset_{0}; int64_t run_at_offset_{0};
......
...@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) { ...@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
// FIXME: don't delete, otherwise interceptor will use undefined node // FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a = TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2); // role, ops, rank, task_id
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 2);
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, 2); TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// source->a->b->sink // source->a->b->sink
......
...@@ -21,61 +21,49 @@ limitations under the License. */ ...@@ -21,61 +21,49 @@ limitations under the License. */
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(1, stop); // stop 1, compute
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
}
};
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}); carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id TaskNode* source =
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* node_a = new TaskNode(0, 0, 0, 3);
TaskNode* node_b = new TaskNode(0, 0, 1, 3);
// a->b->c TaskNode* sink = new TaskNode(0, SINK_ID, 3);
// source->a->b->sink
source->AddDownstreamTask(0);
node_a->AddUpstreamTask(SOURCE_ID);
node_a->AddDownstreamTask(1, 3); node_a->AddDownstreamTask(1, 3);
node_b->AddUpstreamTask(0, 3); node_b->AddUpstreamTask(0, 3);
node_b->AddDownstreamTask(2); node_b->AddDownstreamTask(SINK_ID);
node_c->AddUpstreamTask(1); sink->AddUpstreamTask(1);
Interceptor* a = carrier->SetInterceptor(
carrier->SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a)); SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(START);
// test run three times msg.set_dst_id(SOURCE_ID);
a->Send(1, msg); carrier->EnqueueInterceptorMessage(msg);
a->Send(1, msg);
a->Send(1, msg);
carrier->Wait(); carrier->Wait();
carrier->Release(); carrier->Release();
......
...@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor { ...@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) { void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) { if (msg.message_type() == STOP) {
stop_ = true;
return; return;
} }
std::cout << GetInterceptorId() << " recv msg, count=" << count_ std::cout << GetInterceptorId() << " recv msg, count=" << count_
......
...@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor { ...@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
void PingPong(const InterceptorMessage& msg) { void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) { if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier(); StopCarrier();
return; return;
} }
......
...@@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) {
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3; int64_t micro_steps = 1;
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 1); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 1);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 1);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0); TaskNode* node_d = new TaskNode(0, 0, 3, 1);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0); TaskNode* node_e = new TaskNode(0, 0, 4, 1);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0); TaskNode* node_f = new TaskNode(0, 0, 5, 1);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->e->f->sink // source->a->b->c->d->e->f->sink
......
...@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, micro_steps);
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); TaskNode* node_c = new TaskNode(0, 0, 2, micro_steps);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->sink // source->a->b->c->d->sink
......
...@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) { ...@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1); source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1); node_a->AddUpstreamTask(SOURCE_ID, 1);
......
...@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) { ...@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1); source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1); node_a->AddUpstreamTask(SOURCE_ID, 1);
......
...@@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast, ...@@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast,
ops::CBroadcastOpCUDAKernel<plat::bfloat16>, ops::CBroadcastOpCUDAKernel<plat::bfloat16>,
#endif #endif
ops::CBroadcastOpCUDAKernel<int>, ops::CBroadcastOpCUDAKernel<int>,
ops::CBroadcastOpCUDAKernel<uint8_t>,
ops::CBroadcastOpCUDAKernel<int8_t>,
ops::CBroadcastOpCUDAKernel<int64_t>, ops::CBroadcastOpCUDAKernel<int64_t>,
ops::CBroadcastOpCUDAKernel<plat::float16>); ops::CBroadcastOpCUDAKernel<plat::float16>);
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
DECLARE_bool(cudnn_deterministic);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table, ...@@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table,
} }
} }
template <typename T, typename IndexT>
__global__ void CEmbeddingGradSerial(T *table,
const T *output,
const IndexT *ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
if (i == 0) {
for (int j = 0; j < limit; j++) {
size_t row = j / columns;
size_t col = j % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
paddle::platform::CudaAtomicAdd(&table[real_idx * columns + col],
output[i]);
}
}
}
}
}
template <typename T> template <typename T>
class CEmbeddingCUDAKernel : public framework::OpKernel<T> { class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -163,28 +191,56 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> { ...@@ -163,28 +191,56 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0)); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const auto &index_type = framework::TransToProtoVarType(ids_t->dtype()); const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) { if (FLAGS_cudnn_deterministic) {
CEmbeddingGrad<T, int32_t> VLOG(2) << "Run grad kernel of embedding with single thread.";
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table, blocks = 1;
d_output, if (index_type == framework::proto::VarType::INT32) {
ids_t->data<int32_t>(), CEmbeddingGradSerial<T, int32_t>
K, <<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
D, d_output,
N, ids_t->data<int32_t>(),
start_idx, K,
end_idx, D,
limit); N,
} else if (index_type == framework::proto::VarType::INT64) { start_idx,
CEmbeddingGrad<T, int64_t> end_idx,
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table, limit);
d_output, } else if (index_type == framework::proto::VarType::INT64) {
ids_t->data<int64_t>(), CEmbeddingGradSerial<T, int64_t>
K, <<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
D, d_output,
N, ids_t->data<int64_t>(),
start_idx, K,
end_idx, D,
limit); N,
start_idx,
end_idx,
limit);
}
} else {
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int32_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbeddingGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int64_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
}
} }
} }
}; };
......
...@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> { ...@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using paddle::distributed::DependType;
using paddle::distributed::DistModel; using paddle::distributed::DistModel;
using paddle::distributed::DistModelConfig; using paddle::distributed::DistModelConfig;
using paddle::distributed::DistModelDataBuf; using paddle::distributed::DistModelDataBuf;
...@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) { ...@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
.def( .def(
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>()); "run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::enum_<DependType>(*m, "DependType")
.value("NORMAL", DependType::NORMAL)
.value("LOOP", DependType::LOOP)
.value("STOP_LOOP", DependType::STOP_LOOP);
py::class_<TaskNode>(*m, "TaskNode") py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
int64_t,
int64_t,
int64_t>())
.def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>()) .def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t, .def(py::init<int32_t,
const std::vector<framework::OpDesc*>&, const std::vector<framework::OpDesc*>&,
int64_t, int64_t,
int64_t, int64_t,
int64_t,
int64_t>()) int64_t>())
.def("task_id", &TaskNode::task_id) .def("task_id", &TaskNode::task_id)
.def("add_upstream_task", &TaskNode::AddUpstreamTask) .def("add_upstream_task", &TaskNode::AddUpstreamTask)
...@@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) { ...@@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_pre_steps", &TaskNode::SetRunPerSteps) .def("set_run_pre_steps", &TaskNode::SetRunPerSteps)
.def("set_run_at_offset", &TaskNode::SetRunAtOffset) .def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType) .def("set_type", &TaskNode::SetType)
.def("set_cond_var_name", &TaskNode::SetCondVarName)
.def("role", &TaskNode::role) .def("role", &TaskNode::role)
.def("set_vars_to_shape", &TaskNode::SetVarsToShape)
.def("set_vars_to_dtype", &TaskNode::SetVarsToDtype)
.def("init", [](TaskNode& self) { self.Init(); }) .def("init", [](TaskNode& self) { self.Init(); })
.def("set_program", &TaskNode::SetProgram); .def("set_program", &TaskNode::SetProgram);
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename InT, typename OutT> template <typename InT, typename OutT>
...@@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor { ...@@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor {
const int gridx = 2 * dev_ctx_.GetSMCount(); const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(gridx, 1); dim3 grids(gridx, 1);
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
grids.x = 1;
threads.y = 1;
}
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>( EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D); d_table, d_output, ids, N, K, D);
} }
......
...@@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False) ...@@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1) set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True) set_field_default_config(GRADIENT_MERGE, "avg", True)
#########################################
# pipeline configuration
#########################################
PIPELINE = "pipeline"
set_field_default_config(PIPELINE, "enable", False)
set_field_default_config(PIPELINE, "schedule_mode", "1F1B")
set_field_default_config(PIPELINE, "micro_batch_size", 1)
set_field_default_config(PIPELINE, "accumulate_steps", 1)
set_field_default_config(PIPELINE, "generation_batch_size", 1)
######################################### #########################################
# quantization configuration # quantization configuration
######################################### #########################################
......
...@@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode): ...@@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode):
) )
serial_startup_prog = ( serial_startup_prog = (
engine._serial_startup_progs[mode].clone() engine._fwd_dist_contexts[mode]._original_serial_main_program.clone()
if mode in engine._serial_startup_progs if mode in engine._fwd_dist_contexts
else engine._orig_startup_prog.clone() else engine._orig_startup_prog.clone()
) )
losses = ( losses = (
......
...@@ -27,7 +27,6 @@ from .utils import convert_to_shard_spec, verify_shard_spec ...@@ -27,7 +27,6 @@ from .utils import convert_to_shard_spec, verify_shard_spec
class DistributedOperator: class DistributedOperator:
def __init__(self, serial_op, dist_attr=None): def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op self._serial_op = serial_op
self._serial_inputs = {} self._serial_inputs = {}
...@@ -78,28 +77,34 @@ class DistributedOperator: ...@@ -78,28 +77,34 @@ class DistributedOperator:
if tensor is None: if tensor is None:
tensor_shape = [] tensor_shape = []
else: else:
if tensor.type == core.VarDesc.VarType.READER \ if (
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
if self._dist_attr.get_input_dims_mapping(tensor_name) is None: if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_input_dims_mapping(tensor_name, self._dist_attr.set_input_dims_mapping(
tensor_dims_mapping) tensor_name, tensor_dims_mapping
)
for tensor_name in self._serial_op.output_arg_names: for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name) tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER \ if (
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.STEP_SCOPES: or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
self._serial_outputs[tensor_name] = tensor self._serial_outputs[tensor_name] = tensor
if self._dist_attr.get_output_dims_mapping(tensor_name) is None: if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name, self._dist_attr.set_output_dims_mapping(
tensor_dims_mapping) tensor_name, tensor_dims_mapping
)
if self._dist_attr.op_type is None: if self._dist_attr.op_type is None:
self._dist_attr.op_type = self.serial_op.type self._dist_attr.op_type = self.serial_op.type
if self._dist_attr.impl_type is None: if self._dist_attr.impl_type is None:
...@@ -117,8 +122,10 @@ class DistributedOperator: ...@@ -117,8 +122,10 @@ class DistributedOperator:
new_dist_attr = {} new_dist_attr = {}
for key, value in dist_attr.items(): for key, value in dist_attr.items():
if isinstance(key, Variable): if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names \ if (
or key.name in self._serial_op.output_arg_names: key.name in self._serial_op.input_arg_names
or key.name in self._serial_op.output_arg_names
):
new_dist_attr[key] = value new_dist_attr[key] = value
else: else:
new_dist_attr[key] = value new_dist_attr[key] = value
...@@ -129,13 +136,15 @@ class DistributedOperator: ...@@ -129,13 +136,15 @@ class DistributedOperator:
for tensor_name in self._serial_op.input_arg_names: for tensor_name in self._serial_op.input_arg_names:
tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name) tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
if tensor_dist_attr: if tensor_dist_attr:
new_dist_attr.set_input_dist_attr(tensor_name, new_dist_attr.set_input_dist_attr(
tensor_dist_attr) tensor_name, tensor_dist_attr
)
for tensor_name in self._serial_op.output_arg_names: for tensor_name in self._serial_op.output_arg_names:
tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name) tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
if tensor_dist_attr: if tensor_dist_attr:
new_dist_attr.set_output_dist_attr(tensor_name, new_dist_attr.set_output_dist_attr(
tensor_dist_attr) tensor_name, tensor_dist_attr
)
else: else:
assert False, "Cannot recognize the {} parameter.".format(dist_attr) assert False, "Cannot recognize the {} parameter.".format(dist_attr)
return new_dist_attr return new_dist_attr
...@@ -146,8 +155,10 @@ class DistributedOperator: ...@@ -146,8 +155,10 @@ class DistributedOperator:
for name in self.serial_op.input_arg_names: for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name) input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping dims_mapping = input_dist_attr.dims_mapping
if self.get_serial_input( if (
name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: self.get_serial_input(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
shape = [] shape = []
else: else:
shape = self.get_serial_input(name).shape shape = self.get_serial_input(name).shape
...@@ -155,7 +166,8 @@ class DistributedOperator: ...@@ -155,7 +166,8 @@ class DistributedOperator:
return False return False
for i in range(len(dims_mapping)): for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len( if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology): self.dist_attr.process_mesh.topology
):
return False return False
for i in range(len(self.dist_attr.process_mesh.topology)): for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1: if dims_mapping.count(i) > 1:
...@@ -166,8 +178,12 @@ class DistributedOperator: ...@@ -166,8 +178,12 @@ class DistributedOperator:
for name in self.serial_op.output_arg_names: for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name) output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping dims_mapping = output_dist_attr.dims_mapping
if self.get_serial_output(name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY\ if (
or self.get_serial_output(name).type == core.VarDesc.VarType.STEP_SCOPES: self.get_serial_output(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.get_serial_output(name).type
== core.VarDesc.VarType.STEP_SCOPES
):
shape = [] shape = []
else: else:
shape = self.get_serial_output(name).shape shape = self.get_serial_output(name).shape
...@@ -175,7 +191,8 @@ class DistributedOperator: ...@@ -175,7 +191,8 @@ class DistributedOperator:
return False return False
for i in range(len(dims_mapping)): for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len( if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology): self.dist_attr.process_mesh.topology
):
return False return False
for i in range(len(self.dist_attr.process_mesh.topology)): for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1: if dims_mapping.count(i) > 1:
...@@ -185,8 +202,9 @@ class DistributedOperator: ...@@ -185,8 +202,9 @@ class DistributedOperator:
return True return True
def __str__(self): def __str__(self):
str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(), str = "{{op type: {}, op id: {}".format(
self.serial_op.desc.id()) self.serial_op.desc.type(), self.serial_op.desc.id()
)
# str += ", {}".format(self.dist_attr) # str += ", {}".format(self.dist_attr)
# return str # return str
...@@ -195,8 +213,9 @@ class DistributedOperator: ...@@ -195,8 +213,9 @@ class DistributedOperator:
annotated_str = "annotated" annotated_str = "annotated"
else: else:
annotated_str = "non-annotated" annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str, str += ", process_mesh ({}): {}".format(
self.dist_attr.process_mesh) annotated_str, self.dist_attr.process_mesh
)
for arg_name in self.serial_op.desc.input_arg_names(): for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
...@@ -212,7 +231,8 @@ class DistributedOperator: ...@@ -212,7 +231,8 @@ class DistributedOperator:
else: else:
is_parameter_str = "non-parameter" is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format( str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping) arg_name, annotated_str, is_parameter_str, dims_mapping
)
for arg_name in self.serial_op.desc.output_arg_names(): for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name) dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
...@@ -228,12 +248,14 @@ class DistributedOperator: ...@@ -228,12 +248,14 @@ class DistributedOperator:
else: else:
is_parameter_str = "non-parameter" is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format( str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping) arg_name, annotated_str, is_parameter_str, dims_mapping
)
str += ", pipeline stage: {}".format(None) str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} , dist_impl type {} }}".format( str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type) self.dist_attr._impl_idx, self.dist_attr._impl_type
)
return str return str
...@@ -242,7 +264,11 @@ class DistributedOperator: ...@@ -242,7 +264,11 @@ class DistributedOperator:
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "_serial_op" or k == "_serial_inputs" or k == "_serial_outputs": if (
k == "_serial_op"
or k == "_serial_inputs"
or k == "_serial_outputs"
):
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
...@@ -250,9 +276,9 @@ class DistributedOperator: ...@@ -250,9 +276,9 @@ class DistributedOperator:
class DistributedOperatorHelper: class DistributedOperatorHelper:
def __init__(
def __init__(self, serial_op, process_mesh, in_dims_mappings, self, serial_op, process_mesh, in_dims_mappings, out_dims_mappings
out_dims_mappings): ):
self._serial_op = serial_op self._serial_op = serial_op
self._process_mesh = process_mesh self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings self._in_dims_mappings = in_dims_mappings
...@@ -262,8 +288,11 @@ class DistributedOperatorHelper: ...@@ -262,8 +288,11 @@ class DistributedOperatorHelper:
tensor_to_dims_mapping = {} tensor_to_dims_mapping = {}
index = 0 index = 0
if self._in_dims_mappings: if self._in_dims_mappings:
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \ assert len(args) + len(kwargs) == len(
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs)) self._in_dims_mappings
), "The length of dims_mapping {} does not matching the length output {}.".format(
len(self._in_dims_mappings), len(args) + len(kwargs)
)
for arg in args: for arg in args:
if isinstance(arg, Variable) and self._in_dims_mappings: if isinstance(arg, Variable) and self._in_dims_mappings:
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index] tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
...@@ -287,13 +316,17 @@ class DistributedOperatorHelper: ...@@ -287,13 +316,17 @@ class DistributedOperatorHelper:
raise ValueError("Unrecognized outpout.") raise ValueError("Unrecognized outpout.")
if self._out_dims_mappings: if self._out_dims_mappings:
assert len(new_output) == len(self._out_dims_mappings), \ assert len(new_output) == len(
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output)) self._out_dims_mappings
), "The length of dims_mapping {} does not matching the length output {}.".format(
len(self._out_dims_mappings), len(new_output)
)
for i, item in enumerate(new_output): for i, item in enumerate(new_output):
if isinstance(item, Variable) and self._out_dims_mappings: if isinstance(item, Variable) and self._out_dims_mappings:
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i] tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size): for idx in range(op_size, new_op_size):
op = cur_block.ops[idx] op = cur_block.ops[idx]
...@@ -302,53 +335,68 @@ class DistributedOperatorHelper: ...@@ -302,53 +335,68 @@ class DistributedOperatorHelper:
if name in tensor_to_dims_mapping.keys(): if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_input(name) tensor = dist_op.get_serial_input(name)
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr( tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
name) name
)
dims_mapping = tensor_to_dims_mapping[name] dims_mapping = tensor_to_dims_mapping[name]
if tensor is None: if tensor is None:
tensor_shape = [] tensor_shape = []
else: else:
if tensor.type == core.VarDesc.VarType.READER \ if (
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.STEP_SCOPES: or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
if dims_mapping is not None: if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name] dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec( shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh) dims_mapping, self._process_mesh
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \ )
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( assert verify_shard_spec(
name, shard_spec, tensor_shape, self._process_mesh) shard_spec, tensor_shape, self._process_mesh
), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh
)
tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping") tensor_dist_attr.mark_annotated("dims_mapping")
for name in dist_op.serial_op.output_arg_names: for name in dist_op.serial_op.output_arg_names:
if name in tensor_to_dims_mapping.keys(): if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_output(name) tensor = dist_op.get_serial_output(name)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr( tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
name) name
)
dims_mapping = tensor_to_dims_mapping[name] dims_mapping = tensor_to_dims_mapping[name]
if tensor is None: if tensor is None:
tensor_shape = [] tensor_shape = []
else: else:
if tensor.type == core.VarDesc.VarType.READER \ if (
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.STEP_SCOPES: or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
if dims_mapping is not None: if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name] dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec( shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh) dims_mapping, self._process_mesh
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \ )
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( assert verify_shard_spec(
name, shard_spec, tensor_shape, self._process_mesh) shard_spec, tensor_shape, self._process_mesh
), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh
)
tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping") tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh dist_op.dist_attr.process_mesh = self._process_mesh
if self._process_mesh is not None: if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh") dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op) default_dist_ctx.add_dist_op_for_program(dist_op)
default_dist_ctx.add_process_mesh(self._process_mesh)
return output return output
...@@ -34,6 +34,7 @@ from paddle.fluid.framework import Operator, _non_static_mode ...@@ -34,6 +34,7 @@ from paddle.fluid.framework import Operator, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.parallel import _is_global_parallel_initialize
from .callbacks import config_callbacks from .callbacks import config_callbacks
from .converter import Converter from .converter import Converter
...@@ -160,7 +161,6 @@ class Engine: ...@@ -160,7 +161,6 @@ class Engine:
" or `paddle.fluid.optimizer.Optimizer`." " or `paddle.fluid.optimizer.Optimizer`."
) )
self._optimizer = validate_opt(optimizer) self._optimizer = validate_opt(optimizer)
self._orig_optimizer = copy.deepcopy(self._optimizer)
metrics = metrics or [] metrics = metrics or []
for metric in to_list(metrics): for metric in to_list(metrics):
...@@ -185,12 +185,18 @@ class Engine: ...@@ -185,12 +185,18 @@ class Engine:
self._strategy = strategy or Strategy() self._strategy = strategy or Strategy()
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
if os.getenv("POD_NAME"): if os.getenv("POD_NAME") and not _is_global_parallel_initialize():
self._logger.info( self._logger.info(
"Distribute training by paddle.distributed.launch" "Distribute training by paddle.distributed.launch"
) )
fleet.init(is_collective=True) fleet.init(is_collective=True)
# for compute cost
# TODO: remove _fwd_main_progs and _orig_optimizer
self._fwd_dist_contexts = {}
self._fwd_main_progs = {}
self._orig_optimizer = copy.deepcopy(self._optimizer)
self._executor = None self._executor = None
self._cur_rank = paddle.distributed.get_rank() self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size() self._nranks = paddle.distributed.get_world_size()
...@@ -200,14 +206,6 @@ class Engine: ...@@ -200,14 +206,6 @@ class Engine:
self._orig_startup_prog = static.default_startup_program() self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context() self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {} self._dist_contexts = {}
self._fwd_main_progs = {}
self._fwd_dist_contexts = {}
self._serial_main_progs = {}
self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs
self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._feed_vars = {}
self._fetch_vars = {}
self._planners = {} self._planners = {}
self._has_prepared = {"train": False, "eval": False, "predict": False} self._has_prepared = {"train": False, "eval": False, "predict": False}
self._has_prepared_reader = { self._has_prepared_reader = {
...@@ -338,9 +336,9 @@ class Engine: ...@@ -338,9 +336,9 @@ class Engine:
return inputs, labels return inputs, labels
def _prepare_reader(self): def _prepare_reader(self, feed_list=[]):
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# NOTE: this list may be changed if Paddle changes the existing rules. # NOTE: this list may be changed if Paddle changes the existing rules.
...@@ -361,10 +359,13 @@ class Engine: ...@@ -361,10 +359,13 @@ class Engine:
if op.type in related_reader_ops: if op.type in related_reader_ops:
reader_op_indices.append(idx) reader_op_indices.append(idx)
# Step 2: insert the new reader ops to cpp # Step 2: insert the new reader ops to cpp
# record the read ops' desc to insert to program of forward task_node
read_ops_desc = []
new_reader_ops = [] new_reader_ops = []
for idx in reversed(reader_op_indices): for idx in reversed(reader_op_indices):
new_op_desc = dist_main_block.desc._prepend_op() new_op_desc = dist_main_block.desc._prepend_op()
new_op_desc.copy_from(dist_main_block.ops[idx].desc) new_op_desc.copy_from(dist_main_block.ops[idx].desc)
read_ops_desc.append(new_op_desc)
new_op = Operator( new_op = Operator(
dist_main_block, new_op_desc, type=new_op_desc.type() dist_main_block, new_op_desc, type=new_op_desc.type()
) )
...@@ -383,6 +384,29 @@ class Engine: ...@@ -383,6 +384,29 @@ class Engine:
dist_main_block._sync_with_cpp() dist_main_block._sync_with_cpp()
self._has_prepared_reader[self._mode] = True self._has_prepared_reader[self._mode] = True
# Insert read op to forward TaskNode if 1F1B pass is setted
if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = fleet_opt["tasks"][0]
fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block()
for var in feed_list:
if var.name not in fwd_block.vars:
fwd_block._clone_variable(var)
for op_desc in read_ops_desc:
new_op_desc = fwd_block.desc._prepend_op()
new_op_desc.copy_from(op_desc)
new_op = Operator(
fwd_block, new_op_desc, type=new_op_desc.type()
)
fwd_block.ops.insert(0, new_op)
fwd_block._sync_with_cpp()
fwd_task.set_program(fwd_prog)
def _prepare_feed(self, data, user_feeds, mode): def _prepare_feed(self, data, user_feeds, mode):
feeds = {} feeds = {}
if data is not None: if data is not None:
...@@ -430,14 +454,16 @@ class Engine: ...@@ -430,14 +454,16 @@ class Engine:
fetch_names.append([]) fetch_names.append([])
fetch_indices.append(group_indices) fetch_indices.append(group_indices)
dist_context = self._dist_contexts[mode]
fetch_vars = dist_context.serial_fetch_vars
if mode != "predict": if mode != "predict":
_process_fetch_group("loss", self._fetch_vars[mode]["loss"]) _process_fetch_group("loss", fetch_vars["loss"])
if mode != "predict": if mode != "predict":
metrics = self._fetch_vars[mode]["metrics"] metrics = fetch_vars["metrics"]
for i, var_list in enumerate(metrics): for i, var_list in enumerate(metrics):
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", self._fetch_vars[mode]["outputs"]) _process_fetch_group("outputs", fetch_vars["outputs"])
user_fetches_collection = [ user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES) item[1] for item in get_collection(CollectionNames.FETCHES)
] ]
...@@ -471,7 +497,8 @@ class Engine: ...@@ -471,7 +497,8 @@ class Engine:
logs["loss"] = outs[idx][0] logs["loss"] = outs[idx][0]
group_idx += 1 group_idx += 1
# logging metrics # logging metrics
metric_vars = self._fetch_vars[mode]["metrics"] dist_context = self._dist_contexts[mode]
metric_vars = dist_context.serial_fetch_vars["metrics"]
if metric_vars: if metric_vars:
for metric in self._metrics: for metric in self._metrics:
metrics_indices = fetch_indices[group_idx] metrics_indices = fetch_indices[group_idx]
...@@ -502,15 +529,18 @@ class Engine: ...@@ -502,15 +529,18 @@ class Engine:
logs["fetches"] = logs_fetch logs["fetches"] = logs_fetch
return logs return logs
def _prepare_program(self, mode): def _prepare_program(self, mode, init_parameters=True):
# Do the build process # Do the build process
self._build(mode) self._build(mode)
# Do the planning process # Do the planning process
self._plan(mode) self._plan(mode)
# Do the parallel process # Do the parallel process
self._parallel(mode) self._parallel(mode)
# Init comm and startup program # Init comm
self._initialize(mode) self._init_comm()
if init_parameters:
# startup program
self._initialize(mode)
self._has_prepared[mode] = True self._has_prepared[mode] = True
def _build(self, mode): def _build(self, mode):
...@@ -542,8 +572,8 @@ class Engine: ...@@ -542,8 +572,8 @@ class Engine:
paddle.enable_static() paddle.enable_static()
else: else:
# build program in static mode # build program in static mode
serial_main_prog = self._serial_main_progs.get(mode, None) dist_context = self._dist_contexts.get(mode, None)
if serial_main_prog is not None: if dist_context is not None:
return return
outputs = [] outputs = []
...@@ -581,7 +611,7 @@ class Engine: ...@@ -581,7 +611,7 @@ class Engine:
metric.compute(*(outputs + self._labels)) metric.compute(*(outputs + self._labels))
) )
) )
else: elif mode == "train":
assert isinstance( assert isinstance(
self._loss, Variable self._loss, Variable
), "the type of `loss` of the Engine arguments should be Variable." ), "the type of `loss` of the Engine arguments should be Variable."
...@@ -724,37 +754,21 @@ class Engine: ...@@ -724,37 +754,21 @@ class Engine:
) )
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode): def _init_comm(self):
# Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode
].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[
mode
].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[
mode
].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[
mode
].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode]._serial_optimizer
if self._nranks > 1: if self._nranks > 1:
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
if self._strategy.auto_mode == "full": if self._strategy.auto_mode == "full":
initialize_pg_in_full_mode(all_process_groups, cur_rank) initialize_pg_in_full_mode(all_process_groups, self._cur_rank)
else: else:
for process_group in all_process_groups: for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks: if self._cur_rank not in process_group.ranks:
continue continue
process_group.instantiate() process_group.instantiate()
def _initialize(self, mode):
place = _get_device() place = _get_device()
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
place = fluid.CUDAPlace(ParallelEnv().dev_id) place = fluid.CUDAPlace(ParallelEnv().dev_id)
...@@ -764,15 +778,17 @@ class Engine: ...@@ -764,15 +778,17 @@ class Engine:
np.random.seed(self._strategy.seed + self._dp_ranks[0]) np.random.seed(self._strategy.seed + self._dp_ranks[0])
random.seed(self._strategy.seed + self._dp_ranks[0]) random.seed(self._strategy.seed + self._dp_ranks[0])
dist_context = self._dist_contexts[mode]
if self._dygraph_mode: if self._dygraph_mode:
dist_context = self._dist_contexts[mode] dist_main_program = dist_context.dist_main_programs[self._cur_rank]
dist_main_program = self._dist_main_progs[mode][self._cur_rank]
self.program_helper.init(dist_main_program, place, dist_context) self.program_helper.init(dist_main_program, place, dist_context)
if self._executor is None: if self._executor is None:
self._executor = paddle.static.Executor(place) self._executor = paddle.static.Executor(place)
uninitialized = [] uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
for var in dist_startup_prog.list_vars(): for var in dist_startup_prog.list_vars():
scope_var = global_scope().find_var(var.name) scope_var = global_scope().find_var(var.name)
if scope_var and scope_var.get_tensor()._is_initialized(): if scope_var and scope_var.get_tensor()._is_initialized():
...@@ -789,7 +805,9 @@ class Engine: ...@@ -789,7 +805,9 @@ class Engine:
if self._strategy.reinit: if self._strategy.reinit:
self._logger.info("NOTE: parameters will be re-initialized.") self._logger.info("NOTE: parameters will be re-initialized.")
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
self._executor.run(dist_startup_prog) self._executor.run(dist_startup_prog)
def fit( def fit(
...@@ -926,7 +944,7 @@ class Engine: ...@@ -926,7 +944,7 @@ class Engine:
) )
except core.EOFException: except core.EOFException:
break break
lr = get_lr(self._optimizer) lr = get_lr(self.optimizer)
logs = self._prepare_logger( logs = self._prepare_logger(
outs, outs,
epoch, epoch,
...@@ -1262,6 +1280,7 @@ class Engine: ...@@ -1262,6 +1280,7 @@ class Engine:
main_program=None, main_program=None,
startup_program=None, startup_program=None,
mode=None, mode=None,
init_parameters=True,
): ):
if mode is not None: if mode is not None:
self.to_mode(mode) self.to_mode(mode)
...@@ -1304,7 +1323,7 @@ class Engine: ...@@ -1304,7 +1323,7 @@ class Engine:
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
self._inputs, self._labels = inputs, labels self._inputs, self._labels = inputs, labels
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode, init_parameters)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
...@@ -1355,16 +1374,17 @@ class Engine: ...@@ -1355,16 +1374,17 @@ class Engine:
) )
batch_size //= self._k_steps batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_context = self._dist_contexts[self._mode]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape. # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var, # Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program, # then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values. # that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self._mode]["inputs"] inputs_var = dist_context.serial_feed_vars["inputs"]
labels_var = self._feed_vars[self._mode]["labels"] labels_var = dist_context.serial_feed_vars["labels"]
feed_list = [] feed_list = []
for var in inputs_var + labels_var: for var in inputs_var + labels_var:
if var.name in dist_main_block.vars: if var.name in dist_main_block.vars:
...@@ -1423,16 +1443,17 @@ class Engine: ...@@ -1423,16 +1443,17 @@ class Engine:
) )
batch_size //= self._k_steps batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_context = self._dist_contexts[self._mode]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape. # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var, # Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program, # then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values. # that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self._mode]["inputs"] inputs_var = dist_context.serial_feed_vars["inputs"]
labels_var = self._feed_vars[self._mode]["labels"] labels_var = dist_context.serial_feed_vars["labels"]
feed_list = [] feed_list = []
for var in inputs_var + labels_var: for var in inputs_var + labels_var:
if var.name in dist_main_block.vars: if var.name in dist_main_block.vars:
...@@ -1462,7 +1483,7 @@ class Engine: ...@@ -1462,7 +1483,7 @@ class Engine:
data_parallel_world_size=self._dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks, data_parallel_rank=self._dp_ranks,
) )
self._prepare_reader() self._prepare_reader(feed_list)
return dataloader return dataloader
def _tune(self, tune_data, tune_sample_split=None, batch_size=1): def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
...@@ -1551,10 +1572,9 @@ class Engine: ...@@ -1551,10 +1572,9 @@ class Engine:
def _switch_mode(self, mode): def _switch_mode(self, mode):
assert ( assert (
mode in self._dist_main_progs mode in self._dist_contexts
), "{} model is not ready, please call `prepare()` first.".format(mode) ), "{} model is not ready, please call `prepare()` first.".format(mode)
self.to_mode(mode) self.to_mode(mode)
self._optimizer = self._dist_contexts[mode]._serial_optimizer
def to_mode(self, mode): def to_mode(self, mode):
assert mode in [ assert mode in [
...@@ -1565,8 +1585,8 @@ class Engine: ...@@ -1565,8 +1585,8 @@ class Engine:
self._mode = mode self._mode = mode
def _set_state_dict(self, mode, strict, state_dict, dist_attr): def _set_state_dict(self, mode, strict, state_dict, dist_attr):
program = self._dist_main_progs[mode][self._cur_rank]
dist_context = self._dist_contexts[mode] dist_context = self._dist_contexts[mode]
program = dist_context.dist_main_programs[self._cur_rank]
cur_dist_attr = get_dist_attr(program, dist_context) cur_dist_attr = get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr) converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict) state_dict = converter.convert(strict=strict)
...@@ -1618,10 +1638,10 @@ class Engine: ...@@ -1618,10 +1638,10 @@ class Engine:
""" """
if training: if training:
assert self._mode in self._serial_main_progs assert self._mode in self._dist_contexts
serial_program = self._serial_main_progs[self._mode]
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
serial_program = dist_context.serial_main_program
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
self._saver.save( self._saver.save(
path, path,
serial_program=serial_program, serial_program=serial_program,
...@@ -1629,10 +1649,11 @@ class Engine: ...@@ -1629,10 +1649,11 @@ class Engine:
dist_context=dist_context, dist_context=dist_context,
) )
else: else:
assert "predict" in self._dist_main_progs assert "predict" in self._dist_contexts
feed_vars = self._feed_vars["predict"]['inputs'] dist_context = self._dist_contexts["predict"]
fetch_vars = self._fetch_vars["predict"]['outputs'] feed_vars = dist_context.serial_feed_vars['inputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank] fetch_vars = dist_context.serial_fetch_vars['outputs']
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
self._saver.save_inference_model( self._saver.save_inference_model(
path, path,
feed_vars, feed_vars,
...@@ -1758,11 +1779,13 @@ class Engine: ...@@ -1758,11 +1779,13 @@ class Engine:
@property @property
def main_program(self): def main_program(self):
return self._dist_main_progs[self._mode][self._cur_rank] dist_context = self._dist_contexts[self._mode]
return dist_context.dist_main_programs[self._cur_rank]
@property @property
def startup_program(self): def startup_program(self):
return self._dist_startup_progs[self._mode][self._cur_rank] dist_context = self._dist_contexts[self._mode]
return dist_context.dist_startup_programs[self._cur_rank]
@property @property
def dist_context(self): def dist_context(self):
...@@ -1770,15 +1793,30 @@ class Engine: ...@@ -1770,15 +1793,30 @@ class Engine:
@property @property
def serial_main_program(self): def serial_main_program(self):
return self._serial_main_progs[self._mode] dist_context = self._dist_contexts[self._mode]
return dist_context.serial_main_program
@property @property
def serial_startup_program(self): def serial_startup_program(self):
return self._serial_startup_progs[self._mode] dist_context = self._dist_contexts[self._mode]
return dist_context.serial_startup_program
@property
def feed_vars(self):
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_feed_vars
@property @property
def fetch_vars(self): def fetch_vars(self):
return self._fetch_vars[self._mode] dist_context = self._dist_contexts[self._mode]
return dist_context.serial_fetch_vars
@property
def optimizer(self):
dist_context = self._dist_contexts[self._mode]
if dist_context._serial_optimizer:
return dist_context._serial_optimizer
return self._optimizer
@property @property
def inputs(self): def inputs(self):
......
...@@ -67,29 +67,43 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): ...@@ -67,29 +67,43 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
""" """
if process_mesh is not None: if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \ assert isinstance(
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh) process_mesh, ProcessMesh
), "Argument process_mesh {} is not an instance of ProcessMesh".format(
process_mesh
)
else: else:
process_mesh = get_current_process_mesh() process_mesh = get_current_process_mesh()
assert process_mesh is not None, \ assert (
"Specify the process mesh argument or use ProcessMesh context manager first." process_mesh is not None
assert isinstance(shard_spec, list), \ ), "Specify the process mesh argument or use ProcessMesh context manager first."
"Argument shard_spec {} is not an instance of list".format(shard_spec) assert isinstance(
dist_tensor = DistributedTensor(x) shard_spec, list
), "Argument shard_spec {} is not an instance of list".format(shard_spec)
if isinstance(x, str):
x = paddle.fluid.default_main_program().global_block()._var_recursive(x)
dist_tensor = DistributedTensor(x)
else:
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh dist_tensor.dist_attr.process_mesh = process_mesh
if serial_tensor.type == core.VarDesc.VarType.READER \ if (
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ serial_tensor.type == core.VarDesc.VarType.READER
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES: or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = serial_tensor.shape tensor_shape = serial_tensor.shape
if shard_spec is not None: if shard_spec is not None:
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \ assert verify_shard_spec(
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( shard_spec, tensor_shape, process_mesh
serial_tensor.name, shard_spec, tensor_shape, process_mesh) ), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
serial_tensor.name, shard_spec, tensor_shape, process_mesh
)
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping( dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
shard_spec, process_mesh) shard_spec, process_mesh
)
if process_mesh is not None: if process_mesh is not None:
dist_tensor.dist_attr.mark_annotated("process_mesh") dist_tensor.dist_attr.mark_annotated("process_mesh")
if shard_spec is not None: if shard_spec is not None:
...@@ -97,6 +111,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): ...@@ -97,6 +111,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor) default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x) dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
default_dist_ctx.add_process_mesh(process_mesh)
return x return x
...@@ -144,41 +159,54 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): ...@@ -144,41 +159,54 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
""" """
if process_mesh is not None: if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \ assert isinstance(
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh) process_mesh, ProcessMesh
), "Argument process_mesh {} is not an instance of ProcessMesh".format(
process_mesh
)
else: else:
process_mesh = get_current_process_mesh() process_mesh = get_current_process_mesh()
assert process_mesh is not None, \ assert (
"Specify the process mesh argument or use ProcessMesh context manager first." process_mesh is not None
), "Specify the process mesh argument or use ProcessMesh context manager first."
in_dims_mappings = [] in_dims_mappings = []
if in_shard_specs is not None: if in_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \ assert all(
"in_shard_spec {} is not a list of list or None".format(in_shard_specs) (isinstance(shard_spec, list) or shard_spec is None)
for shard_spec in in_shard_specs
), "in_shard_spec {} is not a list of list or None".format(
in_shard_specs
)
for shard_spec in in_shard_specs: for shard_spec in in_shard_specs:
if shard_spec is not None: if shard_spec is not None:
in_dims_mappings.append( in_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh)) convert_to_dims_mapping(shard_spec, process_mesh)
)
else: else:
in_dims_mappings.append(None) in_dims_mappings.append(None)
out_dims_mappings = [] out_dims_mappings = []
if out_shard_specs is not None: if out_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \ assert all(
"out_shard_spec {} is not a list of list or None".format(out_shard_specs) (isinstance(shard_spec, list) or shard_spec is None)
for shard_spec in out_shard_specs
), "out_shard_spec {} is not a list of list or None".format(
out_shard_specs
)
for shard_spec in out_shard_specs: for shard_spec in out_shard_specs:
if shard_spec is not None: if shard_spec is not None:
out_dims_mappings.append( out_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh)) convert_to_dims_mapping(shard_spec, process_mesh)
)
else: else:
out_dims_mappings.append(None) out_dims_mappings.append(None)
op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings, op = DistributedOperatorHelper(
out_dims_mappings) op, process_mesh, in_dims_mappings, out_dims_mappings
)
return op return op
def recompute(op): def recompute(op):
class RecomputeOperator: class RecomputeOperator:
def __init__(self, op): def __init__(self, op):
self._op = op self._op = op
...@@ -219,11 +247,13 @@ def add_to_collection(collection_name, value, name=None): ...@@ -219,11 +247,13 @@ def add_to_collection(collection_name, value, name=None):
_g_collections[collection_name] = [] _g_collections[collection_name] = []
if name is not None: if name is not None:
for _, v in _g_collections[collection_name]: for _, v in _g_collections[collection_name]:
if v == value: return if v == value:
return
_g_collections[collection_name].append((name, value)) _g_collections[collection_name].append((name, value))
else: else:
for _, v in _g_collections[collection_name]: for _, v in _g_collections[collection_name]:
if v == value: return if v == value:
return
_g_collections[collection_name].append((None, value)) _g_collections[collection_name].append((None, value))
......
...@@ -35,3 +35,4 @@ from . import dist_fused_attention ...@@ -35,3 +35,4 @@ from . import dist_fused_attention
from . import dist_reduce_sum_p from . import dist_reduce_sum_p
from . import dist_shape from . import dist_shape
from . import dist_assign from . import dist_assign
from . import dist_scale
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册