From 7722baa8e328cc0d34ff30731442ba93993a2ec5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 4 May 2018 15:05:37 +0800 Subject: [PATCH] follow comments and clean code --- .../framework/details/broadcast_op_handle.cc | 87 ++++++++++--------- .../framework/details/gather_op_handle.cc | 34 +++----- .../details/multi_devices_graph_builder.cc | 35 ++++---- .../framework/details/reduce_op_handle.cc | 41 +++++---- .../framework/details/ssa_graph_builder.cc | 11 --- paddle/fluid/framework/details/var_handle.h | 10 +++ .../framework/details/variable_visitor.cc | 46 ++++++++++ .../framework/details/variable_visitor.h | 3 + 8 files changed, 160 insertions(+), 107 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 4f41579027b..756fd2afd62 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -22,9 +22,9 @@ namespace details { void BroadcastOpHandle::RunImpl() { if (places_.size() == 1) return; - // the input and output may have dummy var. - VarHandle *in_var_handle; + // The input and output may have dummy vars. + VarHandle *in_var_handle; { auto in_var_handles = DynamicCast(inputs_); PADDLE_ENFORCE_EQ(in_var_handles.size(), 1, @@ -53,23 +53,39 @@ void BroadcastOpHandle::RunImpl() { Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); + // NOTE(zcd): the Place of input can be get from in_tensor and in_var_handle , + // maybe they are different, because the Place that getting from in_tensor is + // determined at runtime, the other is determined at building SSA graph stage. + // If they are different, DataTransform should be applied. Currently, it has + // not been done yet. + for (auto *out_var_handle : out_var_handles) { + if (*out_var_handle == *in_var_handle) { + continue; + } + auto &out_p = out_var_handle->place_; + auto *out_var = var_scopes.at(out_var_handle->scope_idx_) + ->FindVar(out_var_handle->name_); + PADDLE_ENFORCE_NOT_NULL(out_var); + PADDLE_ENFORCE_EQ( + out_p.which(), in_tensor.place().which(), + "Currently, Places of input and output must be all on CPU " + "or all on GPU."); + VariableVisitor::ShareDimsAndLoD(*in_var, out_var); + VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, + in_tensor.type()); + } + if (platform::is_cpu_place(in_tensor.place())) { - for (auto *out : out_var_handles) { - if (*out == *in_var_handle) { + for (auto *out_var_handle : out_var_handles) { + if (*out_var_handle == *in_var_handle) { continue; } - auto &out_p = out->place_; - auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_); - PADDLE_ENFORCE_NOT_NULL(out_var); - PADDLE_ENFORCE_EQ(out_p.which(), in_tensor.place().which(), - "Places must be all on CPU or all on CUDA."); - - VariableVisitor::ShareDimsAndLoD(*in_var, out_var); - VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, - in_tensor.type()); - + auto &out_p = out_var_handle->place_; auto dev_ctx = dev_ctxes_.at(out_p); + auto *out_var = var_scopes.at(out_var_handle->scope_idx_) + ->FindVar(out_var_handle->name_); + RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { paddle::framework::TensorCopy( in_tensor, out_p, *dev_ctx, @@ -78,35 +94,21 @@ void BroadcastOpHandle::RunImpl() { } } else { #ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE(platform::is_gpu_place(in_tensor.place())); - VarHandle *out_handle; - int root = boost::get(in_tensor.place()).device; + VarHandle *out_handle = nullptr; + int root_id = boost::get(in_tensor.place()).device; std::vector> broadcast_calls; - for (size_t j = 0; j < out_var_handles.size(); ++j) { - VarHandle *out_var_handle = out_var_handles[j]; + for (auto out_var_handle : out_var_handles) { Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) ->FindVar(out_var_handle->name_); - if (*out_var_handle != *in_var_handle) { - PADDLE_ENFORCE_NOT_NULL(out_var); - PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), - in_tensor.place().which(), - "Places must be all on CPU or all on CUDA."); - VariableVisitor::ShareDimsAndLoD(*in_var, out_var); - VariableVisitor::GetMutableTensor(out_var).mutable_data( - out_var_handle->place_, in_tensor.type()); - } + int dst_id = + boost::get(out_var_handle->place_).device; - auto out_p = out_var_handle->place_; - int dev_id = boost::get(out_p).device; - - auto &nccl_ctx = nccl_ctxs_->at(dev_id); - auto stream = nccl_ctx.stream(); - auto comm = nccl_ctx.comm_; + auto &nccl_ctx = nccl_ctxs_->at(dst_id); void *send_recv_buffer = nullptr; - if (root == dev_id) { + if (root_id == dst_id) { send_recv_buffer = const_cast(in_tensor.data()); out_handle = out_var_handle; } else { @@ -116,11 +118,13 @@ void BroadcastOpHandle::RunImpl() { } int type = platform::ToNCCLDataType(in_tensor.type()); - broadcast_calls.emplace_back([=] { - PADDLE_ENFORCE(platform::dynload::ncclBcast( - send_recv_buffer, in_tensor.numel(), - static_cast(type), root, comm, stream)); - }); + size_t numel = static_cast(in_tensor.numel()); + broadcast_calls.emplace_back( + [send_recv_buffer, numel, type, root_id, &nccl_ctx] { + PADDLE_ENFORCE(platform::dynload::ncclBcast( + send_recv_buffer, numel, static_cast(type), + root_id, nccl_ctx.comm_, nccl_ctx.stream())); + }); } this->RunAndRecordEvent([&] { @@ -130,6 +134,7 @@ void BroadcastOpHandle::RunImpl() { call(); } } + // TODO(zcd): Maybe the unequal operator is not appropriate here. if (*out_handle != *in_var_handle) { auto out_var = var_scopes.at(in_var_handle->scope_idx_) ->FindVar(out_var_handles[0]->name_); @@ -140,7 +145,7 @@ void BroadcastOpHandle::RunImpl() { } }); #else - PADDLE_THROW("CUDA is not support."); + PADDLE_THROW("CUDA is not enabled."); #endif } } diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 43145f44c27..021703f1e91 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -36,7 +36,6 @@ void GatherOpHandle::RunImpl() { VarHandle *out_var_handle; { auto out_var_handles = DynamicCast(outputs_); - PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, "The number of output should be one."); out_var_handle = out_var_handles.front(); @@ -51,43 +50,39 @@ void GatherOpHandle::RunImpl() { auto pre_in_var = var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); PADDLE_ENFORCE_NOT_NULL(pre_in_var); + PADDLE_ENFORCE(pre_in_var->IsType(), "Currently, gather_op only can gather SelectedRows."); // Wait input done, this Wait is asynchronous operation WaitInputVarGenerated(in_var_handles); + auto &pre_in_value = pre_in_var->Get(); std::vector out_rows; std::vector in_tensors; - auto &pre_in_value = pre_in_var->Get(); - // gather the inputs + // Gather the inputs for (auto *in_handle : in_var_handles) { auto *in_var = var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); PADDLE_ENFORCE_NOT_NULL(in_var); + VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var); auto &in_sr_value = in_var->Get(); - PADDLE_ENFORCE_EQ(in_sr_value.place().which(), pre_in_value.place().which(), - "Places must be all on CPU or all on GPU."); - PADDLE_ENFORCE_EQ(in_sr_value.value().type(), pre_in_value.value().type(), - "The type of input is not consistent."); - PADDLE_ENFORCE_EQ(in_sr_value.height(), pre_in_value.height(), - "The height of inputs is not consistent."); - PADDLE_ENFORCE_EQ(in_sr_value.GetCompleteDims(), - pre_in_value.GetCompleteDims(), - "The dims of inputs is not consistent."); - auto &in_sr_rows = in_sr_value.rows(); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); in_tensors.emplace_back(in_sr_value.value()); } - // write the output + // TODO(zcd): The Place of var_handle is determined at building SSA graph + // stage, while the Place of var is determined at runtime. If they are + // different, DataTransform should be applied. Currently, it has not been done + // yet. auto &out_place = out_var_handle->place_; PADDLE_ENFORCE_EQ(out_place.which(), pre_in_value.place().which(), - "Places must be all on CPU or all on GPU."); + "Currently, Places of input and output must be all on CPU " + "or all on GPU."); auto out_var = var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); PADDLE_ENFORCE_NOT_NULL(out_var); @@ -97,19 +92,18 @@ void GatherOpHandle::RunImpl() { size_t rows = out_rows.size(); DDim out_dim = pre_in_value.GetCompleteDims(); out_dim[0] = static_cast(rows); - out_value->mutable_value()->Resize(out_dim); - out_value->mutable_value()->mutable_data(out_place, - pre_in_value.value().type()); + out_value->mutable_value()->Resize(out_dim).mutable_data( + out_place, pre_in_value.value().type()); Tensor *out_tensor = out_value->mutable_value(); // copy auto dev_ctx = dev_ctxes_[out_place]; - RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] { + RunAndRecordEvent(out_place, [in_tensors, out_tensor, &dev_ctx, out_place] { int s = 0, e = 0; for (size_t j = 0; j < in_tensors.size(); ++j) { e += in_tensors[j].dims()[0]; auto sub_out = out_tensor->Slice(s, e); - paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx), + paddle::framework::TensorCopy(in_tensors[j], out_place, *dev_ctx, &sub_out); s = e; } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 37d69c4b56c..da524cc7928 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -116,13 +116,12 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); - // size_t cur_device_id = 0; - size_t update_sparse_gp_device_id = 0; - std::vector> var_name_on_devices; - std::vector> bcast_var_name_set; + size_t cur_update_sparse_gp_dev_id = 0; + std::vector> sparse_var_name_on_devices; + std::vector> bcast_sparse_var_name_set; - var_name_on_devices.resize(places_.size()); - bcast_var_name_set.resize(places_.size()); + sparse_var_name_on_devices.resize(places_.size()); + bcast_sparse_var_name_set.resize(places_.size()); // Find "send" op first for split is in front of send. OpDesc *send_op = GetSendOpDesc(program); @@ -142,13 +141,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); + int op_dev_id = GetOpDeviceID(sparse_var_name_on_devices, *op); if (op_dev_id == -1) { // var on all device CreateComputationalOps(&result, *op, places_.size()); } else { CreateComputationalOp(&result, *op, op_dev_id); for (auto &var_name : op->OutputArgumentNames()) { - var_name_on_devices[op_dev_id].emplace(var_name); + sparse_var_name_on_devices[op_dev_id].emplace(var_name); } } @@ -158,10 +157,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( for (auto &og : op->OutputArgumentNames()) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsSparseGradient(og)) { - CreateReduceOp(&result, update_sparse_gp_device_id, og); - var_name_on_devices[update_sparse_gp_device_id].emplace(og); - bcast_var_name_set[update_sparse_gp_device_id].emplace( + CreateReduceOp(&result, cur_update_sparse_gp_dev_id, og); + sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace( + og); + bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace( og.substr(0, og.size() - strlen(kGradVarSuffix))); + cur_update_sparse_gp_dev_id = + (cur_update_sparse_gp_dev_id + 1) % places_.size(); } else { InsertNCCLAllReduceOp(&result, og); } @@ -172,8 +174,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } // Insert BCast Ops - for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { - auto &to_bcast_set = bcast_var_name_set[dev_id]; + for (size_t dev_id = 0; dev_id < bcast_sparse_var_name_set.size(); ++dev_id) { + auto &to_bcast_set = bcast_sparse_var_name_set[dev_id]; for (auto &bcast_name : to_bcast_set) { CreateBroadcastOp(&result, bcast_name, dev_id); } @@ -206,13 +208,14 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { } int MultiDevSSAGraphBuilder::GetOpDeviceID( - const std::vector> &var_name_on_devices, + const std::vector> + &sparse_var_name_on_devices, const OpDesc &op) const { int var_dev_id = -1; for (auto &var_name : op.InputArgumentNames()) { if (var_dev_id != -1) break; - for (size_t i = 0; i < var_name_on_devices.size(); ++i) { - if (var_name_on_devices[i].count(var_name)) { + for (size_t i = 0; i < sparse_var_name_on_devices.size(); ++i) { + if (sparse_var_name_on_devices[i].count(var_name)) { var_dev_id = static_cast(i); break; } diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index f06cb024cf0..5ee7008b5b6 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -52,27 +52,30 @@ void ReduceOpHandle::RunImpl() { // Wait input done, this Wait is asynchronous operation WaitInputVarGenerated(in_var_handles); - auto pre_place = in_0_handle->place_; + std::vector in_places; // used to get dev_ctx - auto pre_in_tensor = VariableVisitor::GetMutableTensor(pre_in_var); for (auto *in_handle : in_var_handles) { in_places.emplace_back(in_handle->place_); - auto in_var = var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); PADDLE_ENFORCE_NOT_NULL(in_var); - - auto in_tensor = VariableVisitor::GetMutableTensor(in_var); - PADDLE_ENFORCE_EQ(pre_in_tensor.place().which(), in_tensor.place().which(), - "Places must be all on CPU or all on GPU."); - PADDLE_ENFORCE_EQ(in_tensor.type(), pre_in_tensor.type(), - "The type of input is not consistent."); + VariableVisitor::EnforceShapeAndDTypeEQ(*pre_in_var, *in_var); } auto out_var = var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); PADDLE_ENFORCE_NOT_NULL(out_var); + // TODO(zcd): The Place of var_handle is determined at building SSA graph + // stage, while the Place of var is determined at runtime. If they are + // different, DataTransform should be applied. Currently, it has not been done + // yet. + PADDLE_ENFORCE_EQ( + VariableVisitor::GetMutableTensor(pre_in_var).place().which(), + out_var_handle->place_.which(), + "Currently, Places of input and output must be all on CPU or all on " + "GPU."); + if (pre_in_var->IsType()) { std::vector in_selected_rows = GetInputValues(in_var_handles, var_scopes); @@ -96,7 +99,7 @@ void ReduceOpHandle::RunImpl() { out_var_handle->place_, pre_in.type()); auto out_p = out_var_handle->place_; - int root = boost::get(out_p).device; + int root_id = boost::get(out_p).device; std::vector> all_reduce_calls; for (size_t i = 0; i < var_scopes.size(); ++i) { auto &p = in_places[i]; @@ -104,23 +107,23 @@ void ReduceOpHandle::RunImpl() { int dev_id = boost::get(p).device; auto &nccl_ctx = nccl_ctxs_->at(dev_id); - auto stream = nccl_ctx.stream(); - auto comm = nccl_ctx.comm_; void *buffer = const_cast(lod_tensor.data()); void *recvbuffer = nullptr; - if (root == dev_id) { + if (root_id == dev_id) { recvbuffer = out_var->GetMutable()->mutable_data( out_var_handle->place_); } int type = platform::ToNCCLDataType(lod_tensor.type()); - all_reduce_calls.emplace_back([=] { - PADDLE_ENFORCE(platform::dynload::ncclReduce( - buffer, recvbuffer, static_cast(lod_tensor.numel()), - static_cast(type), ncclSum, root, comm, stream)); - }); + size_t numel = static_cast(lod_tensor.numel()); + all_reduce_calls.emplace_back( + [buffer, recvbuffer, type, numel, root_id, &nccl_ctx] { + PADDLE_ENFORCE(platform::dynload::ncclReduce( + buffer, recvbuffer, numel, static_cast(type), + ncclSum, root_id, nccl_ctx.comm_, nccl_ctx.stream())); + }); } this->RunAndRecordEvent([&] { @@ -130,7 +133,7 @@ void ReduceOpHandle::RunImpl() { } }); #else - PADDLE_THROW("CUDA is not support."); + PADDLE_THROW("CUDA is not enabled."); #endif } else { PADDLE_THROW("Place should be CPUPlace or CUDAPlace."); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 1538744711d..6a567527550 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -47,17 +47,6 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { } } -VarHandle *SSAGraphBuilder::GetLatestVarHandle(SSAGraph *graph, - const std::string &each_var_name, - size_t place_offset) { - auto &var_holders = graph->vars_[place_offset]; - auto &var_holder = var_holders[each_var_name]; - if (var_holder.empty()) { - return nullptr; - } - return var_holder.rbegin()->get(); -} - VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( SSAGraph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 99e5eb2b438..2ccd76df852 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -62,6 +62,16 @@ struct VarHandle : public VarHandleBase { std::string name_; platform::Place place_; + // NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four + // member + // variables(version_, scope_id_, name_, place_) must be equal. But sometimes + // judging whether the two var_handle is equal is actually to determine + // whether + // the two Variables that represented by var_handle is the same. And the same + // Variable may have many different var_handles, the version_ of these + // var_handles + // is different. So I don't take care of version_ temporarily when overloading + // equal. bool operator==(const VarHandle& o) const { return o.generated_op_ == generated_op_ && o.name_ == name_ && o.scope_idx_ == scope_idx_; diff --git a/paddle/fluid/framework/details/variable_visitor.cc b/paddle/fluid/framework/details/variable_visitor.cc index 10bac0fae95..99487a304fa 100644 --- a/paddle/fluid/framework/details/variable_visitor.cc +++ b/paddle/fluid/framework/details/variable_visitor.cc @@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { VisitVariable(src, &visitor); } +struct EnforceEqualShapeAndDTypeVisitor { + const Variable* trg_; + + void operator()(const LoDTensor& src) { + auto& tensor = trg_->Get(); + PADDLE_ENFORCE_EQ( + src.place().which(), tensor.place().which(), + "The Places of the two Variable must be all on CPU or all on GPU."); + PADDLE_ENFORCE_EQ(src.type(), tensor.type(), + "The dtype of the two Variable is not equal."); + PADDLE_ENFORCE_EQ(src.dims(), tensor.dims(), + "The dims of the two Variable is not equal."); + PADDLE_ENFORCE_EQ(src.lod(), tensor.lod(), + "The lod of the two Variable is not equal."); + PADDLE_ENFORCE_EQ(src.layout(), tensor.layout(), + "The layout of the two Variable's tensor is not equal."); + } + + void operator()(const SelectedRows& src) { + auto& selected_rows = trg_->Get(); + PADDLE_ENFORCE_EQ( + src.place().which(), selected_rows.place().which(), + "The Places of the two Variable must be all on CPU or all on GPU."); + PADDLE_ENFORCE_EQ(src.value().type(), selected_rows.value().type(), + "The dtype of the two Variable is not equal."); + PADDLE_ENFORCE_EQ(src.value().layout(), selected_rows.value().layout(), + "The layout of the two Variable's tensor is not equal."); + PADDLE_ENFORCE_EQ(src.height(), selected_rows.height(), + "The height of the two Variable is not equal."); + PADDLE_ENFORCE_EQ(src.GetCompleteDims(), selected_rows.GetCompleteDims(), + "The dims of the two Variable is not equal."); + } + + template + void operator()(const T&) { + PADDLE_ENFORCE("EnforceShapeAndDTypeEQ is not supported by type %s", + typeid(T).name()); + } +}; + +void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1, + const Variable& var2) { + EnforceEqualShapeAndDTypeVisitor visitor{&var1}; + VisitVariable(var2, &visitor); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/variable_visitor.h b/paddle/fluid/framework/details/variable_visitor.h index 67baa1895e4..ca9a19bdcf1 100644 --- a/paddle/fluid/framework/details/variable_visitor.h +++ b/paddle/fluid/framework/details/variable_visitor.h @@ -26,6 +26,9 @@ class VariableVisitor { static Tensor &GetMutableTensor(Variable *var); static void ShareDimsAndLoD(const Variable &src, Variable *trg); + + static void EnforceShapeAndDTypeEQ(const Variable &var1, + const Variable &var2); }; } // namespace details -- GitLab