提交 65ae9d0f 编写于 作者: H hjchen2

Refine memory optimize pass for multiple blocks

上级 ba8d089d
......@@ -28,8 +28,6 @@ limitations under the License. */
#include "framework/tensor_base.h"
#include "memory/t_malloc.h"
#include <iostream>
namespace paddle_mobile {
namespace framework {
......@@ -84,10 +82,8 @@ class Tensor : public TensorBase {
int64_t size = numel() * SizeOfType(type);
if (holder_ == nullptr || holder_->size() < size + offset_) {
if (holder_ == nullptr) {
std::cout << "reset holder... size " << size << std::endl;
holder_.reset(new PlaceholderImpl(size, type));
} else {
std::cout << "resize holder... size " << size << std::endl;
holder_->resize(size);
}
offset_ = 0;
......
......@@ -18,8 +18,8 @@ limitations under the License. */
namespace paddle_mobile {
namespace pass {
void MemoryOptPass::InitBlockVars(const framework::BlockDesc *block) {
block_vars_.clear();
void MemoryOptPass::AppendBlockVars(const framework::BlockDesc *block) {
// block_vars_.clear();
for (const auto var : block->Vars()) {
block_vars_[var->Name()] = var.get();
}
......@@ -51,8 +51,8 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
framework::Scope *scope) {
const auto &blocks = program->Blocks();
for (const auto &block : blocks) {
// access all variables in block, and stored in map
InitBlockVars(block.get());
// access all variables in each block
AppendBlockVars(block.get());
reused_nodes_.clear();
// collect all not persistable variables, and accumulate
......@@ -91,6 +91,8 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
}
}
DLOG << "analysis_nodes_ size: " << analysis_nodes_.size();
// apply optimize
while (!analysis_nodes_.empty()) {
auto *node = analysis_nodes_.top();
......@@ -117,21 +119,22 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
node->visited = true;
node->count -= 1;
}
}
// shared data within all variables in the same reused list
for (const auto &list : reused_nodes_) {
DLOG << "\n";
DLOG << "share memory within these variables";
std::string name = list[0]->name;
auto *reused_var = scope->Var(name);
auto *reuse_tensor =
reused_var->template GetMutable<framework::LoDTensor>();
reuse_tensor->mutable_data<float>();
for (const auto &node : list) {
DLOG << node->name;
auto *var = scope->Var(node->name);
auto *tensor = var->template GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*reuse_tensor);
// shared data within all variables in the same reused list
for (const auto &list : reused_nodes_) {
DLOG << "\n";
DLOG << "share memory within these variables";
std::string name = list[0]->name;
auto *reused_var = scope->Var(name);
auto *reuse_tensor =
reused_var->template GetMutable<framework::LoDTensor>();
reuse_tensor->mutable_data<float>();
for (const auto &node : list) {
DLOG << node->name;
auto *var = scope->Var(node->name);
auto *tensor = var->template GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*reuse_tensor);
}
}
}
}
......
......@@ -49,7 +49,7 @@ class MemoryOptPass : public PassBase {
void operator()(const framework::ProgramDesc *program,
framework::Scope *scope);
void InitBlockVars(const framework::BlockDesc *block);
void AppendBlockVars(const framework::BlockDesc *block);
bool IsPersistable(const std::string name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册