From ead812305326f4b5ce003646b013aab66cfe7a32 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sun, 5 Dec 2021 20:44:40 -0600 Subject: [PATCH] [PTen] Fix reshape move storage using error (#37765) * fix reshape move storage error * remove needless set type * alloc tensor by shared storage --- paddle/fluid/operators/reshape_op.cc | 11 ++++++----- paddle/pten/api/lib/utils/tensor_utils.cc | 24 +++++++++++++++++++++++ paddle/pten/api/lib/utils/tensor_utils.h | 5 +++++ 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 5148e3b0940..c12db129385 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -383,13 +383,13 @@ class ReshapeKernel { // 3. out tensor is view of input // We can't MakePtenDenseTensor for case 2, so we solve this case by // creating a temporary tensor here: - const auto alloc = std::make_shared( - ctx.GetPlace()); pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()), in->dims(), pten::TransToPtenDataLayout(in->layout())}; - auto pt_out_tmp = - std::make_shared(alloc, std::move(meta)); + auto pt_out_tmp = std::make_shared( + pten::make_intrusive( + ctx.GetPlace()), + std::move(meta)); pten::DenseTensor *pt_out = nullptr; if (in == out) { pt_out = pt_x.get(); @@ -484,7 +484,8 @@ class ReshapeKernel { // non-inplace need move all result from pt_out to out, inplace need set // result dims. if (in != out) { - paddle::experimental::MovesStorage(pt_out, static_cast(out)); + paddle::experimental::MovesSharedStorage(pt_out, + static_cast(out)); } else { out->Resize(pt_out->dims()); } diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc index 0983abfa921..f2b6e4841aa 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.cc +++ b/paddle/pten/api/lib/utils/tensor_utils.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/tensor_utils.h" +#include #include #include "paddle/pten/core/compat_utils.h" @@ -342,6 +343,29 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) { MovesStorage(src, static_cast(dst)); } +void MovesSharedStorage(pten::DenseTensor* src, + paddle::framework::Tensor* dst) { + PADDLE_ENFORCE_NOT_NULL( + src, + platform::errors::InvalidArgument( + "The source DenseTensor is nullptr when move allocation.")); + PADDLE_ENFORCE_NOT_NULL( + dst, + platform::errors::InvalidArgument( + "The destination Tensor is nullptr when move allocation.")); + dst->Resize(src->dims()); + auto* storage = static_cast( + pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src)); + dst->ResetHolderWithType(storage->GetAllocation(), + pten::TransToProtoVarType(src->dtype())); +} + +void MovesSharedStorage(pten::DenseTensor* src, + paddle::framework::LoDTensor* dst) { + MovesSharedStorage(src, static_cast(dst)); + SetLoD(dst->mutable_lod(), src->lod()); +} + void ReMakePtenDenseTensor(const paddle::framework::Tensor& src, const pten::TensorArgDef& arg_def, pten::DenseTensor* dst) { diff --git a/paddle/pten/api/lib/utils/tensor_utils.h b/paddle/pten/api/lib/utils/tensor_utils.h index 04f0f6c1ff0..6397ca369ce 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.h +++ b/paddle/pten/api/lib/utils/tensor_utils.h @@ -58,6 +58,11 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst); void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst); +void MovesSharedStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst); + +void MovesSharedStorage(pten::DenseTensor* src, + paddle::framework::LoDTensor* dst); + /** * In order to improve the compatibility state performance, some tricky tool * functions are added. -- GitLab