未验证 提交 5866a7a5 编写于 作者: C chengduo 提交者: GitHub

Enable fused_all_reduce_op_handle support GPU and CPU Gradients (#19418)

* Enable fused_all_reduce_op_handle support GPU and CPU Gradients
上级 3e5fb636
...@@ -40,11 +40,124 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -40,11 +40,124 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
#endif #endif
void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
std::vector<VarHandleBase *> inputs = this->Inputs();
std::vector<VarHandleBase *> outputs = this->Outputs();
auto in_var_handles = DynamicCast<VarHandle>(inputs);
auto out_var_handles = DynamicCast<VarHandle>(outputs);
AllReduceImpl(in_var_handles, out_var_handles);
}
void AllReduceOpHandle::AllReduceImpl(
const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles) {
size_t num_places = places_.size();
PADDLE_ENFORCE_EQ(
in_var_handles.size(), num_places,
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
PADDLE_ENFORCE_EQ(local_exec_scopes_.size(), num_places);
std::vector<const void *> lod_tensor_data;
std::vector<platform::Place> places;
lod_tensor_data.reserve(num_places);
places.reserve(num_places);
int64_t numel = -1;
bool is_gpu_place = false;
auto dtype = static_cast<framework::proto::VarType::Type>(0);
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto &local_scope = local_exec_scopes_[i];
auto var = local_scope->FindVar(in_var_handles[i]->name());
PADDLE_ENFORCE_NOT_NULL(var, "%s is not found int scope.",
in_var_handles[i]->name());
auto &lod_tensor = var->Get<LoDTensor>();
if (i == 0) {
numel = static_cast<int64_t>(lod_tensor.numel());
dtype = lod_tensor.type();
is_gpu_place = platform::is_gpu_place(lod_tensor.place());
}
PADDLE_ENFORCE_EQ(numel, static_cast<int64_t>(lod_tensor.numel()));
PADDLE_ENFORCE_EQ(dtype, lod_tensor.type());
PADDLE_ENFORCE_EQ(is_gpu_place, platform::is_gpu_place(lod_tensor.place()));
lod_tensor_data.emplace_back(lod_tensor.data<void>());
places.emplace_back(lod_tensor.place());
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal.");
}
std::vector<std::string> grad_var_names;
grad_var_names.reserve(num_places);
for (auto &out_var : out_var_handles) {
grad_var_names.emplace_back(out_var->Name());
}
AllReduceFunc(lod_tensor_data, dtype, numel, places, grad_var_names);
}
void AllReduceOpHandle::AllReduceFunc(
std::vector<const void *> lod_tensor_data,
const framework::proto::VarType::Type &dtype, int64_t numel,
const std::vector<platform::Place> &places,
const std::vector<std::string> &out_var_names) {
if (is_gpu_place(places[0])) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::RunAllReduceFuncs( PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
ncclDataType_t nccl_dtype = platform::ToNCCLDataType(dtype);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto &p = places[i];
void *buffer = const_cast<void *>(lod_tensor_data.at(i));
all_reduce_calls.emplace_back([=] {
NCCLAllReduce(p, buffer, buffer, numel, nccl_dtype, ncclSum);
});
}
NCCLAllReduceFunc(all_reduce_calls);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
} else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *local_exec_scopes_[0]
->FindVar(out_var_names[0])
->GetMutable<LoDTensor>();
// Reduce All Tensor to trg in CPU
ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel);
VisitDataType(trg.type(), func);
for (size_t i = 1; i < local_exec_scopes_.size(); ++i) {
auto &scope = local_exec_scopes_[i];
auto &p = places[i];
auto *var = scope->FindVar(out_var_names[i]);
size_t size = numel * SizeOfType(trg.type());
RunAndRecordEvent(p, [&trg, var, p, size] {
auto dst_ptr = var->GetMutable<framework::LoDTensor>()->data<void>();
platform::CPUPlace cpu_place;
memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data<void>(), size);
});
}
}
VLOG(10) << Name() << " size:" << numel * SizeOfType(dtype);
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::NCCLAllReduceFunc(
const std::vector<std::function<void()>> &all_reduce_calls) { const std::vector<std::function<void()>> &all_reduce_calls) {
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) { if (all_reduce_calls.size() == 1UL) {
...@@ -80,85 +193,6 @@ void AllReduceOpHandle::RunAllReduceFuncs( ...@@ -80,85 +193,6 @@ void AllReduceOpHandle::RunAllReduceFuncs(
} }
#endif #endif
void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope = local_exec_scopes_[i];
auto &lod_tensor =
local_scope->FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal.");
}
if (platform::is_gpu_place(lod_tensors[0]->place())) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
auto &lod_tensor = *lod_tensors[i];
void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) {
dtype = platform::ToNCCLDataType(lod_tensor.type());
}
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
all_reduce_calls.emplace_back([=] {
NCCLAllReduce(p, buffer, buffer, numel,
static_cast<ncclDataType_t>(dtype), ncclSum);
});
}
VLOG(10) << "allreduce size:" << numel * SizeOfType(lod_tensors[0]->type());
RunAllReduceFuncs(all_reduce_calls);
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_exec_scopes_[0]
->FindVar(out_var_handles[0]->name())
->GetMutable<framework::LoDTensor>();
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = local_exec_scopes_[i];
auto &p = places_[i];
auto *var = scope->FindVar(out_var_handles[i]->name());
auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
auto &tensor_cpu = trg;
TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu);
});
}
}
}
std::string AllReduceOpHandle::Name() const { return "all_reduce"; } std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -61,9 +61,17 @@ class AllReduceOpHandle : public OpHandleBase { ...@@ -61,9 +61,17 @@ class AllReduceOpHandle : public OpHandleBase {
#endif #endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunAllReduceFuncs( void NCCLAllReduceFunc(
const std::vector<std::function<void()>> &all_reduce_calls); const std::vector<std::function<void()>> &all_reduce_calls);
#endif #endif
void AllReduceImpl(const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles);
void AllReduceFunc(std::vector<const void *> lod_tensor_data,
const framework::proto::VarType::Type &dtype,
int64_t numel, const std::vector<platform::Place> &places,
const std::vector<std::string> &out_var_handles);
}; };
} // namespace details } // namespace details
......
...@@ -83,12 +83,20 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -83,12 +83,20 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
<< "Currently, fuse_all_optimizer_ops doesn't work under " << "Currently, fuse_all_optimizer_ops doesn't work under "
"parallel_graph."; "parallel_graph.";
strategy_.fuse_all_optimizer_ops_ = false; strategy_.fuse_all_optimizer_ops_ = false;
VLOG_IF(3, strategy_.fuse_all_reduce_ops_)
<< "fuse_all_reduce_ops doesn't work under "
"parallel_graph.";
strategy_.fuse_all_reduce_ops_ = false;
} }
if (strategy_.is_distribution_) { if (strategy_.is_distribution_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
<< "Currently, fuse_all_optimizer_ops only works under " << "Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode."; "Non-distributed mode.";
strategy_.fuse_all_optimizer_ops_ = false; strategy_.fuse_all_optimizer_ops_ = false;
VLOG_IF(3, strategy_.fuse_all_reduce_ops_)
<< "Currently, fuse_all_reduce_ops_ only works under "
"Non-distributed mode.";
strategy_.fuse_all_reduce_ops_ = false;
} }
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
...@@ -284,8 +292,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -284,8 +292,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kLocalScopes); pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes, pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes); &local_scopes);
pass->Erase(ir::kNRanks); pass->Erase(kNRanks);
pass->Set<size_t>(ir::kNRanks, new size_t(nranks)); pass->Set<size_t>(kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
...@@ -293,6 +301,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -293,6 +301,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif #endif
} else if (pass->Type() == "fuse_all_reduce_op_pass") { } else if (pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes); pass->Erase(kLocalScopes);
...@@ -307,11 +317,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -307,11 +317,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
new bool(use_hierarchical_allreduce_)); new bool(use_hierarchical_allreduce_));
#endif #endif
} else if (pass->Type() == "coalesce_grad_tensor_pass") { } else if (pass->Type() == "coalesce_grad_tensor_pass") {
pass->Erase(kPlaces); pass->Erase(kNRanks);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
} else if (pass->Type() == "sequential_execution_pass") { } else if (pass->Type() == "sequential_execution_pass") {
LOG(INFO) << "set enable_sequential_execution:" LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_; << enable_sequential_execution_;
......
...@@ -33,28 +33,18 @@ FusedAllReduceOpHandle::FusedAllReduceOpHandle( ...@@ -33,28 +33,18 @@ FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce, const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
const platform::NCCLCommunicator *ctxs) const platform::NCCLCommunicator *ctxs)
: NCCLOpHandleBase(node, places, ctxs), : AllReduceOpHandle(node, local_scopes, places, ctxs),
local_scopes_(local_scopes), num_of_all_reduce_(num_of_all_reduce) {}
num_of_all_reduce_(num_of_all_reduce) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
#else #else
FusedAllReduceOpHandle::FusedAllReduceOpHandle( FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce) const std::vector<platform::Place> &places, const size_t num_of_all_reduce)
: OpHandleBase(node), : AllReduceOpHandle(node, local_scopes, places),
local_scopes_(local_scopes), num_of_all_reduce_(num_of_all_reduce) {}
places_(places),
num_of_all_reduce_(num_of_all_reduce) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
#endif #endif
void FusedAllReduceOpHandle::RunImpl() { void FusedAllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
VLOG(4) << this->DebugString(); VLOG(4) << this->DebugString();
WaitInputVarGenerated(); WaitInputVarGenerated();
...@@ -71,6 +61,30 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -71,6 +61,30 @@ void FusedAllReduceOpHandle::RunImpl() {
in_var_handles.size(), out_var_handles.size(), in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal."); "The NoDummyInputSize and NoDummyOutputSize should be equal.");
// Note: some gradient op doesn't have CUDAKernel, so the gradients of
// those op are in CPUPlace, in this case, the all reduce should not be fused.
if (InputIsInDifferentPlace(in_var_handles)) {
for (size_t j = 0; j < num_of_all_reduce_; ++j) {
std::vector<VarHandle *> dev_inputs;
std::vector<VarHandle *> dev_outputs;
dev_inputs.reserve(place_num);
dev_outputs.reserve(place_num);
for (size_t idx = 0; idx < place_num; ++idx) {
dev_inputs.emplace_back(in_var_handles.at(j * place_num + idx));
dev_outputs.emplace_back(out_var_handles.at(j * place_num + idx));
}
AllReduceImpl(dev_inputs, dev_outputs);
}
} else {
FusedAllReduceFunc(in_var_handles, out_var_handles);
}
}
void FusedAllReduceOpHandle::FusedAllReduceFunc(
const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles) {
size_t place_num = places_.size();
GradientAndLoDTensor grads_tensor; GradientAndLoDTensor grads_tensor;
grads_tensor.resize(place_num); grads_tensor.resize(place_num);
...@@ -87,14 +101,11 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -87,14 +101,11 @@ void FusedAllReduceOpHandle::RunImpl() {
static_cast<framework::proto::VarType::Type>(0); static_cast<framework::proto::VarType::Type>(0);
GetDTypeAndNumel(g_tensor, &ele_dtype, &element_num); GetDTypeAndNumel(g_tensor, &ele_dtype, &element_num);
if (numel == -1) { if (scope_idx == 0) {
numel = element_num; numel = element_num;
}
if (dtype == static_cast<framework::proto::VarType::Type>(0)) {
dtype = ele_dtype; dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype,
static_cast<framework::proto::VarType::Type>(0));
} }
PADDLE_ENFORCE_EQ(ele_dtype, dtype); PADDLE_ENFORCE_EQ(ele_dtype, dtype);
// Check whether the address space is contiguous. // Check whether the address space is contiguous.
...@@ -134,66 +145,36 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -134,66 +145,36 @@ void FusedAllReduceOpHandle::RunImpl() {
} }
std::vector<const void *> lod_tensor_data; std::vector<const void *> lod_tensor_data;
lod_tensor_data.reserve(place_num);
for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
auto data = grads_tensor.at(scope_idx).at(0).second->data<void>(); auto data = grads_tensor.at(scope_idx).at(0).second->data<void>();
lod_tensor_data.emplace_back(data); lod_tensor_data.emplace_back(data);
} }
std::vector<std::string> grad_var_names;
grad_var_names.reserve(place_num);
for (auto &grad_t : grads_tensor) {
grad_var_names.emplace_back(grad_t.at(0).first);
}
if (platform::is_gpu_place(places_[0])) { AllReduceFunc(lod_tensor_data, dtype, numel, this->places_, grad_var_names);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) }
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int nccl_dtype = platform::ToNCCLDataType(dtype);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
void *buffer = const_cast<void *>(lod_tensor_data.at(i));
all_reduce_calls.emplace_back([=] {
NCCLAllReduce(p, buffer, buffer, numel,
static_cast<ncclDataType_t>(nccl_dtype), ncclSum);
});
}
VLOG(10) << "fusedallreduce size:" << numel * SizeOfType(dtype); bool FusedAllReduceOpHandle::InputIsInDifferentPlace(
const std::vector<VarHandle *> &in_var_handles) const {
this->RunAndRecordEvent([&] { for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
if (all_reduce_calls.size() == 1UL) { auto *local_scope = local_exec_scopes_[scope_idx];
// Do not use NCCLGroup when manage NCCL by per thread per device size_t place_num = places_.size();
all_reduce_calls[0](); for (size_t j = 0; j < in_var_handles.size(); j += place_num) {
} else { auto var_name = in_var_handles[j]->name();
platform::NCCLGroupGuard guard; auto var = local_scope->FindVar(var_name);
for (auto &call : all_reduce_calls) { PADDLE_ENFORCE_NOT_NULL(var, "%s is not found in local scope.", var_name);
call(); auto &lod_tensor = var->Get<LoDTensor>();
} if (!is_same_place(lod_tensor.place(), places_.at(scope_idx))) {
return true;
} }
});
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} else {
// Special handle CPU only Operator's gradient. Like CRF
auto grad_name = grads_tensor.at(0).at(0).first;
auto &trg = *this->local_exec_scopes_[0]
->FindVar(grad_name)
->GetMutable<framework::LoDTensor>();
// Reduce All data to trg in CPU
ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel);
VisitDataType(trg.type(), func);
for (size_t i = 1; i < local_exec_scopes_.size(); ++i) {
auto &scope = *local_exec_scopes_[i];
auto &p = places_[i];
auto *var = scope.FindVar(grad_name);
auto *dev_ctx = dev_ctxes_.at(p);
size_t size = numel * SizeOfType(trg.type());
RunAndRecordEvent(p, [&trg, var, dev_ctx, p, size] {
auto dst_ptr = var->GetMutable<framework::LoDTensor>()->data<void>();
platform::CPUPlace cpu_place;
memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data<void>(), size);
});
} }
} }
return false;
} }
void FusedAllReduceOpHandle::GetGradLoDTensor( void FusedAllReduceOpHandle::GetGradLoDTensor(
...@@ -202,12 +183,14 @@ void FusedAllReduceOpHandle::GetGradLoDTensor( ...@@ -202,12 +183,14 @@ void FusedAllReduceOpHandle::GetGradLoDTensor(
std::vector<std::pair<std::string, const LoDTensor *>> *grad_tensor) const { std::vector<std::pair<std::string, const LoDTensor *>> *grad_tensor) const {
auto *local_scope = local_exec_scopes_[scope_idx]; auto *local_scope = local_exec_scopes_[scope_idx];
size_t place_num = places_.size(); size_t place_num = places_.size();
for (size_t j = 0; j < in_var_handles.size(); j += place_num) { for (size_t j = 0; j < in_var_handles.size(); j += place_num) {
auto var_name = in_var_handles[j]->name(); auto var_name = in_var_handles[j]->name();
PADDLE_ENFORCE_EQ(var_name, out_var_handles[j]->name()); PADDLE_ENFORCE_EQ(var_name, out_var_handles[j]->name());
auto &lod_tensor = local_scope->FindVar(var_name)->Get<LoDTensor>(); auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx)); PADDLE_ENFORCE_NOT_NULL(var, "%s is not found in local scope.", var_name);
auto &lod_tensor = var->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx),
"%s(%d) is not in the right place.", var_name, scope_idx);
grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor)); grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor));
} }
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -30,14 +31,14 @@ namespace framework { ...@@ -30,14 +31,14 @@ namespace framework {
namespace details { namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
struct FusedAllReduceOpHandle : public NCCLOpHandleBase { struct FusedAllReduceOpHandle : public AllReduceOpHandle {
FusedAllReduceOpHandle(ir::Node *node, FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const size_t num_of_all_reduce, const size_t num_of_all_reduce,
const platform::NCCLCommunicator *ctxs); const platform::NCCLCommunicator *ctxs);
#else #else
struct FusedAllReduceOpHandle : public OpHandleBase { struct FusedAllReduceOpHandle : public AllReduceOpHandle {
FusedAllReduceOpHandle(ir::Node *node, FusedAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
...@@ -45,22 +46,10 @@ struct FusedAllReduceOpHandle : public OpHandleBase { ...@@ -45,22 +46,10 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
#endif #endif
std::string Name() const override; std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return true; };
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
private: private:
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std::vector<platform::Place> places_;
#endif
size_t num_of_all_reduce_; size_t num_of_all_reduce_;
// Check the dtype of the input // Check the dtype of the input
...@@ -74,6 +63,12 @@ struct FusedAllReduceOpHandle : public OpHandleBase { ...@@ -74,6 +63,12 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
const std::vector<VarHandle *> &out_var_handles, const std::vector<VarHandle *> &out_var_handles,
std::vector<std::pair<std::string, const LoDTensor *>> std::vector<std::pair<std::string, const LoDTensor *>>
*grad_tensor) const; *grad_tensor) const;
bool InputIsInDifferentPlace(
const std::vector<VarHandle *> &in_var_handles) const;
void FusedAllReduceFunc(const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles);
}; };
} // namespace details } // namespace details
......
...@@ -42,6 +42,8 @@ typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>> ...@@ -42,6 +42,8 @@ typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars; GraphVars;
constexpr char kGraphVars[] = "vars"; constexpr char kGraphVars[] = "vars";
constexpr char kNRanks[] = "nranks";
constexpr char kPlaces[] = "places"; constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes"; constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs"; constexpr char kNCCLCtxs[] = "nccl_ctxs";
...@@ -68,6 +70,9 @@ constexpr char kParamsAndSparseGrads[] = "params_and_sparse_grads"; ...@@ -68,6 +70,9 @@ constexpr char kParamsAndSparseGrads[] = "params_and_sparse_grads";
typedef std::vector<ProgramDesc> ProgramDescs; typedef std::vector<ProgramDesc> ProgramDescs;
constexpr char kProgramDescs[] = "program_descs"; constexpr char kProgramDescs[] = "program_descs";
typedef std::unordered_set<std::string> PinnedVars;
constexpr char kPinnedVars[] = "pinned_vars";
typedef std::vector<std::vector<std::pair<std::string, std::string>>> typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupParamsAndGrads; GroupParamsAndGrads;
constexpr char kGroupParamsAndDenseGrads[] = "group_params_dense_grads"; constexpr char kGroupParamsAndDenseGrads[] = "group_params_dense_grads";
......
...@@ -126,7 +126,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -126,7 +126,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
}); });
} }
RunAllReduceFuncs(all_reduce_calls); NCCLAllReduceFunc(all_reduce_calls);
} }
int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) { int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) {
......
...@@ -65,28 +65,33 @@ double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; } ...@@ -65,28 +65,33 @@ double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; }
class CoalesceGradTensorPass : public ir::Pass { class CoalesceGradTensorPass : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph *graph) const { void ApplyImpl(ir::Graph *graph) const {
if (Get<size_t>(details::kNRanks) <= 1) {
VLOG(6) << "The number of place is" << Get<size_t>(details::kNRanks)
<< ", there doesn't need apply FuseAllReduceOpPass.";
return;
}
ir::Graph &result = *graph; ir::Graph &result = *graph;
details::ParamsAndGrads params_grads; details::ParamsAndGrads params_grads;
RecordParamsAndGrads(result, &params_grads); RecordParamsAndGrads(result, &params_grads);
VLOG(10) << "The number of params and grads is:" << params_grads.size();
if (params_grads.size() == 0) {
return;
}
auto vars_info = GetVarInfo(result);
ResetAttribute<details::ParamsAndGrads>(details::kParamsAndDenseGrads, ResetAttribute<details::ParamsAndGrads>(details::kParamsAndDenseGrads,
&result); &result);
ResetAttribute<details::ParamsAndGrads>(details::kParamsAndSparseGrads, ResetAttribute<details::ParamsAndGrads>(details::kParamsAndSparseGrads,
&result); &result);
ResetAttribute<details::GroupParamsAndGrads>( ResetAttribute<details::GroupParamsAndGrads>(
details::kGroupParamsAndDenseGrads, &result); details::kGroupParamsAndDenseGrads, &result);
VLOG(10) << "The number of params and grads is:" << params_grads.size();
if (params_grads.size() == 0) {
return;
}
auto &p_g_dense_grad = auto &p_g_dense_grad =
result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads); result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads);
auto &p_g_sparse_grad = auto &p_g_sparse_grad =
result.Get<details::ParamsAndGrads>(details::kParamsAndSparseGrads); result.Get<details::ParamsAndGrads>(details::kParamsAndSparseGrads);
auto vars_info = GetVarInfo(result);
for (auto &param_grad : params_grads) { for (auto &param_grad : params_grads) {
if (IsLoDTensorType(GetTypeOfVar(vars_info, param_grad.second))) { if (IsLoDTensorType(GetTypeOfVar(vars_info, param_grad.second))) {
p_g_dense_grad.emplace_back(param_grad); p_g_dense_grad.emplace_back(param_grad);
...@@ -118,33 +123,37 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -118,33 +123,37 @@ class CoalesceGradTensorPass : public ir::Pass {
p_g_dense_grad.size(), num_of_p_g_dense_grad, p_g_dense_grad.size(), num_of_p_g_dense_grad,
"The number of p_g_dense_grad is not consistent with before."); "The number of p_g_dense_grad is not consistent with before.");
auto &pinned_var_set =
graph->GetOrInit<details::PinnedVars>(details::kPinnedVars);
if (IsUnifiedDtype(p_g_dense_grad, vars_info)) { if (IsUnifiedDtype(p_g_dense_grad, vars_info)) {
SetGradientPersistable(p_g_dense_grad, vars_info); RecordGradients(p_g_dense_grad, vars_info, &pinned_var_set);
CoalesceTensors(vars_info, p_g_dense_grad, &result); CoalesceTensors(vars_info, p_g_dense_grad, &result);
} else { } else {
for (auto &sub_param_grad : group_params_grads) { for (auto &sub_param_grad : group_params_grads) {
SetGradientPersistable(p_g_dense_grad, vars_info); RecordGradients(p_g_dense_grad, vars_info, &pinned_var_set);
PADDLE_ENFORCE(IsUnifiedDtype(sub_param_grad, vars_info), PADDLE_ENFORCE_EQ(IsUnifiedDtype(sub_param_grad, vars_info), true,
"The data type of the same group is not consistent."); "The data type of the same group is not consistent.");
CoalesceTensors(vars_info, sub_param_grad, &result); CoalesceTensors(vars_info, sub_param_grad, &result);
} }
} }
} }
void SetGradientPersistable( void RecordGradients(
const std::vector<std::pair<std::string, std::string>> &sub_param_grad, const std::vector<std::pair<std::string, std::string>> &sub_param_grad,
const std::unordered_map<std::string, std::vector<ir::Node *>> &vars_info) const std::unordered_map<std::string, std::vector<ir::Node *>> &vars_info,
const { std::unordered_set<std::string> *pinned_var_set) const {
// The Gradients should not be reused during memory optimization.
for (auto &p_g : sub_param_grad) { for (auto &p_g : sub_param_grad) {
auto iter = vars_info.find(p_g.second); auto iter = vars_info.find(p_g.second);
PADDLE_ENFORCE(iter != vars_info.end(), "%s is not found.", p_g.second); PADDLE_ENFORCE_EQ(iter != vars_info.end(), true, "%s is not found.",
PADDLE_ENFORCE(!iter->second.empty()); p_g.second);
// Set persistable PADDLE_ENFORCE_EQ(!iter->second.empty(), true);
for (auto it : iter->second) { for (auto it : iter->second) {
PADDLE_ENFORCE_NOT_NULL(it->Var()); PADDLE_ENFORCE_NOT_NULL(it->Var());
it->Var()->SetPersistable(true); pinned_var_set->insert(it->Var()->Name());
} }
PADDLE_ENFORCE(IsLoDTensorType(GetTypeOfVar(vars_info, p_g.second))); PADDLE_ENFORCE_EQ(IsLoDTensorType(GetTypeOfVar(vars_info, p_g.second)),
true);
} }
} }
...@@ -411,8 +420,10 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -411,8 +420,10 @@ class CoalesceGradTensorPass : public ir::Pass {
const std::unordered_map<std::string, std::vector<Node *>> &vars_info, const std::unordered_map<std::string, std::vector<Node *>> &vars_info,
const std::string &var_name) const { const std::string &var_name) const {
auto grad_iter = vars_info.find(var_name); auto grad_iter = vars_info.find(var_name);
PADDLE_ENFORCE(grad_iter != vars_info.end(), "%s is not found.", var_name); PADDLE_ENFORCE_EQ(grad_iter != vars_info.end(), true, "%s is not found.",
PADDLE_ENFORCE(!grad_iter->second.empty()); var_name);
PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true, "%s is not found.",
var_name);
PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var()); PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var());
return grad_iter->second.front()->Var(); return grad_iter->second.front()->Var();
} }
...@@ -483,4 +494,5 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -483,4 +494,5 @@ class CoalesceGradTensorPass : public ir::Pass {
} // namespace paddle } // namespace paddle
REGISTER_PASS(coalesce_grad_tensor_pass, REGISTER_PASS(coalesce_grad_tensor_pass,
paddle::framework::ir::CoalesceGradTensorPass); paddle::framework::ir::CoalesceGradTensorPass)
.RequirePassAttr(paddle::framework::details::kNRanks);
...@@ -106,6 +106,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -106,6 +106,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
params_and_dense_grads.size(), aux_var_set.at(kGrad).size(), params_and_dense_grads.size(), aux_var_set.at(kGrad).size(),
"The number of dense gradients should be little than optimizer ops."); "The number of dense gradients should be little than optimizer ops.");
std::unordered_set<std::string> opt_grad_set(aux_var_set.at(kGrad).size()); std::unordered_set<std::string> opt_grad_set(aux_var_set.at(kGrad).size());
for (auto &p_g : params_and_dense_grads) { for (auto &p_g : params_and_dense_grads) {
opt_grad_set.insert(p_g.second); opt_grad_set.insert(p_g.second);
...@@ -138,7 +139,8 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -138,7 +139,8 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars); auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = auto iter =
std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front()); std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front());
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad."); PADDLE_ENFORCE_EQ(iter != fused_vars.end(), true,
"Not find the fused_grad.");
fused_vars_name[kGrad] = fused_grad.front(); fused_vars_name[kGrad] = fused_grad.front();
// Sort the parameters and auxiliary variables according // Sort the parameters and auxiliary variables according
...@@ -246,18 +248,24 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads( ...@@ -246,18 +248,24 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
const std::vector<std::string> &params, const std::vector<std::string> &params,
const std::vector<std::string> &grads, const std::string &fused_grad_name, const std::vector<std::string> &grads, const std::string &fused_grad_name,
ir::Graph *result) const { ir::Graph *result) const {
auto &pinned_var_set =
result->GetOrInit<details::PinnedVars>(details::kPinnedVars);
auto vars_info = GetVarInfo(*result); auto vars_info = GetVarInfo(*result);
// Set Gradients as Persistable to prevent this var becoming reusable. // The Gradients should not be reused during memory optimization.
for (auto &grad_var_name : grads) { for (auto &grad_var_name : grads) {
auto iter = vars_info.find(grad_var_name); auto iter = vars_info.find(grad_var_name);
PADDLE_ENFORCE(iter != vars_info.end()); PADDLE_ENFORCE_EQ(iter != vars_info.end(), true, "%s is not found.",
PADDLE_ENFORCE(!iter->second.empty()); grad_var_name);
PADDLE_ENFORCE_EQ(!iter->second.empty(), true, "%s is not found.",
grad_var_name);
PADDLE_ENFORCE_NOT_NULL(iter->second.front()->Var()); PADDLE_ENFORCE_NOT_NULL(iter->second.front()->Var());
PADDLE_ENFORCE(IsLoDTensorType(iter->second.front()->Var()->GetType()), PADDLE_ENFORCE_EQ(
"Currently the gradient type only should be LoDTensor when " IsLoDTensorType(iter->second.front()->Var()->GetType()), true,
"fusing optimizer ops."); "Currently the gradient type only should be LoDTensor when "
"fusing optimizer ops.");
for (auto var : iter->second) { for (auto var : iter->second) {
var->Var()->SetPersistable(true); pinned_var_set.insert(var->Var()->Name());
} }
} }
...@@ -293,8 +301,9 @@ proto::VarType::Type FuseOptimizerOpPass::GetTypeOfVar( ...@@ -293,8 +301,9 @@ proto::VarType::Type FuseOptimizerOpPass::GetTypeOfVar(
const std::unordered_map<std::string, std::vector<Node *>> &var_nodes, const std::unordered_map<std::string, std::vector<Node *>> &var_nodes,
const std::string &name) const { const std::string &name) const {
auto grad_iter = var_nodes.find(name); auto grad_iter = var_nodes.find(name);
PADDLE_ENFORCE(grad_iter != var_nodes.end()); PADDLE_ENFORCE_EQ(grad_iter != var_nodes.end(), true, "%s is not found.",
PADDLE_ENFORCE(grad_iter->second.size() > 0); name);
PADDLE_ENFORCE_GT(grad_iter->second.size(), 0);
PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var()); PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var());
return grad_iter->second.front()->Var()->GetType(); return grad_iter->second.front()->Var()->GetType();
} }
...@@ -321,24 +330,25 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars( ...@@ -321,24 +330,25 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads, const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_vars_set, std::unordered_map<std::string, std::vector<std::string>> *aux_vars_set,
std::vector<ir::Node *> *ops) const { std::vector<ir::Node *> *ops) const {
PADDLE_ENFORCE_NE(aux_vars_set->count(kParam), static_cast<size_t>(0)); PADDLE_ENFORCE_NE(aux_vars_set->count(kGrad), static_cast<size_t>(0));
auto &param_vec = aux_vars_set->at(kParam); auto &grad_vec = aux_vars_set->at(kGrad);
std::vector<size_t> param_sort_idx; std::vector<size_t> grad_sort_idx;
param_sort_idx.reserve(param_vec.size()); grad_sort_idx.reserve(grad_vec.size());
for (auto &p_g : params_grads) { for (auto &p_g : params_grads) {
auto iter = std::find(param_vec.begin(), param_vec.end(), p_g.first); auto iter = std::find(grad_vec.begin(), grad_vec.end(), p_g.second);
PADDLE_ENFORCE(iter != param_vec.end()); PADDLE_ENFORCE_EQ(iter != grad_vec.end(), true,
auto idx = std::distance(param_vec.begin(), iter); "%s is not found in grad_vec", p_g.second);
param_sort_idx.emplace_back(idx); auto idx = std::distance(grad_vec.begin(), iter);
grad_sort_idx.emplace_back(idx);
} }
for (auto &aux_vars : *aux_vars_set) { for (auto &aux_vars : *aux_vars_set) {
std::vector<std::string> sorted_vars; std::vector<std::string> sorted_vars;
sorted_vars.reserve(aux_vars.second.size()); sorted_vars.reserve(aux_vars.second.size());
for (size_t i = 0; i < aux_vars.second.size(); ++i) { for (size_t i = 0; i < aux_vars.second.size(); ++i) {
sorted_vars.emplace_back(aux_vars.second.at(param_sort_idx[i])); sorted_vars.emplace_back(aux_vars.second.at(grad_sort_idx[i]));
} }
std::swap(aux_vars.second, sorted_vars); std::swap(aux_vars.second, sorted_vars);
...@@ -354,7 +364,7 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars( ...@@ -354,7 +364,7 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
std::vector<ir::Node *> sorted_ops; std::vector<ir::Node *> sorted_ops;
sorted_ops.reserve(ops->size()); sorted_ops.reserve(ops->size());
for (size_t i = 0; i < ops->size(); ++i) { for (size_t i = 0; i < ops->size(); ++i) {
sorted_ops.emplace_back(ops->at(param_sort_idx[i])); sorted_ops.emplace_back(ops->at(grad_sort_idx[i]));
} }
std::swap(*ops, sorted_ops); std::swap(*ops, sorted_ops);
} }
......
...@@ -85,10 +85,18 @@ class Graph { ...@@ -85,10 +85,18 @@ class Graph {
return attrs_.count(attr_name) > 0; return attrs_.count(attr_name) > 0;
} }
template <typename AttrType>
AttrType &GetOrInit(const std::string &attr_name) {
if (!Has(attr_name)) {
Set(attr_name, new AttrType);
}
return Get<AttrType>(attr_name);
}
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", PADDLE_ENFORCE_EQ(Has(attr_name), true, "%s attr not registered for graph.",
attr_name); attr_name);
try { try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) { } catch (boost::bad_any_cast &) {
...@@ -101,8 +109,8 @@ class Graph { ...@@ -101,8 +109,8 @@ class Graph {
template <typename AttrType> template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) { void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0, "%s already set in the graph",
attr_name); attr_name);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name; VLOG(3) << "deleting " << attr_name;
...@@ -112,15 +120,15 @@ class Graph { ...@@ -112,15 +120,15 @@ class Graph {
template <typename AttrType> template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) { void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0, "%s already set in the graph",
attr_name); attr_name);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = []() {}; attr_dels_[attr_name] = []() {};
} }
void Erase(const std::string &attr_name) { void Erase(const std::string &attr_name) {
PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph", PADDLE_ENFORCE_NE(attrs_.count(attr_name), 0, "%s not set in the graph",
attr_name); attr_name);
attr_dels_[attr_name](); attr_dels_[attr_name]();
attrs_.erase(attr_name); attrs_.erase(attr_name);
attr_dels_.erase(attr_name); attr_dels_.erase(attr_name);
...@@ -130,7 +138,7 @@ class Graph { ...@@ -130,7 +138,7 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc); PADDLE_ENFORCE_NOT_NULL(var_desc);
auto *x = AddNode(new ir::Node(var_desc)); auto *x = AddNode(new ir::Node(var_desc));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
return x; return x;
...@@ -138,7 +146,7 @@ class Graph { ...@@ -138,7 +146,7 @@ class Graph {
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc); PADDLE_ENFORCE_NOT_NULL(op_desc);
auto *x = AddNode(new ir::Node(op_desc)); auto *x = AddNode(new ir::Node(op_desc));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
return x; return x;
...@@ -178,7 +186,7 @@ class Graph { ...@@ -178,7 +186,7 @@ class Graph {
} }
std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) { std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true);
std::unique_ptr<ir::Node> ret; std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release()); ret.reset(nodes_.at(node).release());
nodes_.erase(node); nodes_.erase(node);
...@@ -204,7 +212,7 @@ class Graph { ...@@ -204,7 +212,7 @@ class Graph {
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true);
nodes_[node].reset(node); nodes_[node].reset(node);
node_set_.insert(node); node_set_.insert(node);
return node; return node;
......
...@@ -206,5 +206,51 @@ TEST(GraphTest, WriteAfterWrite) { ...@@ -206,5 +206,51 @@ TEST(GraphTest, WriteAfterWrite) {
ASSERT_NE(control_dep2, nullptr); ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2); ASSERT_EQ(control_dep1, control_dep2);
} }
TEST(GraphTest, TestException) {
ProgramDesc prog;
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
bool not_met_exception = false;
try {
g->Erase("no_attr");
} catch (const platform::EnforceNotMet &e) {
not_met_exception = true;
}
ASSERT_TRUE(not_met_exception);
not_met_exception = false;
try {
g->CreateVarNode(nullptr);
} catch (const platform::EnforceNotMet &e) {
not_met_exception = true;
}
ASSERT_TRUE(not_met_exception);
not_met_exception = false;
try {
g->CreateOpNode(nullptr);
} catch (const platform::EnforceNotMet &e) {
not_met_exception = true;
}
ASSERT_TRUE(not_met_exception);
not_met_exception = false;
try {
g->RemoveNode(nullptr);
} catch (const platform::EnforceNotMet &e) {
not_met_exception = true;
}
ASSERT_TRUE(not_met_exception);
not_met_exception = false;
try {
g->AddNode(nullptr);
g->AddNode(nullptr);
} catch (const platform::EnforceNotMet &e) {
not_met_exception = true;
}
ASSERT_TRUE(not_met_exception);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -36,6 +36,11 @@ void MemoryReusePass::ApplyImpl(Graph *graph) const { ...@@ -36,6 +36,11 @@ void MemoryReusePass::ApplyImpl(Graph *graph) const {
reused_out_var_names_.resize(all_vars_->size()); reused_out_var_names_.resize(all_vars_->size());
var_descs_.resize(all_vars_->size()); var_descs_.resize(all_vars_->size());
pinned_var_set_ = nullptr;
if (graph->Has(details::kPinnedVars)) {
pinned_var_set_ = &graph->Get<details::PinnedVars>(details::kPinnedVars);
}
// Collect the existing ShareTensorBufferOpHandles. // Collect the existing ShareTensorBufferOpHandles.
// This is because (1) we want to reuse the existing // This is because (1) we want to reuse the existing
// ShareTensorBufferOpHandles to avoid inserting too many ops; // ShareTensorBufferOpHandles to avoid inserting too many ops;
...@@ -195,7 +200,7 @@ bool MemoryReusePass::IsInVarReusable(const details::VarHandle &in_var) const { ...@@ -195,7 +200,7 @@ bool MemoryReusePass::IsInVarReusable(const details::VarHandle &in_var) const {
const VarDesc *in_var_desc = GetVarDesc(in_var); const VarDesc *in_var_desc = GetVarDesc(in_var);
if (in_var_desc->Persistable()) { if (IsPinnedVar(*in_var_desc)) {
return false; return false;
} }
...@@ -244,7 +249,7 @@ bool MemoryReusePass::IsOutVarReusable( ...@@ -244,7 +249,7 @@ bool MemoryReusePass::IsOutVarReusable(
} }
const VarDesc *out_var_desc = GetVarDesc(out_var); const VarDesc *out_var_desc = GetVarDesc(out_var);
if (out_var_desc->Persistable()) { if (IsPinnedVar(*out_var_desc)) {
return false; return false;
} }
...@@ -261,6 +266,11 @@ bool MemoryReusePass::IsOutVarReusable( ...@@ -261,6 +266,11 @@ bool MemoryReusePass::IsOutVarReusable(
return true; return true;
} }
bool MemoryReusePass::IsPinnedVar(const VarDesc &var_desc) const {
return var_desc.Persistable() ||
(pinned_var_set_ && pinned_var_set_->count(var_desc.Name()));
}
/** /**
* Input-Output pair can be reused only when: * Input-Output pair can be reused only when:
* - they are not the same var. * - they are not the same var.
......
...@@ -133,6 +133,9 @@ class MemoryReusePass : public Pass { ...@@ -133,6 +133,9 @@ class MemoryReusePass : public Pass {
mutable std::vector<std::unordered_set<std::string>> reused_out_var_names_; mutable std::vector<std::unordered_set<std::string>> reused_out_var_names_;
mutable std::vector<std::unordered_map<std::string, VarDesc *>> var_descs_; mutable std::vector<std::unordered_map<std::string, VarDesc *>> var_descs_;
mutable details::PinnedVars *pinned_var_set_;
bool IsPinnedVar(const VarDesc &out_var_desc) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -312,13 +312,22 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -312,13 +312,22 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
ShrinkDepsOpFunctor shrink_func( ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<details::OpHandleBase>(*graph)); ir::FilterByNodeWrapper<details::OpHandleBase>(*graph));
details::PinnedVars *pinned_var_set = nullptr;
if (graph->Has(details::kPinnedVars)) {
pinned_var_set = &graph->Get<details::PinnedVars>(details::kPinnedVars);
}
auto is_pinned_var = [&pinned_var_set](const VarDesc &var_desc) {
return var_desc.Persistable() ||
(pinned_var_set && pinned_var_set->count(var_desc.Name()));
};
VLOG(1) << "Place number: " << vars.size(); VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) { for (auto &name_var_pair : vars[i]) {
// Whether this variable can be reused or deleted? If not, we do not // Whether this variable can be reused or deleted? If not, we do not
// compute reference counts and dependencies. // compute reference counts and dependencies.
VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second); VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);
if (var_desc == nullptr || var_desc->Persistable()) { if (var_desc == nullptr || is_pinned_var(*var_desc)) {
continue; continue;
} }
......
...@@ -29,14 +29,21 @@ namespace ir { ...@@ -29,14 +29,21 @@ namespace ir {
class FuseAllReduceOpPass : public ir::Pass { class FuseAllReduceOpPass : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph; if (Get<size_t>(details::kNRanks) <= 1) {
VLOG(6) << "The number of place is" << Get<size_t>(details::kNRanks)
<< ", there doesn't need apply FuseAllReduceOpPass.";
return;
}
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces); auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes); auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *multi_nccl_ctxs = auto *multi_nccl_ctxs =
&Get<platform::NCCLCommunicator>(details::kNCCLCtxs); &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
#endif #endif
ir::Graph &result = *graph;
auto &params_grads = auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads); result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads);
size_t num_of_all_reduce = params_grads.size(); size_t num_of_all_reduce = params_grads.size();
...@@ -203,4 +210,5 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -203,4 +210,5 @@ class FuseAllReduceOpPass : public ir::Pass {
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_all_reduce_op_pass, REGISTER_PASS(fuse_all_reduce_op_pass,
paddle::framework::ir::FuseAllReduceOpPass); paddle::framework::ir::FuseAllReduceOpPass)
.RequirePassAttr(paddle::framework::details::kNRanks);
...@@ -205,7 +205,7 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const { ...@@ -205,7 +205,7 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
} }
// Insert collective ops if nranks > 1 // Insert collective ops if nranks > 1
if (!is_forwarding && Get<size_t>(kNRanks) > 1) { if (!is_forwarding && Get<size_t>(details::kNRanks) > 1) {
try { try {
bool is_bk_op = bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr( static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
...@@ -273,7 +273,7 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( ...@@ -273,7 +273,7 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
loss_scale = 1; loss_scale = 1;
break; break;
case details::BuildStrategy::GradientScaleStrategy::kCoeffNumDevice: case details::BuildStrategy::GradientScaleStrategy::kCoeffNumDevice:
loss_scale = Get<size_t>(kNRanks); loss_scale = Get<size_t>(details::kNRanks);
break; break;
case details::BuildStrategy::GradientScaleStrategy::kCustomized: case details::BuildStrategy::GradientScaleStrategy::kCustomized:
loss_scale = 0; loss_scale = 0;
...@@ -1106,7 +1106,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) { ...@@ -1106,7 +1106,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
.RequirePassAttr(paddle::framework::details::kPlaces) \ .RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \ .RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::ir::kStrategy) \ .RequirePassAttr(paddle::framework::ir::kStrategy) \
.RequirePassAttr(paddle::framework::ir::kNRanks) .RequirePassAttr(paddle::framework::details::kNRanks)
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass, REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::ir::ReduceSSAGraphBuilder); paddle::framework::ir::ReduceSSAGraphBuilder);
......
...@@ -35,7 +35,6 @@ namespace ir { ...@@ -35,7 +35,6 @@ namespace ir {
constexpr char kLossVarName[] = "loss_var_name"; constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy"; constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass { class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册