diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index 9354f92e7dd018a420037ecbc1d1f1b43a97610b..8c2b711015c9da2c5daac2005eae47fdfa2e403d 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -172,6 +172,12 @@ class DenseTensor : public TensorBase, /// \return The const data pointer value of raw type. const void* data() const; + /// \brief Get the shallow clone of current tensor. + /// \return The shallow clone of current tensor. + DenseTensor shallow_clone() const { + return DenseTensor(copy_intrusive(storage_), meta_); + } + private: friend class CompatibleDenseTensorUtils; diff --git a/paddle/pten/core/tensor_meta.h b/paddle/pten/core/tensor_meta.h index 85afc3f2f01ea803a4cbd71221f741d862367033..a7cba706dd7d0341335899c954c4386eca5e4ac1 100644 --- a/paddle/pten/core/tensor_meta.h +++ b/paddle/pten/core/tensor_meta.h @@ -82,4 +82,11 @@ inline bool DenseTensorMeta::valid() const noexcept { return valid; } +inline bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) { + bool ret = true; + return ret && (lhs.is_scalar == rhs.is_scalar) && (lhs.dims == rhs.dims) && + (lhs.type == rhs.type) && (lhs.layout == rhs.layout) && + (lhs.lod == rhs.lod); +} + } // namespace pten diff --git a/paddle/pten/tests/core/test_dense_tensor.cc b/paddle/pten/tests/core/test_dense_tensor.cc index dac2575713bfb40db4d5c86bd9920f99014b7dbb..59f67320752551abb1e5b909bee7532aaea8ae41 100644 --- a/paddle/pten/tests/core/test_dense_tensor.cc +++ b/paddle/pten/tests/core/test_dense_tensor.cc @@ -125,5 +125,20 @@ TEST(dense_tensor, resize) { CHECK_EQ(storage->size(), 6u); } +TEST(dense_tensor, shallow_clone) { + const DDim dims({1, 2}); + const DataType dtype{DataType::INT8}; + const DataLayout layout{DataLayout::NHWC}; + const std::vector> lod{}; + DenseTensorMeta meta(dtype, dims, layout, lod); + + auto alloc = std::make_shared(); + DenseTensor tensor_0(alloc, meta); + + auto tensor_1 = tensor_0.shallow_clone(); + CHECK(tensor_0.meta() == tensor_1.meta()); + CHECK(tensor_0.release() == tensor_1.release()); +} + } // namespace tests } // namespace pten