提交 7722baa8 编写于 作者: C chengduoZH

follow comments and clean code

上级 c8911895
...@@ -22,9 +22,9 @@ namespace details { ...@@ -22,9 +22,9 @@ namespace details {
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
if (places_.size() == 1) return; if (places_.size() == 1) return;
// the input and output may have dummy var.
VarHandle *in_var_handle;
// The input and output may have dummy vars.
VarHandle *in_var_handle;
{ {
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1, PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
...@@ -53,23 +53,39 @@ void BroadcastOpHandle::RunImpl() { ...@@ -53,23 +53,39 @@ void BroadcastOpHandle::RunImpl() {
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
if (platform::is_cpu_place(in_tensor.place())) { // NOTE(zcd): the Place of input can be get from in_tensor and in_var_handle ,
for (auto *out : out_var_handles) { // maybe they are different, because the Place that getting from in_tensor is
if (*out == *in_var_handle) { // determined at runtime, the other is determined at building SSA graph stage.
// If they are different, DataTransform should be applied. Currently, it has
// not been done yet.
for (auto *out_var_handle : out_var_handles) {
if (*out_var_handle == *in_var_handle) {
continue; continue;
} }
auto &out_p = out_var_handle->place_;
auto &out_p = out->place_; auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_); ->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_EQ(out_p.which(), in_tensor.place().which(), PADDLE_ENFORCE_EQ(
"Places must be all on CPU or all on CUDA."); out_p.which(), in_tensor.place().which(),
"Currently, Places of input and output must be all on CPU "
"or all on GPU.");
VariableVisitor::ShareDimsAndLoD(*in_var, out_var); VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
in_tensor.type()); in_tensor.type());
}
if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) {
if (*out_var_handle == *in_var_handle) {
continue;
}
auto &out_p = out_var_handle->place_;
auto dev_ctx = dev_ctxes_.at(out_p); auto dev_ctx = dev_ctxes_.at(out_p);
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
in_tensor, out_p, *dev_ctx, in_tensor, out_p, *dev_ctx,
...@@ -78,35 +94,21 @@ void BroadcastOpHandle::RunImpl() { ...@@ -78,35 +94,21 @@ void BroadcastOpHandle::RunImpl() {
} }
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(in_tensor.place())); VarHandle *out_handle = nullptr;
VarHandle *out_handle; int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
int root = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
std::vector<std::function<void()>> broadcast_calls; std::vector<std::function<void()>> broadcast_calls;
for (size_t j = 0; j < out_var_handles.size(); ++j) { for (auto out_var_handle : out_var_handles) {
VarHandle *out_var_handle = out_var_handles[j];
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name_);
if (*out_var_handle != *in_var_handle) { int dst_id =
PADDLE_ENFORCE_NOT_NULL(out_var); boost::get<platform::CUDAPlace>(out_var_handle->place_).device;
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(),
in_tensor.place().which(),
"Places must be all on CPU or all on CUDA.");
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(
out_var_handle->place_, in_tensor.type());
}
auto out_p = out_var_handle->place_;
int dev_id = boost::get<platform::CUDAPlace>(out_p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dst_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
void *send_recv_buffer = nullptr; void *send_recv_buffer = nullptr;
if (root == dev_id) { if (root_id == dst_id) {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>()); send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
out_handle = out_var_handle; out_handle = out_var_handle;
} else { } else {
...@@ -116,10 +118,12 @@ void BroadcastOpHandle::RunImpl() { ...@@ -116,10 +118,12 @@ void BroadcastOpHandle::RunImpl() {
} }
int type = platform::ToNCCLDataType(in_tensor.type()); int type = platform::ToNCCLDataType(in_tensor.type());
broadcast_calls.emplace_back([=] { size_t numel = static_cast<size_t>(in_tensor.numel());
broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
send_recv_buffer, in_tensor.numel(), send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
static_cast<ncclDataType_t>(type), root, comm, stream)); root_id, nccl_ctx.comm_, nccl_ctx.stream()));
}); });
} }
...@@ -130,6 +134,7 @@ void BroadcastOpHandle::RunImpl() { ...@@ -130,6 +134,7 @@ void BroadcastOpHandle::RunImpl() {
call(); call();
} }
} }
// TODO(zcd): Maybe the unequal operator is not appropriate here.
if (*out_handle != *in_var_handle) { if (*out_handle != *in_var_handle) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_) auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_); ->FindVar(out_var_handles[0]->name_);
...@@ -140,7 +145,7 @@ void BroadcastOpHandle::RunImpl() { ...@@ -140,7 +145,7 @@ void BroadcastOpHandle::RunImpl() {
} }
}); });
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not enabled.");
#endif #endif
} }
} }
......
...@@ -36,7 +36,6 @@ void GatherOpHandle::RunImpl() { ...@@ -36,7 +36,6 @@ void GatherOpHandle::RunImpl() {
VarHandle *out_var_handle; VarHandle *out_var_handle;
{ {
auto out_var_handles = DynamicCast<VarHandle>(outputs_); auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one."); "The number of output should be one.");
out_var_handle = out_var_handles.front(); out_var_handle = out_var_handles.front();
...@@ -51,43 +50,39 @@ void GatherOpHandle::RunImpl() { ...@@ -51,43 +50,39 @@ void GatherOpHandle::RunImpl() {
auto pre_in_var = auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_);
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows."); "Currently, gather_op only can gather SelectedRows.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated(in_var_handles);
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors; std::vector<Tensor> in_tensors;
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>(); // Gather the inputs
// gather the inputs
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto *in_var = auto *in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var);
auto &in_sr_value = in_var->Get<framework::SelectedRows>(); auto &in_sr_value = in_var->Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(in_sr_value.place().which(), pre_in_value.place().which(),
"Places must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(in_sr_value.value().type(), pre_in_value.value().type(),
"The type of input is not consistent.");
PADDLE_ENFORCE_EQ(in_sr_value.height(), pre_in_value.height(),
"The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(in_sr_value.GetCompleteDims(),
pre_in_value.GetCompleteDims(),
"The dims of inputs is not consistent.");
auto &in_sr_rows = in_sr_value.rows(); auto &in_sr_rows = in_sr_value.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
in_tensors.emplace_back(in_sr_value.value()); in_tensors.emplace_back(in_sr_value.value());
} }
// write the output // TODO(zcd): The Place of var_handle is determined at building SSA graph
// stage, while the Place of var is determined at runtime. If they are
// different, DataTransform should be applied. Currently, it has not been done
// yet.
auto &out_place = out_var_handle->place_; auto &out_place = out_var_handle->place_;
PADDLE_ENFORCE_EQ(out_place.which(), pre_in_value.place().which(), PADDLE_ENFORCE_EQ(out_place.which(), pre_in_value.place().which(),
"Places must be all on CPU or all on GPU."); "Currently, Places of input and output must be all on CPU "
"or all on GPU.");
auto out_var = auto out_var =
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
...@@ -97,19 +92,18 @@ void GatherOpHandle::RunImpl() { ...@@ -97,19 +92,18 @@ void GatherOpHandle::RunImpl() {
size_t rows = out_rows.size(); size_t rows = out_rows.size();
DDim out_dim = pre_in_value.GetCompleteDims(); DDim out_dim = pre_in_value.GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows); out_dim[0] = static_cast<int64_t>(rows);
out_value->mutable_value()->Resize(out_dim); out_value->mutable_value()->Resize(out_dim).mutable_data(
out_value->mutable_value()->mutable_data(out_place, out_place, pre_in_value.value().type());
pre_in_value.value().type());
Tensor *out_tensor = out_value->mutable_value(); Tensor *out_tensor = out_value->mutable_value();
// copy // copy
auto dev_ctx = dev_ctxes_[out_place]; auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] { RunAndRecordEvent(out_place, [in_tensors, out_tensor, &dev_ctx, out_place] {
int s = 0, e = 0; int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) { for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0]; e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e); auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx), paddle::framework::TensorCopy(in_tensors[j], out_place, *dev_ctx,
&sub_out); &sub_out);
s = e; s = e;
} }
......
...@@ -116,13 +116,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -116,13 +116,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
// size_t cur_device_id = 0; size_t cur_update_sparse_gp_dev_id = 0;
size_t update_sparse_gp_device_id = 0; std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
std::vector<std::unordered_set<std::string>> var_name_on_devices; std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size()); sparse_var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size()); bcast_sparse_var_name_set.resize(places_.size());
// Find "send" op first for split is in front of send. // Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program); OpDesc *send_op = GetSendOpDesc(program);
...@@ -142,13 +141,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -142,13 +141,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); int op_dev_id = GetOpDeviceID(sparse_var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
} else { } else {
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name); sparse_var_name_on_devices[op_dev_id].emplace(var_name);
} }
} }
...@@ -158,10 +157,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -158,10 +157,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (IsSparseGradient(og)) { if (IsSparseGradient(og)) {
CreateReduceOp(&result, update_sparse_gp_device_id, og); CreateReduceOp(&result, cur_update_sparse_gp_dev_id, og);
var_name_on_devices[update_sparse_gp_device_id].emplace(og); sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace(
bcast_var_name_set[update_sparse_gp_device_id].emplace( og);
bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix))); og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_update_sparse_gp_dev_id =
(cur_update_sparse_gp_dev_id + 1) % places_.size();
} else { } else {
InsertNCCLAllReduceOp(&result, og); InsertNCCLAllReduceOp(&result, og);
} }
...@@ -172,8 +174,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -172,8 +174,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
// Insert BCast Ops // Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_sparse_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id]; auto &to_bcast_set = bcast_sparse_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) { for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id); CreateBroadcastOp(&result, bcast_name, dev_id);
} }
...@@ -206,13 +208,14 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { ...@@ -206,13 +208,14 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices, const std::vector<std::unordered_set<std::string>>
&sparse_var_name_on_devices,
const OpDesc &op) const { const OpDesc &op) const {
int var_dev_id = -1; int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) { for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break; if (var_dev_id != -1) break;
for (size_t i = 0; i < var_name_on_devices.size(); ++i) { for (size_t i = 0; i < sparse_var_name_on_devices.size(); ++i) {
if (var_name_on_devices[i].count(var_name)) { if (sparse_var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i); var_dev_id = static_cast<int>(i);
break; break;
} }
......
...@@ -52,27 +52,30 @@ void ReduceOpHandle::RunImpl() { ...@@ -52,27 +52,30 @@ void ReduceOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated(in_var_handles);
auto pre_place = in_0_handle->place_;
std::vector<platform::Place> in_places; // used to get dev_ctx std::vector<platform::Place> in_places; // used to get dev_ctx
auto pre_in_tensor = VariableVisitor::GetMutableTensor(pre_in_var);
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
in_places.emplace_back(in_handle->place_); in_places.emplace_back(in_handle->place_);
auto in_var = auto in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*pre_in_var, *in_var);
auto in_tensor = VariableVisitor::GetMutableTensor(in_var);
PADDLE_ENFORCE_EQ(pre_in_tensor.place().which(), in_tensor.place().which(),
"Places must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(in_tensor.type(), pre_in_tensor.type(),
"The type of input is not consistent.");
} }
auto out_var = auto out_var =
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
// TODO(zcd): The Place of var_handle is determined at building SSA graph
// stage, while the Place of var is determined at runtime. If they are
// different, DataTransform should be applied. Currently, it has not been done
// yet.
PADDLE_ENFORCE_EQ(
VariableVisitor::GetMutableTensor(pre_in_var).place().which(),
out_var_handle->place_.which(),
"Currently, Places of input and output must be all on CPU or all on "
"GPU.");
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<framework::SelectedRows>()) {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
...@@ -96,7 +99,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -96,7 +99,7 @@ void ReduceOpHandle::RunImpl() {
out_var_handle->place_, pre_in.type()); out_var_handle->place_, pre_in.type());
auto out_p = out_var_handle->place_; auto out_p = out_var_handle->place_;
int root = boost::get<platform::CUDAPlace>(out_p).device; int root_id = boost::get<platform::CUDAPlace>(out_p).device;
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < var_scopes.size(); ++i) { for (size_t i = 0; i < var_scopes.size(); ++i) {
auto &p = in_places[i]; auto &p = in_places[i];
...@@ -104,22 +107,22 @@ void ReduceOpHandle::RunImpl() { ...@@ -104,22 +107,22 @@ void ReduceOpHandle::RunImpl() {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
void *buffer = const_cast<void *>(lod_tensor.data<void>()); void *buffer = const_cast<void *>(lod_tensor.data<void>());
void *recvbuffer = nullptr; void *recvbuffer = nullptr;
if (root == dev_id) { if (root_id == dev_id) {
recvbuffer = recvbuffer =
out_var->GetMutable<framework::LoDTensor>()->mutable_data( out_var->GetMutable<framework::LoDTensor>()->mutable_data(
out_var_handle->place_); out_var_handle->place_);
} }
int type = platform::ToNCCLDataType(lod_tensor.type()); int type = platform::ToNCCLDataType(lod_tensor.type());
all_reduce_calls.emplace_back([=] { size_t numel = static_cast<size_t>(lod_tensor.numel());
all_reduce_calls.emplace_back(
[buffer, recvbuffer, type, numel, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclReduce( PADDLE_ENFORCE(platform::dynload::ncclReduce(
buffer, recvbuffer, static_cast<size_t>(lod_tensor.numel()), buffer, recvbuffer, numel, static_cast<ncclDataType_t>(type),
static_cast<ncclDataType_t>(type), ncclSum, root, comm, stream)); ncclSum, root_id, nccl_ctx.comm_, nccl_ctx.stream()));
}); });
} }
...@@ -130,7 +133,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -130,7 +133,7 @@ void ReduceOpHandle::RunImpl() {
} }
}); });
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not enabled.");
#endif #endif
} else { } else {
PADDLE_THROW("Place should be CPUPlace or CUDAPlace."); PADDLE_THROW("Place should be CPUPlace or CUDAPlace.");
......
...@@ -47,17 +47,6 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -47,17 +47,6 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
} }
} }
VarHandle *SSAGraphBuilder::GetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name,
size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holder = var_holders[each_var_name];
if (var_holder.empty()) {
return nullptr;
}
return var_holder.rbegin()->get();
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name, SSAGraph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) { const platform::Place &place, size_t place_offset) {
......
...@@ -62,6 +62,16 @@ struct VarHandle : public VarHandleBase { ...@@ -62,6 +62,16 @@ struct VarHandle : public VarHandleBase {
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
// NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four
// member
// variables(version_, scope_id_, name_, place_) must be equal. But sometimes
// judging whether the two var_handle is equal is actually to determine
// whether
// the two Variables that represented by var_handle is the same. And the same
// Variable may have many different var_handles, the version_ of these
// var_handles
// is different. So I don't take care of version_ temporarily when overloading
// equal.
bool operator==(const VarHandle& o) const { bool operator==(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ && return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_; o.scope_idx_ == scope_idx_;
......
...@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { ...@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
VisitVariable(src, &visitor); VisitVariable(src, &visitor);
} }
struct EnforceEqualShapeAndDTypeVisitor {
const Variable* trg_;
void operator()(const LoDTensor& src) {
auto& tensor = trg_->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
src.place().which(), tensor.place().which(),
"The Places of the two Variable must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(src.type(), tensor.type(),
"The dtype of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.dims(), tensor.dims(),
"The dims of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.lod(), tensor.lod(),
"The lod of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.layout(), tensor.layout(),
"The layout of the two Variable's tensor is not equal.");
}
void operator()(const SelectedRows& src) {
auto& selected_rows = trg_->Get<SelectedRows>();
PADDLE_ENFORCE_EQ(
src.place().which(), selected_rows.place().which(),
"The Places of the two Variable must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(src.value().type(), selected_rows.value().type(),
"The dtype of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.value().layout(), selected_rows.value().layout(),
"The layout of the two Variable's tensor is not equal.");
PADDLE_ENFORCE_EQ(src.height(), selected_rows.height(),
"The height of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.GetCompleteDims(), selected_rows.GetCompleteDims(),
"The dims of the two Variable is not equal.");
}
template <typename T>
void operator()(const T&) {
PADDLE_ENFORCE("EnforceShapeAndDTypeEQ is not supported by type %s",
typeid(T).name());
}
};
void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1,
const Variable& var2) {
EnforceEqualShapeAndDTypeVisitor visitor{&var1};
VisitVariable(var2, &visitor);
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,6 +26,9 @@ class VariableVisitor { ...@@ -26,6 +26,9 @@ class VariableVisitor {
static Tensor &GetMutableTensor(Variable *var); static Tensor &GetMutableTensor(Variable *var);
static void ShareDimsAndLoD(const Variable &src, Variable *trg); static void ShareDimsAndLoD(const Variable &src, Variable *trg);
static void EnforceShapeAndDTypeEQ(const Variable &var1,
const Variable &var2);
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册