You need to sign in or sign up before continuing.
未验证 提交 4d536678 编写于 作者: 石晓伟 提交者: GitHub

adjust the COLUMNS=128; (#37120)

上级 9396f286
......@@ -34,6 +34,11 @@ class ExternalStorage : public pten::Storage {
"The external shared storage cannot be reallocated."));
}
void Clear() override {
data_.Clear();
size_ = 0;
}
size_t size() const noexcept override { return size_; }
const paddle::platform::Place& place() const override {
return data_.place();
......@@ -41,7 +46,7 @@ class ExternalStorage : public pten::Storage {
bool OwnsMemory() const noexcept override { return false; }
private:
const int64_t size_{0};
int64_t size_{0};
};
class SharedStorage : public pten::Storage {
......@@ -65,6 +70,11 @@ class SharedStorage : public pten::Storage {
"The external shared storage cannot be reallocated."));
}
void Clear() override {
data_.Clear();
size_ = 0;
}
size_t size() const noexcept override { return size_; }
const paddle::platform::Place& place() const override {
return data_.place();
......
......@@ -73,9 +73,9 @@ class Allocation final {
operator bool() const noexcept { return data_ || ctx_.Get(); }
const Place& place() const noexcept { return place_; }
void Clear() noexcept {
data_ = nullptr;
void Clear() {
ctx_.Clear();
data_ = nullptr;
}
/// \brief Statically cast the void pointer of the context object to
......@@ -107,12 +107,11 @@ class Allocation final {
swap(*this, other);
return *this;
}
~Context() {
~Context() { Clear(); }
void Clear() {
if (deleter_) {
deleter_(ctx_);
}
}
void Clear() noexcept {
ctx_ = nullptr;
deleter_ = nullptr;
}
......
......@@ -113,6 +113,20 @@ void DenseTensor::check_memory_size() const {
bytes));
}
void DenseTensor::Resize(const DDim& dims) {
if (product(dims) == product(meta_.dims)) {
set_dims(dims);
} else {
meta_.dims = dims;
storage_->Clear();
}
}
void DenseTensor::set_dims(const DDim& dims) {
CHECK(product(dims) == product(meta_.dims));
meta_.dims = dims;
}
#define DATA_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DenseTensor::mutable_data(); \
template const dtype* DenseTensor::data() const;
......
......@@ -119,10 +119,16 @@ class DenseTensor : public TensorBase,
/// \return Whether the storage is shared with other objects.
bool IsSharedWith(const DenseTensor& b) const;
/// \brief Change the dims information in the metadata, and the corresponding
/// memory allocation will occur when the `mutable_data` is called.
/// \brief Change the dims information in the metadata. If the new size is
/// inconsistent with the original value, the storage area will be released
/// to avoid wrong access.
/// \param dims The new dims of the dense tensor.
void Resize(const DDim& dims) noexcept { meta_.dims = dims; }
void Resize(const DDim& dims);
/// \brief Change the dims information in the metadata.
/// \param dims The new dims of the dense tensor. The product of the dims
/// elements must be consistent with the original value.
void set_dims(const DDim& dims);
/// \brief Returns the actual storage size occupied by tensor, may be larger
/// than its shape dims.
......
......@@ -44,6 +44,8 @@ class Storage : public intrusive_ref_counter<Storage> {
/// \return The mutable data pointer of the storage.
void* data() const noexcept { return data_.operator->(); }
virtual void Clear() = 0;
virtual size_t size() const = 0;
virtual const Place& place() const = 0;
virtual bool OwnsMemory() const = 0;
......@@ -68,6 +70,12 @@ class TensorStorage : public Storage {
void Realloc(size_t size) override;
size_t size() const noexcept override { return size_; }
void Clear() override {
data_.Clear();
size_ = 0;
}
const Place& place() const override { return data_.place(); }
bool OwnsMemory() const noexcept override { return true; }
const std::shared_ptr<Allocator>& allocator() const noexcept {
......
......@@ -115,7 +115,9 @@ TEST(dense_tensor, resize) {
CHECK_EQ(tensor_0.memory_size(), 2u);
tensor_0.check_memory_size();
tensor_0.Resize({1, 2, 3});
CHECK_EQ(tensor_0.memory_size(), 2u);
CHECK_EQ(tensor_0.memory_size(), 0u);
tensor_0.set_dims({2, 3});
CHECK_EQ(tensor_0.memory_size(), 0u);
tensor_0.mutable_data<int8_t>();
CHECK_EQ(tensor_0.memory_size(), 6u);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册