提交 779a560c 编写于 作者: H hjchen2

Refine memory optimize pass for multiple blocks

上级 72d3cb39
...@@ -28,8 +28,6 @@ limitations under the License. */ ...@@ -28,8 +28,6 @@ limitations under the License. */
#include "framework/tensor_base.h" #include "framework/tensor_base.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include <iostream>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -84,10 +82,8 @@ class Tensor : public TensorBase { ...@@ -84,10 +82,8 @@ class Tensor : public TensorBase {
int64_t size = numel() * SizeOfType(type); int64_t size = numel() * SizeOfType(type);
if (holder_ == nullptr || holder_->size() < size + offset_) { if (holder_ == nullptr || holder_->size() < size + offset_) {
if (holder_ == nullptr) { if (holder_ == nullptr) {
std::cout << "reset holder... size " << size << std::endl;
holder_.reset(new PlaceholderImpl(size, type)); holder_.reset(new PlaceholderImpl(size, type));
} else { } else {
std::cout << "resize holder... size " << size << std::endl;
holder_->resize(size); holder_->resize(size);
} }
offset_ = 0; offset_ = 0;
......
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace pass { namespace pass {
void MemoryOptPass::InitBlockVars(const framework::BlockDesc *block) { void MemoryOptPass::AppendBlockVars(const framework::BlockDesc *block) {
block_vars_.clear(); // block_vars_.clear();
for (const auto var : block->Vars()) { for (const auto var : block->Vars()) {
block_vars_[var->Name()] = var.get(); block_vars_[var->Name()] = var.get();
} }
...@@ -51,8 +51,8 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -51,8 +51,8 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
framework::Scope *scope) { framework::Scope *scope) {
const auto &blocks = program->Blocks(); const auto &blocks = program->Blocks();
for (const auto &block : blocks) { for (const auto &block : blocks) {
// access all variables in block, and stored in map // access all variables in each block
InitBlockVars(block.get()); AppendBlockVars(block.get());
reused_nodes_.clear(); reused_nodes_.clear();
// collect all not persistable variables, and accumulate // collect all not persistable variables, and accumulate
...@@ -91,6 +91,8 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -91,6 +91,8 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
} }
} }
DLOG << "analysis_nodes_ size: " << analysis_nodes_.size();
// apply optimize // apply optimize
while (!analysis_nodes_.empty()) { while (!analysis_nodes_.empty()) {
auto *node = analysis_nodes_.top(); auto *node = analysis_nodes_.top();
...@@ -117,21 +119,22 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -117,21 +119,22 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
node->visited = true; node->visited = true;
node->count -= 1; node->count -= 1;
} }
}
// shared data within all variables in the same reused list // shared data within all variables in the same reused list
for (const auto &list : reused_nodes_) { for (const auto &list : reused_nodes_) {
DLOG << "\n"; DLOG << "\n";
DLOG << "share memory within these variables"; DLOG << "share memory within these variables";
std::string name = list[0]->name; std::string name = list[0]->name;
auto *reused_var = scope->Var(name); auto *reused_var = scope->Var(name);
auto *reuse_tensor = auto *reuse_tensor =
reused_var->template GetMutable<framework::LoDTensor>(); reused_var->template GetMutable<framework::LoDTensor>();
reuse_tensor->mutable_data<float>(); reuse_tensor->mutable_data<float>();
for (const auto &node : list) { for (const auto &node : list) {
DLOG << node->name; DLOG << node->name;
auto *var = scope->Var(node->name); auto *var = scope->Var(node->name);
auto *tensor = var->template GetMutable<framework::LoDTensor>(); auto *tensor = var->template GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*reuse_tensor); tensor->ShareDataWith(*reuse_tensor);
}
} }
} }
} }
......
...@@ -49,7 +49,7 @@ class MemoryOptPass : public PassBase { ...@@ -49,7 +49,7 @@ class MemoryOptPass : public PassBase {
void operator()(const framework::ProgramDesc *program, void operator()(const framework::ProgramDesc *program,
framework::Scope *scope); framework::Scope *scope);
void InitBlockVars(const framework::BlockDesc *block); void AppendBlockVars(const framework::BlockDesc *block);
bool IsPersistable(const std::string name); bool IsPersistable(const std::string name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册