diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 897e41f79f4e3bb9cecbe7b42fc6c4fd3401d839..181f08d028919d3d55821186d777f3a8a636ae3a 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -21,8 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) -cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory) -cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory) +cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) + +cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory) +cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory) cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory device_context broadcast_op_handle) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 7d29012380e1b1710704d71a28d21dcc3097eb51..0bc3ee78d67e8548f093ff7086cf06a1ffb1c58b 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -13,95 +13,72 @@ // limitations under the License. #include "paddle/fluid/framework/details/broadcast_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/variable_visitor.h" namespace paddle { namespace framework { namespace details { - -Tensor *GetTensorFromVar(Variable *in_var) { - if (in_var->IsType()) { - return in_var->GetMutable(); - } else if (in_var->IsType()) { - return in_var->GetMutable()->mutable_value(); - } else { - PADDLE_THROW("Var should be LoDTensor or SelectedRows"); - } - return nullptr; -} - BroadcastOpHandle::BroadcastOpHandle(const std::vector &local_scopes, const std::vector &places) : local_scopes_(local_scopes), places_(places) {} void BroadcastOpHandle::RunImpl() { - // the input may have dummy var. - std::vector in_var_handle; - for (auto *in : inputs_) { - auto *out_handle = dynamic_cast(in); - if (out_handle) { - in_var_handle.push_back(out_handle); - } - } - PADDLE_ENFORCE_EQ(in_var_handle.size(), 1, - "The number of input should be one."); + // the input and output may have dummy var. + VarHandle *in_var_handle; - // the output may have dummy var. - std::vector out_var_handles; - for (auto *out : outputs_) { - auto *out_handle = dynamic_cast(out); - if (out_handle) { - out_var_handles.push_back(out_handle); - } + { + auto in_var_handles = DynamicCast(inputs_); + PADDLE_ENFORCE_EQ(in_var_handles.size(), 1, + "The number of input should be one."); + in_var_handle = in_var_handles[0]; } + auto out_var_handles = DynamicCast(outputs_); + PADDLE_ENFORCE_EQ( out_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - // Wait input done, this Wait is asynchronous operation - auto &in_place = in_var_handle[0]->place_; - if (in_var_handle[0]->generated_op_) { - for (auto *out : out_var_handles) { - auto &out_p = out->place_; - in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]); - } - } + // Wait input done, this Wait is asynchronous operation platform::Place + // &in_place; + WaitInputVarGenerated(*in_var_handle); - // - auto in_scope_idx = in_var_handle[0]->scope_idx_; - auto in_var = - local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); - Tensor *in_tensor = GetTensorFromVar(in_var); + auto *in_var = local_scopes_.at(in_var_handle->scope_idx_) + ->FindVar(in_var_handle->name_); + PADDLE_ENFORCE_NOT_NULL(in_var); + Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); for (auto *out : out_var_handles) { + if (*out == *in_var_handle) { + continue; + } + auto &out_p = out->place_; - auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); + auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); - PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), + PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(), "Places must be all on CPU or all on CUDA."); - if (in_var->IsType()) { - auto &in_sr = in_var->Get(); - auto out_sr = out_var->GetMutable(); - if (&in_sr == out_sr) continue; - out_sr->set_height(in_sr.height()); - out_sr->set_rows(in_sr.rows()); - out_sr->mutable_value()->Resize(in_sr.value().dims()); - out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type()); - } else if (in_var->IsType()) { - auto in_lod = in_var->Get(); - auto out_lod = out_var->GetMutable(); - if (&in_lod == out_lod) continue; - out_lod->set_lod(in_lod.lod()); - out_lod->Resize(in_lod.dims()); - out_lod->mutable_data(out_p, in_lod.type()); - } else { - PADDLE_THROW("Var should be LoDTensor or SelectedRows."); - } + VariableVisitor::ShareDimsAndLoD(*in_var, out_var); + VariableVisitor::GetMutableTensor(out_var) + .Resize(in_tensor.dims()) + .mutable_data(out_p, in_tensor.type()); - Tensor *out_tensor = GetTensorFromVar(out_var); - paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), - out_tensor); + auto dev_ctx = dev_ctxes_[out_p]; + RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { + paddle::framework::TensorCopy( + in_tensor, out_p, *(dev_ctx), + &VariableVisitor::GetMutableTensor(out_var)); + }); + } +} + +void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { + if (in_var.generated_op_) { + for (auto &pair : dev_ctxes_) { + in_var.generated_op_->Wait(pair.second); + } } } diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index bc3e373488c9899e6e6d46d090b083332ff40666..92420f10ac5972b7924d83b43bb28234079e5072 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -39,12 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; + void WaitInputVarGenerated(const VarHandle &in_var); private: const std::vector &local_scopes_; const std::vector &places_; }; - } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/container_cast.h b/paddle/fluid/framework/details/container_cast.h new file mode 100644 index 0000000000000000000000000000000000000000..a42ae78dc45c2a885f98315a21f1d5558725bca3 --- /dev/null +++ b/paddle/fluid/framework/details/container_cast.h @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +template +std::vector DynamicCast(const std::vector& container) { + static_assert(std::is_base_of::value, + "ElementType must be a base class of ResultType"); + std::vector res; + for (auto* ptr : container) { + auto* derived = dynamic_cast(ptr); + if (derived) { + res.emplace_back(derived); + } + } + return res; +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 8dd85be567d33991ac003707fec939a61a2d0962..511fd941dc7270d79f9a565f03d233b6fdf41d37 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/details/gather_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/variable_visitor.h" namespace paddle { namespace framework { @@ -23,30 +25,23 @@ GatherOpHandle::GatherOpHandle(const std::vector &local_scopes, : local_scopes_(local_scopes), places_(places) {} void GatherOpHandle::RunImpl() { - // the input may have dummy var. - std::vector in_var_handles; - for (auto *in : inputs_) { - auto *in_handle = dynamic_cast(in); - if (in_handle) { - in_var_handles.push_back(in_handle); - } - } + // the input and output may have dummy var. + auto in_var_handles = DynamicCast(inputs_); + PADDLE_ENFORCE_EQ( in_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - // the output may have dummy var. - std::vector out_var_handles; - for (auto *out : outputs_) { - auto *out_handle = dynamic_cast(out); - if (out_handle) { - out_var_handles.push_back(out_handle); - } + VarHandle *out_var_handle; + { + auto out_var_handles = DynamicCast(outputs_); + + PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, + "The number of output should be one."); + out_var_handle = out_var_handles.front(); } - PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, - "The number of output should be one."); - auto in_0_handle = static_cast(in_var_handles[0]); + auto in_0_handle = in_var_handles[0]; auto pre_in_var = local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_); auto pre_place = in_0_handle->place_; @@ -54,15 +49,11 @@ void GatherOpHandle::RunImpl() { PADDLE_ENFORCE(pre_in_var->IsType(), "Currently, gather_op only can gather SelectedRows."); - PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(), + PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(), "The place of input and output should be the same."); // Wait input done, this Wait is asynchronous operation - for (auto *in : in_var_handles) { - if (in->generated_op_) { - in->generated_op_->Wait(dev_ctxes_[in->place_]); - } - } + WaitInputVarGenerated(in_var_handles); std::vector out_rows; std::vector in_tensors; @@ -70,13 +61,12 @@ void GatherOpHandle::RunImpl() { auto &pre_in = pre_in_var->Get(); // gather the inputs - for (auto *in : in_var_handles) { - auto in_handle = static_cast(in); + for (auto *in_handle : in_var_handles) { auto in_p = in_handle->place_; in_places.push_back(in_p); PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), "Places must be all on CPU or all on CUDA."); - auto in_var = + auto *in_var = local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); auto &in_sr = in_var->Get(); @@ -84,20 +74,19 @@ void GatherOpHandle::RunImpl() { "The type of input is not consistent."); PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(), "The height of inputs is not consistent."); - PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), , + PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), "The dims of inputs is not consistent."); - auto in_sr_rows = in_sr.rows(); + auto &in_sr_rows = in_sr.rows(); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); in_tensors.emplace_back(in_sr.value()); } // write the output - auto &out_place = out_var_handles[0]->place_; - auto out_scope_idx = out_var_handles[0]->scope_idx_; - auto out_var = - local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_); + auto &out_place = out_var_handle->place_; + auto out_scope_idx = out_var_handle->scope_idx_; + auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_); auto out = out_var->GetMutable(); out->set_height(pre_in.height()); @@ -110,13 +99,27 @@ void GatherOpHandle::RunImpl() { Tensor *out_tensor = out->mutable_value(); // copy - int s = 0, e = 0; - for (size_t j = 0; j < in_tensors.size(); ++j) { - e += in_tensors[j].dims()[0]; - auto sub_out = out_tensor->Slice(s, e); - paddle::framework::TensorCopy(in_tensors[j], out_place, - *(dev_ctxes_[in_places[j]]), &sub_out); - s = e; + auto dev_ctx = dev_ctxes_[out_place]; + RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] { + int s = 0, e = 0; + for (size_t j = 0; j < in_tensors.size(); ++j) { + e += in_tensors[j].dims()[0]; + auto sub_out = out_tensor->Slice(s, e); + paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx), + &sub_out); + s = e; + } + }); +} + +void GatherOpHandle::WaitInputVarGenerated( + const std::vector &in_var_handles) { + for (auto *in : in_var_handles) { + if (in->generated_op_) { + for (auto pair : dev_ctxes_) { + in->generated_op_->Wait(pair.second); + } + } } } diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index d11ef8556aa8840949ca8dc7aa176413f70b9f22..c394dd7a14b07cb956aa1aedfc0df4fa25744dd7 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -39,6 +39,7 @@ struct GatherOpHandle : public OpHandleBase { protected: void RunImpl() override; + void WaitInputVarGenerated(const std::vector &in_var_handles); private: const std::vector &local_scopes_; diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 2b887c67e6fc6ea78e42fbb9fd170f740db27d97..9f7fd69e64fe9d7ef0bf3037bea7f686cb2eee0b 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -61,6 +61,11 @@ struct VarHandle : public VarHandleBase { size_t scope_idx_; std::string name_; platform::Place place_; + + bool operator==(const VarHandle& o) const { + return o.generated_op_ == generated_op_ && o.name_ == name_ && + o.scope_idx_ == scope_idx_; + } }; // Dummy Variable. It is used to represent dependencies between operators diff --git a/paddle/fluid/framework/details/variable_visitor.cc b/paddle/fluid/framework/details/variable_visitor.cc new file mode 100644 index 0000000000000000000000000000000000000000..10bac0fae9504215fab11dd8cca7c278feaa4bda --- /dev/null +++ b/paddle/fluid/framework/details/variable_visitor.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/framework/selected_rows.h" +namespace paddle { +namespace framework { +namespace details { +template +static void VisitVariable(Variable* var, Func* func) { + if (var->IsType()) { + (*func)(var->GetMutable()); + } else if (var->IsType()) { + (*func)(var->GetMutable()); + } else { + PADDLE_THROW("Not supported type %s", var->Type().name()); + } +} + +template +static void VisitVariable(const Variable& var, Func* func) { + if (var.IsType()) { + (*func)(var.Get()); + } else if (var.IsType()) { + (*func)(var.Get()); + } else { + PADDLE_THROW("Not supported type %s", var.Type().name()); + } +} + +struct TensorVisitor { + Tensor* result_{nullptr}; + + void operator()(LoDTensor* tensor) { result_ = tensor; } + + void operator()(SelectedRows* selected_rows) { + result_ = selected_rows->mutable_value(); + } + + template + void operator()() { + PADDLE_THROW("Not Support to get LoDTensor from %s", typeid(T).name()); + } +}; + +Tensor& VariableVisitor::GetMutableTensor(Variable* var) { + TensorVisitor vistor; + VisitVariable(var, &vistor); + return *vistor.result_; +} + +struct ShareDimsAndLoDVisitor { + Variable* trg_; + void operator()(const LoDTensor& val) { + auto* tensor = trg_->GetMutable(); + tensor->set_layout(val.layout()); + tensor->set_lod(val.lod()); + tensor->Resize(val.dims()); + } + + void operator()(const SelectedRows& val) { + auto* selected_rows = trg_->GetMutable(); + selected_rows->set_rows(val.rows()); + selected_rows->set_height(val.height()); + selected_rows->mutable_value()->Resize(val.value().dims()); + } + + template + void operator()(const T&) { + PADDLE_ENFORCE("ShareDimsAndLoD is not supported by type %s", + typeid(T).name()); + } +}; + +void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { + ShareDimsAndLoDVisitor visitor{trg}; + VisitVariable(src, &visitor); +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/variable_visitor.h b/paddle/fluid/framework/details/variable_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..67baa1895e4513738fa73d49c46660da92279b9d --- /dev/null +++ b/paddle/fluid/framework/details/variable_visitor.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/variable.h" + +namespace paddle { +namespace framework { +namespace details { + +class VariableVisitor { + public: + static Tensor &GetMutableTensor(Variable *var); + + static void ShareDimsAndLoD(const Variable &src, Variable *trg); +}; + +} // namespace details +} // namespace framework +} // namespace paddle