diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 7232a707916dd5f0795c04cff8137c5e88132d42..91ef59575c3aa2a737f32c0ca90a7cbb2b3f3744 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -232,16 +232,8 @@ class CompatMetaTensor : public phi::MetaTensor { } } - void share_meta(const MetaTensor& meta_tensor) override { + void share_dims(const MetaTensor& meta_tensor) override { set_dims(meta_tensor.dims()); - set_dtype(meta_tensor.dtype()); - // VarDesc doesn't contains layout, so we cannot share layout - // set_layout(meta_tensor.layout()); - - // special case 1: share lod of LoDTensor - share_lod(meta_tensor); - - // special case 2: share height and rows of SelectedRows in runtime if (is_runtime_) { auto* var = BOOST_GET(Variable*, var_); if (var->IsType()) { @@ -254,6 +246,16 @@ class CompatMetaTensor : public phi::MetaTensor { } } + void share_meta(const MetaTensor& meta_tensor) override { + set_dtype(meta_tensor.dtype()); + // VarDesc doesn't contains layout, so we cannot share layout + // set_layout(meta_tensor.layout()); + + // special case 1: share lod of LoDTensor + share_lod(meta_tensor); + share_dims(meta_tensor); + } + private: const LoD& GetRuntimeLoD() const { auto* var = BOOST_GET_CONST(Variable*, var_); diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 2aadce4feda96623553e8583f926af38458f8f9e..eb114304f53ea08b05d36792330cf5bd3ebbee5d 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -98,13 +98,9 @@ const LoD& MetaTensor::lod() const { } void MetaTensor::share_meta(const MetaTensor& meta_tensor) { - if (phi::DenseTensor::classof(tensor_)) { - set_dims(meta_tensor.dims()); - set_dtype(meta_tensor.dtype()); - set_layout(meta_tensor.layout()); - share_lod(meta_tensor); - } else if (phi::SelectedRows::classof(tensor_)) { - set_dims(meta_tensor.dims()); + if (phi::DenseTensor::classof(tensor_) || + phi::SelectedRows::classof(tensor_)) { + share_dims(meta_tensor); set_dtype(meta_tensor.dtype()); set_layout(meta_tensor.layout()); share_lod(meta_tensor); @@ -114,4 +110,29 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) { } } +TensorBase* MetaTensor::get_tensor() const { return tensor_; } + +void MetaTensor::share_dims(const MetaTensor& meta_tensor) { + bool is_dense_tensor = phi::DenseTensor::classof(tensor_); + bool is_selected_rows = phi::SelectedRows::classof(tensor_); + if (is_dense_tensor || is_selected_rows) { + set_dims(meta_tensor.dims()); + if (is_selected_rows) { + const auto in_tensor_base = meta_tensor.get_tensor(); + PADDLE_ENFORCE_EQ( + phi::SelectedRows::classof(in_tensor_base), + true, + errors::InvalidArgument("The input MetaTensor is SelectedRows, but " + "the output MetaTensor is not this type.")); + auto* selected_rows_out = static_cast(tensor_); + auto* selected_rows_in = static_cast(in_tensor_base); + selected_rows_out->set_rows(selected_rows_in->rows()); + selected_rows_out->set_height(selected_rows_in->height()); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported sharing dims for `%s`.", tensor_->type_info().name())); + } +} + } // namespace phi diff --git a/paddle/phi/core/meta_tensor.h b/paddle/phi/core/meta_tensor.h index 1a32019a190496804c0ef4c64f78f687b8af7577..3971a9f7e99e0282cae5e4d1e61ee6eb28c4b9a7 100644 --- a/paddle/phi/core/meta_tensor.h +++ b/paddle/phi/core/meta_tensor.h @@ -60,12 +60,13 @@ class MetaTensor { virtual void share_lod(const MetaTensor& meta_tensor); virtual void share_meta(const MetaTensor& meta_tensor); + virtual void share_dims(const MetaTensor& meta_tensor); private: // Because the lod in compiletime and runtime is different, // so `LoD` cannot in public methods const LoD& lod() const; - + TensorBase* get_tensor() const; TensorBase* tensor_; };