From 9303b0951407880dca4ddac6aaf2fa9c9c737b53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 12 Nov 2021 10:44:55 +0800 Subject: [PATCH] add the shallow clone member func of the dense tensor, test=develop (#37146) --- paddle/pten/core/dense_tensor.h | 6 ++++++ paddle/pten/core/tensor_meta.h | 7 +++++++ paddle/pten/tests/core/test_dense_tensor.cc | 15 +++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index 9354f92e7dd..8c2b711015c 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -172,6 +172,12 @@ class DenseTensor : public TensorBase, /// \return The const data pointer value of raw type. const void* data() const; + /// \brief Get the shallow clone of current tensor. + /// \return The shallow clone of current tensor. + DenseTensor shallow_clone() const { + return DenseTensor(copy_intrusive(storage_), meta_); + } + private: friend class CompatibleDenseTensorUtils; diff --git a/paddle/pten/core/tensor_meta.h b/paddle/pten/core/tensor_meta.h index 85afc3f2f01..a7cba706dd7 100644 --- a/paddle/pten/core/tensor_meta.h +++ b/paddle/pten/core/tensor_meta.h @@ -82,4 +82,11 @@ inline bool DenseTensorMeta::valid() const noexcept { return valid; } +inline bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) { + bool ret = true; + return ret && (lhs.is_scalar == rhs.is_scalar) && (lhs.dims == rhs.dims) && + (lhs.type == rhs.type) && (lhs.layout == rhs.layout) && + (lhs.lod == rhs.lod); +} + } // namespace pten diff --git a/paddle/pten/tests/core/test_dense_tensor.cc b/paddle/pten/tests/core/test_dense_tensor.cc index dac2575713b..59f67320752 100644 --- a/paddle/pten/tests/core/test_dense_tensor.cc +++ b/paddle/pten/tests/core/test_dense_tensor.cc @@ -125,5 +125,20 @@ TEST(dense_tensor, resize) { CHECK_EQ(storage->size(), 6u); } +TEST(dense_tensor, shallow_clone) { + const DDim dims({1, 2}); + const DataType dtype{DataType::INT8}; + const DataLayout layout{DataLayout::NHWC}; + const std::vector> lod{}; + DenseTensorMeta meta(dtype, dims, layout, lod); + + auto alloc = std::make_shared(); + DenseTensor tensor_0(alloc, meta); + + auto tensor_1 = tensor_0.shallow_clone(); + CHECK(tensor_0.meta() == tensor_1.meta()); + CHECK(tensor_0.release() == tensor_1.release()); +} + } // namespace tests } // namespace pten -- GitLab