未验证 提交 4d536678 编写于 作者: 石晓伟 提交者: GitHub

adjust the COLUMNS=128; (#37120)

上级 9396f286
...@@ -34,6 +34,11 @@ class ExternalStorage : public pten::Storage { ...@@ -34,6 +34,11 @@ class ExternalStorage : public pten::Storage {
"The external shared storage cannot be reallocated.")); "The external shared storage cannot be reallocated."));
} }
void Clear() override {
data_.Clear();
size_ = 0;
}
size_t size() const noexcept override { return size_; } size_t size() const noexcept override { return size_; }
const paddle::platform::Place& place() const override { const paddle::platform::Place& place() const override {
return data_.place(); return data_.place();
...@@ -41,7 +46,7 @@ class ExternalStorage : public pten::Storage { ...@@ -41,7 +46,7 @@ class ExternalStorage : public pten::Storage {
bool OwnsMemory() const noexcept override { return false; } bool OwnsMemory() const noexcept override { return false; }
private: private:
const int64_t size_{0}; int64_t size_{0};
}; };
class SharedStorage : public pten::Storage { class SharedStorage : public pten::Storage {
...@@ -65,6 +70,11 @@ class SharedStorage : public pten::Storage { ...@@ -65,6 +70,11 @@ class SharedStorage : public pten::Storage {
"The external shared storage cannot be reallocated.")); "The external shared storage cannot be reallocated."));
} }
void Clear() override {
data_.Clear();
size_ = 0;
}
size_t size() const noexcept override { return size_; } size_t size() const noexcept override { return size_; }
const paddle::platform::Place& place() const override { const paddle::platform::Place& place() const override {
return data_.place(); return data_.place();
......
...@@ -73,9 +73,9 @@ class Allocation final { ...@@ -73,9 +73,9 @@ class Allocation final {
operator bool() const noexcept { return data_ || ctx_.Get(); } operator bool() const noexcept { return data_ || ctx_.Get(); }
const Place& place() const noexcept { return place_; } const Place& place() const noexcept { return place_; }
void Clear() noexcept { void Clear() {
data_ = nullptr;
ctx_.Clear(); ctx_.Clear();
data_ = nullptr;
} }
/// \brief Statically cast the void pointer of the context object to /// \brief Statically cast the void pointer of the context object to
...@@ -107,12 +107,11 @@ class Allocation final { ...@@ -107,12 +107,11 @@ class Allocation final {
swap(*this, other); swap(*this, other);
return *this; return *this;
} }
~Context() { ~Context() { Clear(); }
void Clear() {
if (deleter_) { if (deleter_) {
deleter_(ctx_); deleter_(ctx_);
} }
}
void Clear() noexcept {
ctx_ = nullptr; ctx_ = nullptr;
deleter_ = nullptr; deleter_ = nullptr;
} }
......
...@@ -113,6 +113,20 @@ void DenseTensor::check_memory_size() const { ...@@ -113,6 +113,20 @@ void DenseTensor::check_memory_size() const {
bytes)); bytes));
} }
void DenseTensor::Resize(const DDim& dims) {
if (product(dims) == product(meta_.dims)) {
set_dims(dims);
} else {
meta_.dims = dims;
storage_->Clear();
}
}
void DenseTensor::set_dims(const DDim& dims) {
CHECK(product(dims) == product(meta_.dims));
meta_.dims = dims;
}
#define DATA_MEMBER_FUNC_INSTANTIATION(dtype) \ #define DATA_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DenseTensor::mutable_data(); \ template dtype* DenseTensor::mutable_data(); \
template const dtype* DenseTensor::data() const; template const dtype* DenseTensor::data() const;
......
...@@ -119,10 +119,16 @@ class DenseTensor : public TensorBase, ...@@ -119,10 +119,16 @@ class DenseTensor : public TensorBase,
/// \return Whether the storage is shared with other objects. /// \return Whether the storage is shared with other objects.
bool IsSharedWith(const DenseTensor& b) const; bool IsSharedWith(const DenseTensor& b) const;
/// \brief Change the dims information in the metadata, and the corresponding /// \brief Change the dims information in the metadata. If the new size is
/// memory allocation will occur when the `mutable_data` is called. /// inconsistent with the original value, the storage area will be released
/// to avoid wrong access.
/// \param dims The new dims of the dense tensor. /// \param dims The new dims of the dense tensor.
void Resize(const DDim& dims) noexcept { meta_.dims = dims; } void Resize(const DDim& dims);
/// \brief Change the dims information in the metadata.
/// \param dims The new dims of the dense tensor. The product of the dims
/// elements must be consistent with the original value.
void set_dims(const DDim& dims);
/// \brief Returns the actual storage size occupied by tensor, may be larger /// \brief Returns the actual storage size occupied by tensor, may be larger
/// than its shape dims. /// than its shape dims.
......
...@@ -44,6 +44,8 @@ class Storage : public intrusive_ref_counter<Storage> { ...@@ -44,6 +44,8 @@ class Storage : public intrusive_ref_counter<Storage> {
/// \return The mutable data pointer of the storage. /// \return The mutable data pointer of the storage.
void* data() const noexcept { return data_.operator->(); } void* data() const noexcept { return data_.operator->(); }
virtual void Clear() = 0;
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual const Place& place() const = 0; virtual const Place& place() const = 0;
virtual bool OwnsMemory() const = 0; virtual bool OwnsMemory() const = 0;
...@@ -68,6 +70,12 @@ class TensorStorage : public Storage { ...@@ -68,6 +70,12 @@ class TensorStorage : public Storage {
void Realloc(size_t size) override; void Realloc(size_t size) override;
size_t size() const noexcept override { return size_; } size_t size() const noexcept override { return size_; }
void Clear() override {
data_.Clear();
size_ = 0;
}
const Place& place() const override { return data_.place(); } const Place& place() const override { return data_.place(); }
bool OwnsMemory() const noexcept override { return true; } bool OwnsMemory() const noexcept override { return true; }
const std::shared_ptr<Allocator>& allocator() const noexcept { const std::shared_ptr<Allocator>& allocator() const noexcept {
......
...@@ -115,7 +115,9 @@ TEST(dense_tensor, resize) { ...@@ -115,7 +115,9 @@ TEST(dense_tensor, resize) {
CHECK_EQ(tensor_0.memory_size(), 2u); CHECK_EQ(tensor_0.memory_size(), 2u);
tensor_0.check_memory_size(); tensor_0.check_memory_size();
tensor_0.Resize({1, 2, 3}); tensor_0.Resize({1, 2, 3});
CHECK_EQ(tensor_0.memory_size(), 2u); CHECK_EQ(tensor_0.memory_size(), 0u);
tensor_0.set_dims({2, 3});
CHECK_EQ(tensor_0.memory_size(), 0u);
tensor_0.mutable_data<int8_t>(); tensor_0.mutable_data<int8_t>();
CHECK_EQ(tensor_0.memory_size(), 6u); CHECK_EQ(tensor_0.memory_size(), 6u);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册