提交 e2e67b04 编写于 作者: J Jiaying Zhao 提交者: GitHub

reduce memery when input change small in lod_mode (#1756)

上级 c1f540ca
......@@ -299,31 +299,26 @@ void Executor<Device, T>::InitNoPersistableMemory(const Tensor &input_tensor) {
for (const auto &block : program_desc_->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = program_.scope->Var(var_desc->Name());
auto tensor = var->template GetMutable<LoDTensor>();
if (var_desc->Persistable()) {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<framework::LoDTensorArray>();
continue;
}
} else {
if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
if (!var_desc->Persistable() &&
var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
DLOG << "InitNoPersistableMemory var " << var_desc->Name();
auto tensor = var->template GetMutable<LoDTensor>();
if (tensor->IsInitialized()) {
DLOG << "var's tensor is Initialized";
DDim tensor_dim = tensor->dims();
DDim new_dim =
make_ddim({tensor_dim[0], tensor_dim[1], input_tensor.dims()[2],
input_tensor.dims()[3]});
tensor->Resize(new_dim);
tensor->template mutable_data<T>();
tensor->template mutable_data_new<T>();
DLOG << "var's tensor dims " << tensor_dim;
DLOG << "var's tensor new dims " << new_dim;
} else {
PADDLE_MOBILE_THROW_EXCEPTION("Unsupported var type `%d`",
var_desc->Type());
DLOG << "var's tensor is not Initialized ???";
}
}
}
}
std::shared_ptr<LoDTensor> output = GetOutput("fetch");
output->Resize(input_tensor.dims());
output->mutable_data<T>();
}
template <typename Device, typename T>
......@@ -411,6 +406,9 @@ void Executor<Device, T>::SetInput(const Tensor &input,
target.ShareDataWith(input);
if (feed_indices_.size() == 1) {
auto &dim = input.dims();
if (lod_mode_ && product(dim) < 0.9 * product(input_dim_last_)) {
InitNoPersistableMemory(target);
}
input_dim_has_changed_ = input_dim_last_ != dim;
input_dim_last_ = static_cast<DDim>(dim);
}
......@@ -432,6 +430,9 @@ void Executor<Device, T>::SetInput(const LoDTensor &input,
target.set_lod(input.lod());
if (feed_indices_.size() == 1) {
auto &dim = input.dims();
if (lod_mode_ && product(dim) < 0.9 * product(input_dim_last_)) {
InitNoPersistableMemory(target);
}
input_dim_has_changed_ = input_dim_last_ != dim;
input_dim_last_ = static_cast<DDim>(dim);
}
......
......@@ -68,7 +68,8 @@ class Tensor : public TensorBase {
Resize(ddim);
auto type = type_id<T>().hash_code();
int64_t size = numel() * SizeOfType(type);
holder_.reset(new PlaceholderImpl(size, type, (uint8_t *)input));
holder_.reset(
new PlaceholderImpl(size, type, reinterpret_cast<uint8_t *>(input)));
holder_->set_type(type);
offset_ = 0;
}
......@@ -103,6 +104,29 @@ class Tensor : public TensorBase {
return *this;
}
template <typename T>
inline T *mutable_data_new() {
static_assert(std::is_pod<T>::value, "T must be POD");
const kTypeId_t type = type_id<T>().hash_code();
if (holder_ != nullptr) {
holder_->set_type(type);
}
PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.")
int64_t size = numel() * SizeOfType(type);
if (holder_ == nullptr || holder_->size() != size + offset_) {
if (holder_ == nullptr) {
holder_.reset(new PlaceholderImpl(size, type));
} else {
holder_->realloc(size);
}
offset_ = 0;
}
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
inline void *mutable_data(const kTypeId_t type) {
if (holder_ != nullptr) {
holder_->set_type(type);
......@@ -244,6 +268,12 @@ class Tensor : public TensorBase {
size_ = size;
}
virtual void realloc(size_t size) {
capatity_ = size;
ptr_.reset(static_cast<uint8_t *>(memory::Alloc(capatity_)));
size_ = size;
}
std::unique_ptr<uint8_t, std::function<void(uint8_t *)>> ptr_;
/*! the size of memory block. */
......
......@@ -117,6 +117,8 @@ class TensorBase {
virtual void set_type(kTypeId_t type) = 0;
virtual void resize(size_t size) = 0;
virtual void realloc(size_t size) = 0;
};
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册