diff --git a/paddle/pten/api/lib/utils/storage.h b/paddle/pten/api/lib/utils/storage.h index 242ea6476ae983781d3d9eb1e959b5091b2495f4..2ba1e2e90eef214b8701bb8b32dff34481b9d52e 100644 --- a/paddle/pten/api/lib/utils/storage.h +++ b/paddle/pten/api/lib/utils/storage.h @@ -34,6 +34,11 @@ class ExternalStorage : public pten::Storage { "The external shared storage cannot be reallocated.")); } + void Clear() override { + data_.Clear(); + size_ = 0; + } + size_t size() const noexcept override { return size_; } const paddle::platform::Place& place() const override { return data_.place(); @@ -41,7 +46,7 @@ class ExternalStorage : public pten::Storage { bool OwnsMemory() const noexcept override { return false; } private: - const int64_t size_{0}; + int64_t size_{0}; }; class SharedStorage : public pten::Storage { @@ -65,6 +70,11 @@ class SharedStorage : public pten::Storage { "The external shared storage cannot be reallocated.")); } + void Clear() override { + data_.Clear(); + size_ = 0; + } + size_t size() const noexcept override { return size_; } const paddle::platform::Place& place() const override { return data_.place(); diff --git a/paddle/pten/core/allocator.h b/paddle/pten/core/allocator.h index c16c4ffaa6a3768959db9de597cdad6aaae4108f..9c6f749609a48fdf0125bb48ac820c5da8bc15b4 100644 --- a/paddle/pten/core/allocator.h +++ b/paddle/pten/core/allocator.h @@ -73,9 +73,9 @@ class Allocation final { operator bool() const noexcept { return data_ || ctx_.Get(); } const Place& place() const noexcept { return place_; } - void Clear() noexcept { - data_ = nullptr; + void Clear() { ctx_.Clear(); + data_ = nullptr; } /// \brief Statically cast the void pointer of the context object to @@ -107,12 +107,11 @@ class Allocation final { swap(*this, other); return *this; } - ~Context() { + ~Context() { Clear(); } + void Clear() { if (deleter_) { deleter_(ctx_); } - } - void Clear() noexcept { ctx_ = nullptr; deleter_ = nullptr; } diff --git a/paddle/pten/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc index 471b24f09fdc7e19208d3c9d40a487e3b42549bc..2dfd523579d26e780358c5b8c546785a463855df 100644 --- a/paddle/pten/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -113,6 +113,20 @@ void DenseTensor::check_memory_size() const { bytes)); } +void DenseTensor::Resize(const DDim& dims) { + if (product(dims) == product(meta_.dims)) { + set_dims(dims); + } else { + meta_.dims = dims; + storage_->Clear(); + } +} + +void DenseTensor::set_dims(const DDim& dims) { + CHECK(product(dims) == product(meta_.dims)); + meta_.dims = dims; +} + #define DATA_MEMBER_FUNC_INSTANTIATION(dtype) \ template dtype* DenseTensor::mutable_data(); \ template const dtype* DenseTensor::data() const; diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index e8e57b333ae99e0cd836ddf0cab1b4f09664a749..9354f92e7dd018a420037ecbc1d1f1b43a97610b 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -119,10 +119,16 @@ class DenseTensor : public TensorBase, /// \return Whether the storage is shared with other objects. bool IsSharedWith(const DenseTensor& b) const; - /// \brief Change the dims information in the metadata, and the corresponding - /// memory allocation will occur when the `mutable_data` is called. + /// \brief Change the dims information in the metadata. If the new size is + /// inconsistent with the original value, the storage area will be released + /// to avoid wrong access. /// \param dims The new dims of the dense tensor. - void Resize(const DDim& dims) noexcept { meta_.dims = dims; } + void Resize(const DDim& dims); + + /// \brief Change the dims information in the metadata. + /// \param dims The new dims of the dense tensor. The product of the dims + /// elements must be consistent with the original value. + void set_dims(const DDim& dims); /// \brief Returns the actual storage size occupied by tensor, may be larger /// than its shape dims. diff --git a/paddle/pten/core/storage.h b/paddle/pten/core/storage.h index 430572e253d6ec7ba81f3969fd61d28ea24431c5..ef9e22a0804e73f57cd8b3a7936adb7de02effdb 100644 --- a/paddle/pten/core/storage.h +++ b/paddle/pten/core/storage.h @@ -44,6 +44,8 @@ class Storage : public intrusive_ref_counter { /// \return The mutable data pointer of the storage. void* data() const noexcept { return data_.operator->(); } + virtual void Clear() = 0; + virtual size_t size() const = 0; virtual const Place& place() const = 0; virtual bool OwnsMemory() const = 0; @@ -68,6 +70,12 @@ class TensorStorage : public Storage { void Realloc(size_t size) override; size_t size() const noexcept override { return size_; } + + void Clear() override { + data_.Clear(); + size_ = 0; + } + const Place& place() const override { return data_.place(); } bool OwnsMemory() const noexcept override { return true; } const std::shared_ptr& allocator() const noexcept { diff --git a/paddle/pten/tests/core/test_dense_tensor.cc b/paddle/pten/tests/core/test_dense_tensor.cc index 12476373f8d98c7e349a5d82959a570f3df58e2c..dac2575713bfb40db4d5c86bd9920f99014b7dbb 100644 --- a/paddle/pten/tests/core/test_dense_tensor.cc +++ b/paddle/pten/tests/core/test_dense_tensor.cc @@ -115,7 +115,9 @@ TEST(dense_tensor, resize) { CHECK_EQ(tensor_0.memory_size(), 2u); tensor_0.check_memory_size(); tensor_0.Resize({1, 2, 3}); - CHECK_EQ(tensor_0.memory_size(), 2u); + CHECK_EQ(tensor_0.memory_size(), 0u); + tensor_0.set_dims({2, 3}); + CHECK_EQ(tensor_0.memory_size(), 0u); tensor_0.mutable_data(); CHECK_EQ(tensor_0.memory_size(), 6u);