From d24ef931b5beb78b16e9e6eae1b692c11dad2271 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 18 Apr 2018 14:13:38 +0800 Subject: [PATCH] Clean Code --- paddle/fluid/framework/details/CMakeLists.txt | 3 + .../framework/details/broadcast_op_handle.cc | 98 +++++++------------ .../framework/details/broadcast_op_handle.h | 6 +- .../fluid/framework/details/container_cast.h | 40 ++++++++ .../framework/details/gather_op_handle.cc | 54 +++++----- .../framework/details/gather_op_handle.h | 5 +- paddle/fluid/framework/details/var_handle.h | 5 + .../framework/details/variable_visitor.cc | 93 ++++++++++++++++++ .../framework/details/variable_visitor.h | 33 +++++++ 9 files changed, 234 insertions(+), 103 deletions(-) create mode 100644 paddle/fluid/framework/details/container_cast.h create mode 100644 paddle/fluid/framework/details/variable_visitor.cc create mode 100644 paddle/fluid/framework/details/variable_visitor.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 897e41f79f..5d1b34537c 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -24,6 +24,9 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory) +cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) + + cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory device_context broadcast_op_handle) cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 05bab5334a..0fb54a1d3e 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -13,102 +13,70 @@ // limitations under the License. #include "paddle/fluid/framework/details/broadcast_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/variable_visitor.h" namespace paddle { namespace framework { namespace details { - -Tensor *GetTensorFromVar(Variable *in_var) { - if (in_var->IsType()) { - return in_var->GetMutable(); - } else if (in_var->IsType()) { - return in_var->GetMutable()->mutable_value(); - } else { - PADDLE_THROW("Var should be LoDTensor or SelectedRows"); - } - return nullptr; -} - BroadcastOpHandle::BroadcastOpHandle(const std::vector &local_scopes, const std::vector &places) : local_scopes_(local_scopes), places_(places) {} void BroadcastOpHandle::RunImpl() { // the input and output may have dummy var. - std::vector in_var_handle = GetValidVarHandles(inputs_); - std::vector out_var_handles = GetValidVarHandles(outputs_); + VarHandle *in_var_handle; + + { + auto in_var_handles = DynamicCast(inputs_); + PADDLE_ENFORCE_EQ(in_var_handles.size(), 1, + "The number of input should be one."); + in_var_handle = in_var_handles[0]; + } + + auto out_var_handles = DynamicCast(outputs_); - PADDLE_ENFORCE_EQ(in_var_handle.size(), 1, - "The number of input should be one."); PADDLE_ENFORCE_EQ( out_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - // Wait input done, this Wait is asynchronous operationplatform::Place + // Wait input done, this Wait is asynchronous operation platform::Place // &in_place; - WaitEvents(out_var_handles, in_var_handle); + WaitInputVarGenerated(*in_var_handle); - auto in_place = in_var_handle[0]->place_; - auto in_scope_idx = in_var_handle[0]->scope_idx_; - auto in_var = - local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); - Tensor *in_tensor = GetTensorFromVar(in_var); + auto *in_var = local_scopes_.at(in_var_handle->scope_idx_) + ->FindVar(in_var_handle->name_); + PADDLE_ENFORCE_NOT_NULL(in_var); + Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); for (auto *out : out_var_handles) { + if (*out == *in_var_handle) { + continue; + } + auto &out_p = out->place_; - auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); + auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); - PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), + PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(), "Places must be all on CPU or all on CUDA."); - if (in_var->IsType()) { - auto &in_sr = in_var->Get(); - auto out_sr = out_var->GetMutable(); - if (&in_sr == out_sr) continue; - out_sr->set_height(in_sr.height()); - out_sr->set_rows(in_sr.rows()); - out_sr->mutable_value()->Resize(in_sr.value().dims()); - out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type()); - } else if (in_var->IsType()) { - auto in_lod = in_var->Get(); - auto out_lod = out_var->GetMutable(); - if (&in_lod == out_lod) continue; - out_lod->set_lod(in_lod.lod()); - out_lod->Resize(in_lod.dims()); - out_lod->mutable_data(out_p, in_lod.type()); - } else { - PADDLE_THROW("Var should be LoDTensor or SelectedRows."); - } + VariableVisitor::ShareDimsAndLoD(*in_var, out_var); + VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, + in_tensor.type()); auto dev_ctx = dev_ctxes_[out_p]; RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { - Tensor *out_tensor = GetTensorFromVar(out_var); - paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctx), out_tensor); + paddle::framework::TensorCopy( + in_tensor, out_p, *(dev_ctx), + &VariableVisitor::GetMutableTensor(out_var)); }); } } -void BroadcastOpHandle::WaitEvents( - const std::vector &out_var_handles, - const std::vector &in_var_handle) { - if (in_var_handle[0]->generated_op_) { - for (auto *out : out_var_handles) { - auto &out_p = out->place_; - in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]); - } - } -} - -std::vector BroadcastOpHandle::GetValidVarHandles( - const std::vector &inputs) { - std::vector in_var_handle; - for (auto *in : inputs) { - auto *out_handle = dynamic_cast(in); - if (out_handle) { - in_var_handle.push_back(out_handle); - } +void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { + for (auto &pair : dev_ctxes_) { + in_var.generated_op_->Wait(pair.second); } - return in_var_handle; } std::string BroadcastOpHandle::Name() const { return "broadcast"; } diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index e1311aceaf..2a0d70f8ea 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -42,11 +42,7 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - std::vector GetValidVarHandles( - const std::vector &inputs); - - void WaitEvents(const std::vector &out_var_handles, - const std::vector &in_var_handle); + void WaitInputVarGenerated(const VarHandle &in_var); }; } // namespace details diff --git a/paddle/fluid/framework/details/container_cast.h b/paddle/fluid/framework/details/container_cast.h new file mode 100644 index 0000000000..a42ae78dc4 --- /dev/null +++ b/paddle/fluid/framework/details/container_cast.h @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +template +std::vector DynamicCast(const std::vector& container) { + static_assert(std::is_base_of::value, + "ElementType must be a base class of ResultType"); + std::vector res; + for (auto* ptr : container) { + auto* derived = dynamic_cast(ptr); + if (derived) { + res.emplace_back(derived); + } + } + return res; +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index df55e4dad1..511fd941dc 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/details/gather_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/variable_visitor.h" namespace paddle { namespace framework { @@ -24,16 +26,22 @@ GatherOpHandle::GatherOpHandle(const std::vector &local_scopes, void GatherOpHandle::RunImpl() { // the input and output may have dummy var. - std::vector in_var_handles = GetValidVarHandles(inputs_); - std::vector out_var_handles = GetValidVarHandles(outputs_); + auto in_var_handles = DynamicCast(inputs_); PADDLE_ENFORCE_EQ( in_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, - "The number of output should be one."); - auto in_0_handle = static_cast(in_var_handles[0]); + VarHandle *out_var_handle; + { + auto out_var_handles = DynamicCast(outputs_); + + PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, + "The number of output should be one."); + out_var_handle = out_var_handles.front(); + } + + auto in_0_handle = in_var_handles[0]; auto pre_in_var = local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_); auto pre_place = in_0_handle->place_; @@ -41,11 +49,11 @@ void GatherOpHandle::RunImpl() { PADDLE_ENFORCE(pre_in_var->IsType(), "Currently, gather_op only can gather SelectedRows."); - PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(), + PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(), "The place of input and output should be the same."); // Wait input done, this Wait is asynchronous operation - WaitEvents(in_var_handles); + WaitInputVarGenerated(in_var_handles); std::vector out_rows; std::vector in_tensors; @@ -53,13 +61,12 @@ void GatherOpHandle::RunImpl() { auto &pre_in = pre_in_var->Get(); // gather the inputs - for (auto *in : in_var_handles) { - auto in_handle = static_cast(in); + for (auto *in_handle : in_var_handles) { auto in_p = in_handle->place_; in_places.push_back(in_p); PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), "Places must be all on CPU or all on CUDA."); - auto in_var = + auto *in_var = local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); auto &in_sr = in_var->Get(); @@ -70,17 +77,16 @@ void GatherOpHandle::RunImpl() { PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), "The dims of inputs is not consistent."); - auto in_sr_rows = in_sr.rows(); + auto &in_sr_rows = in_sr.rows(); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); in_tensors.emplace_back(in_sr.value()); } // write the output - auto &out_place = out_var_handles[0]->place_; - auto out_scope_idx = out_var_handles[0]->scope_idx_; - auto out_var = - local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_); + auto &out_place = out_var_handle->place_; + auto out_scope_idx = out_var_handle->scope_idx_; + auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_); auto out = out_var->GetMutable(); out->set_height(pre_in.height()); @@ -106,25 +112,15 @@ void GatherOpHandle::RunImpl() { }); } -void GatherOpHandle::WaitEvents( +void GatherOpHandle::WaitInputVarGenerated( const std::vector &in_var_handles) { for (auto *in : in_var_handles) { if (in->generated_op_) { - in->generated_op_->Wait(dev_ctxes_[in->place_]); - } - } -} - -std::vector GatherOpHandle::GetValidVarHandles( - const std::vector &inputs) { - std::vector in_var_handles; - for (auto *in : inputs) { - auto *in_handle = dynamic_cast(in); - if (in_handle) { - in_var_handles.push_back(in_handle); + for (auto pair : dev_ctxes_) { + in->generated_op_->Wait(pair.second); + } } } - return in_var_handles; } std::string GatherOpHandle::Name() const { return "gather"; } diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index b13dc4ceb3..f576f047f3 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -42,10 +42,7 @@ struct GatherOpHandle : public OpHandleBase { protected: void RunImpl() override; - std::vector GetValidVarHandles( - const std::vector &); - - void WaitEvents(const std::vector &in_var_handles); + void WaitInputVarGenerated(const std::vector &in_var_handles); }; } // namespace details diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 871e41343f..68116aca93 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -53,6 +53,11 @@ struct VarHandle : public VarHandleBase { size_t scope_idx_; std::string name_; platform::Place place_; + + bool operator==(const VarHandle &o) const { + return o.generated_op_ == generated_op_ && o.name_ == name_ && + o.scope_idx_ == scope_idx_; + } }; // Dummy Variable. It is used to represent dependencies between operators diff --git a/paddle/fluid/framework/details/variable_visitor.cc b/paddle/fluid/framework/details/variable_visitor.cc new file mode 100644 index 0000000000..f5f62ed8c4 --- /dev/null +++ b/paddle/fluid/framework/details/variable_visitor.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/framework/selected_rows.h" +namespace paddle { +namespace framework { +namespace details { +template +static void VisitVariable(Variable* var, Func func) { + if (var->IsType()) { + func(var->GetMutable()); + } else if (var->IsType()) { + func(var->GetMutable()); + } else { + PADDLE_THROW("Not supported type %s", var->Type().name()); + } +} + +template +static void VisitVariable(const Variable& var, Func func) { + if (var.IsType()) { + func(var.Get()); + } else if (var.IsType()) { + func(var.Get()); + } else { + PADDLE_THROW("Not supported type %s", var.Type().name()); + } +} + +struct TensorVisitor { + Tensor* result_{nullptr}; + + void operator()(LoDTensor* tensor) { result_ = tensor; } + + void operator()(SelectedRows* selected_rows) { + result_ = selected_rows->mutable_value(); + } + + template + void operator()() { + PADDLE_THROW("Not Support to get LoDTensor from %s", typeid(T).name()); + } +}; + +Tensor& VariableVisitor::GetMutableTensor(Variable* var) { + TensorVisitor vistor; + VisitVariable(var, vistor); + return *vistor.result_; +} + +struct ShareDimsAndLoDVisitor { + Variable* trg_; + void operator()(const LoDTensor& val) { + auto* tensor = trg_->GetMutable(); + tensor->set_layout(val.layout()); + tensor->set_lod(val.lod()); + tensor->Resize(val.dims()); + } + + void operator()(const SelectedRows& val) { + auto* selected_rows = trg_->GetMutable(); + selected_rows->set_rows(val.rows()); + selected_rows->set_height(val.height()); + selected_rows->mutable_value()->Resize(val.value().dims()); + } + + template + void operator()(const T&) { + PADDLE_ENFORCE("ShareDimsAndLoD is not supported by type %s", + typeid(T).name()); + } +}; + +void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { + ShareDimsAndLoDVisitor visitor{trg}; + VisitVariable(src, visitor); +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/variable_visitor.h b/paddle/fluid/framework/details/variable_visitor.h new file mode 100644 index 0000000000..67baa1895e --- /dev/null +++ b/paddle/fluid/framework/details/variable_visitor.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/variable.h" + +namespace paddle { +namespace framework { +namespace details { + +class VariableVisitor { + public: + static Tensor &GetMutableTensor(Variable *var); + + static void ShareDimsAndLoD(const Variable &src, Variable *trg); +}; + +} // namespace details +} // namespace framework +} // namespace paddle -- GitLab