From d4b007af8bfa82df134220690115fcd58122de26 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 8 Mar 2022 10:53:28 +0800 Subject: [PATCH] add share dims (#40238) --- paddle/fluid/framework/infershape_utils.cc | 20 +++++++------ paddle/phi/core/meta_tensor.cc | 35 +++++++++++++++++----- paddle/phi/core/meta_tensor.h | 3 +- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 7232a707916..91ef59575c3 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -232,16 +232,8 @@ class CompatMetaTensor : public phi::MetaTensor { } } - void share_meta(const MetaTensor& meta_tensor) override { + void share_dims(const MetaTensor& meta_tensor) override { set_dims(meta_tensor.dims()); - set_dtype(meta_tensor.dtype()); - // VarDesc doesn't contains layout, so we cannot share layout - // set_layout(meta_tensor.layout()); - - // special case 1: share lod of LoDTensor - share_lod(meta_tensor); - - // special case 2: share height and rows of SelectedRows in runtime if (is_runtime_) { auto* var = BOOST_GET(Variable*, var_); if (var->IsType()) { @@ -254,6 +246,16 @@ class CompatMetaTensor : public phi::MetaTensor { } } + void share_meta(const MetaTensor& meta_tensor) override { + set_dtype(meta_tensor.dtype()); + // VarDesc doesn't contains layout, so we cannot share layout + // set_layout(meta_tensor.layout()); + + // special case 1: share lod of LoDTensor + share_lod(meta_tensor); + share_dims(meta_tensor); + } + private: const LoD& GetRuntimeLoD() const { auto* var = BOOST_GET_CONST(Variable*, var_); diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 2aadce4feda..eb114304f53 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -98,13 +98,9 @@ const LoD& MetaTensor::lod() const { } void MetaTensor::share_meta(const MetaTensor& meta_tensor) { - if (phi::DenseTensor::classof(tensor_)) { - set_dims(meta_tensor.dims()); - set_dtype(meta_tensor.dtype()); - set_layout(meta_tensor.layout()); - share_lod(meta_tensor); - } else if (phi::SelectedRows::classof(tensor_)) { - set_dims(meta_tensor.dims()); + if (phi::DenseTensor::classof(tensor_) || + phi::SelectedRows::classof(tensor_)) { + share_dims(meta_tensor); set_dtype(meta_tensor.dtype()); set_layout(meta_tensor.layout()); share_lod(meta_tensor); @@ -114,4 +110,29 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) { } } +TensorBase* MetaTensor::get_tensor() const { return tensor_; } + +void MetaTensor::share_dims(const MetaTensor& meta_tensor) { + bool is_dense_tensor = phi::DenseTensor::classof(tensor_); + bool is_selected_rows = phi::SelectedRows::classof(tensor_); + if (is_dense_tensor || is_selected_rows) { + set_dims(meta_tensor.dims()); + if (is_selected_rows) { + const auto in_tensor_base = meta_tensor.get_tensor(); + PADDLE_ENFORCE_EQ( + phi::SelectedRows::classof(in_tensor_base), + true, + errors::InvalidArgument("The input MetaTensor is SelectedRows, but " + "the output MetaTensor is not this type.")); + auto* selected_rows_out = static_cast(tensor_); + auto* selected_rows_in = static_cast(in_tensor_base); + selected_rows_out->set_rows(selected_rows_in->rows()); + selected_rows_out->set_height(selected_rows_in->height()); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported sharing dims for `%s`.", tensor_->type_info().name())); + } +} + } // namespace phi diff --git a/paddle/phi/core/meta_tensor.h b/paddle/phi/core/meta_tensor.h index 1a32019a190..3971a9f7e99 100644 --- a/paddle/phi/core/meta_tensor.h +++ b/paddle/phi/core/meta_tensor.h @@ -60,12 +60,13 @@ class MetaTensor { virtual void share_lod(const MetaTensor& meta_tensor); virtual void share_meta(const MetaTensor& meta_tensor); + virtual void share_dims(const MetaTensor& meta_tensor); private: // Because the lod in compiletime and runtime is different, // so `LoD` cannot in public methods const LoD& lod() const; - + TensorBase* get_tensor() const; TensorBase* tensor_; }; -- GitLab