未验证 提交 ead81230 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Fix reshape move storage using error (#37765)

* fix reshape move storage error

* remove needless set type

* alloc tensor by shared storage
上级 1bdb8578
...@@ -383,13 +383,13 @@ class ReshapeKernel { ...@@ -383,13 +383,13 @@ class ReshapeKernel {
// 3. out tensor is view of input // 3. out tensor is view of input
// We can't MakePtenDenseTensor for case 2, so we solve this case by // We can't MakePtenDenseTensor for case 2, so we solve this case by
// creating a temporary tensor here: // creating a temporary tensor here:
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
ctx.GetPlace());
pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()), pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()),
in->dims(), in->dims(),
pten::TransToPtenDataLayout(in->layout())}; pten::TransToPtenDataLayout(in->layout())};
auto pt_out_tmp = auto pt_out_tmp = std::make_shared<pten::DenseTensor>(
std::make_shared<pten::DenseTensor>(alloc, std::move(meta)); pten::make_intrusive<paddle::experimental::SharedStorage>(
ctx.GetPlace()),
std::move(meta));
pten::DenseTensor *pt_out = nullptr; pten::DenseTensor *pt_out = nullptr;
if (in == out) { if (in == out) {
pt_out = pt_x.get(); pt_out = pt_x.get();
...@@ -484,7 +484,8 @@ class ReshapeKernel { ...@@ -484,7 +484,8 @@ class ReshapeKernel {
// non-inplace need move all result from pt_out to out, inplace need set // non-inplace need move all result from pt_out to out, inplace need set
// result dims. // result dims.
if (in != out) { if (in != out) {
paddle::experimental::MovesStorage(pt_out, static_cast<Tensor *>(out)); paddle::experimental::MovesSharedStorage(pt_out,
static_cast<Tensor *>(out));
} else { } else {
out->Resize(pt_out->dims()); out->Resize(pt_out->dims());
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include <utility>
#include <vector> #include <vector>
#include "paddle/pten/core/compat_utils.h" #include "paddle/pten/core/compat_utils.h"
...@@ -342,6 +343,29 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) { ...@@ -342,6 +343,29 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) {
MovesStorage(src, static_cast<paddle::framework::Tensor*>(dst)); MovesStorage(src, static_cast<paddle::framework::Tensor*>(dst));
} }
void MovesSharedStorage(pten::DenseTensor* src,
paddle::framework::Tensor* dst) {
PADDLE_ENFORCE_NOT_NULL(
src,
platform::errors::InvalidArgument(
"The source DenseTensor is nullptr when move allocation."));
PADDLE_ENFORCE_NOT_NULL(
dst,
platform::errors::InvalidArgument(
"The destination Tensor is nullptr when move allocation."));
dst->Resize(src->dims());
auto* storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));
dst->ResetHolderWithType(storage->GetAllocation(),
pten::TransToProtoVarType(src->dtype()));
}
void MovesSharedStorage(pten::DenseTensor* src,
paddle::framework::LoDTensor* dst) {
MovesSharedStorage(src, static_cast<paddle::framework::Tensor*>(dst));
SetLoD(dst->mutable_lod(), src->lod());
}
void ReMakePtenDenseTensor(const paddle::framework::Tensor& src, void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
const pten::TensorArgDef& arg_def, const pten::TensorArgDef& arg_def,
pten::DenseTensor* dst) { pten::DenseTensor* dst) {
......
...@@ -58,6 +58,11 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst); ...@@ -58,6 +58,11 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst);
void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst); void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst);
void MovesSharedStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst);
void MovesSharedStorage(pten::DenseTensor* src,
paddle::framework::LoDTensor* dst);
/** /**
* In order to improve the compatibility state performance, some tricky tool * In order to improve the compatibility state performance, some tricky tool
* functions are added. * functions are added.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册