提交 ba8d089d 编写于 作者: H hjchen2

Refine memory optimize

上级 be8788b4
...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "framework/executor.h"
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h"
#include "framework/context.h" #include "framework/context.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
#include "framework/lod_tensor.h" #include "framework/lod_tensor.h"
...@@ -27,9 +27,8 @@ limitations under the License. */ ...@@ -27,9 +27,8 @@ limitations under the License. */
#include "framework/program/var_desc.h" #include "framework/program/var_desc.h"
#include "framework/scope.h" #include "framework/scope.h"
#include "framework/tensor.h" #include "framework/tensor.h"
#include "memory/t_malloc.h" #include "framework/executor.h"
#include "pass/memory_optimize.h" #include "pass/memory_optimize.h"
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_image.h" #include "framework/cl/cl_image.h"
#endif #endif
...@@ -217,6 +216,7 @@ void Executor<Device, T>::InitMemory() { ...@@ -217,6 +216,7 @@ void Executor<Device, T>::InitMemory() {
var->template GetMutable<framework::LoDTensorArray>(); var->template GetMutable<framework::LoDTensorArray>();
continue; continue;
} }
DLOG << "init persistable var: " << var_desc->Name();
char *origin_data = char *origin_data =
ReadFileToBuff(program_.model_path + "/" + var_desc->Name()); ReadFileToBuff(program_.model_path + "/" + var_desc->Name());
char *data = origin_data; char *data = origin_data;
...@@ -329,7 +329,6 @@ bool Executor<Device, T>::varInputMemory( ...@@ -329,7 +329,6 @@ bool Executor<Device, T>::varInputMemory(
if (type == VARTYPE_TYPE_LOD_TENSOR) { if (type == VARTYPE_TYPE_LOD_TENSOR) {
auto data_type = var_desc->Tensor_desc().DataType(); auto data_type = var_desc->Tensor_desc().DataType();
framework::LoDTensor *tensor = var->template GetMutable<LoDTensor>(); framework::LoDTensor *tensor = var->template GetMutable<LoDTensor>();
tensor->mutable_data(TypeId(data_type));
} else if (type == VARTYPE_TYPE_STEP_SCOPES) { } else if (type == VARTYPE_TYPE_STEP_SCOPES) {
std::vector<framework::Scope *> *step_scopes = std::vector<framework::Scope *> *step_scopes =
var->template GetMutable<std::vector<framework::Scope *>>(); var->template GetMutable<std::vector<framework::Scope *>>();
...@@ -465,6 +464,7 @@ PMStatus Executor<Device, T>::Predict() { ...@@ -465,6 +464,7 @@ PMStatus Executor<Device, T>::Predict() {
clock_gettime(CLOCK_MONOTONIC, &ts); clock_gettime(CLOCK_MONOTONIC, &ts);
profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif #endif
DLOG << "run op: " << op_handler->Type();
if (lod_mode_) { if (lod_mode_) {
op_handler->InferShape(); op_handler->InferShape();
} }
......
...@@ -28,6 +28,8 @@ limitations under the License. */ ...@@ -28,6 +28,8 @@ limitations under the License. */
#include "framework/tensor_base.h" #include "framework/tensor_base.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include <iostream>
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -69,7 +71,6 @@ class Tensor : public TensorBase { ...@@ -69,7 +71,6 @@ class Tensor : public TensorBase {
inline Tensor &ShareDataWith(const Tensor &src) { inline Tensor &ShareDataWith(const Tensor &src) {
src.check_memory_size(); src.check_memory_size();
if (holder_.get() != src.holder_.get()) { if (holder_.get() != src.holder_.get()) {
// *this = src;
holder_ = src.holder_; holder_ = src.holder_;
} }
return *this; return *this;
...@@ -82,7 +83,13 @@ class Tensor : public TensorBase { ...@@ -82,7 +83,13 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.") PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.")
int64_t size = numel() * SizeOfType(type); int64_t size = numel() * SizeOfType(type);
if (holder_ == nullptr || holder_->size() < size + offset_) { if (holder_ == nullptr || holder_->size() < size + offset_) {
holder_.reset(new PlaceholderImpl(size, type)); if (holder_ == nullptr) {
std::cout << "reset holder... size " << size << std::endl;
holder_.reset(new PlaceholderImpl(size, type));
} else {
std::cout << "resize holder... size " << size << std::endl;
holder_->resize(size);
}
offset_ = 0; offset_ = 0;
} }
return reinterpret_cast<void *>( return reinterpret_cast<void *>(
...@@ -181,6 +188,7 @@ class Tensor : public TensorBase { ...@@ -181,6 +188,7 @@ class Tensor : public TensorBase {
: ptr_(static_cast<uint8_t *>(memory::Alloc(size)), : ptr_(static_cast<uint8_t *>(memory::Alloc(size)),
memory::PODDeleter<uint8_t>()), memory::PODDeleter<uint8_t>()),
size_(size), size_(size),
capatity_(size),
type_(type) { type_(type) {
PADDLE_MOBILE_ENFORCE(ptr_ != nullptr, PADDLE_MOBILE_ENFORCE(ptr_ != nullptr,
"Insufficient memory to allocation"); "Insufficient memory to allocation");
...@@ -194,11 +202,21 @@ class Tensor : public TensorBase { ...@@ -194,11 +202,21 @@ class Tensor : public TensorBase {
virtual void set_type(std::type_index type) { type_ = type; } virtual void set_type(std::type_index type) { type_ = type; }
virtual void resize(size_t size) {
if (size > capatity_) {
capatity_ = size;
ptr_.reset(static_cast<uint8_t *>(memory::Alloc(capatity_)));
}
size_ = size;
}
std::unique_ptr<uint8_t, memory::PODDeleter<uint8_t>> ptr_; std::unique_ptr<uint8_t, memory::PODDeleter<uint8_t>> ptr_;
/*! the size of memory block. */ /*! the size of memory block. */
size_t size_; size_t size_;
size_t capatity_;
/* the current type of memory */ /* the current type of memory */
std::type_index type_; std::type_index type_;
}; };
......
...@@ -117,6 +117,8 @@ class TensorBase { ...@@ -117,6 +117,8 @@ class TensorBase {
virtual std::type_index type() const = 0; virtual std::type_index type() const = 0;
virtual void set_type(std::type_index type) = 0; virtual void set_type(std::type_index type) = 0;
virtual void resize(size_t size) = 0;
}; };
/** /**
......
...@@ -54,7 +54,6 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -54,7 +54,6 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
// access all variables in block, and stored in map // access all variables in block, and stored in map
InitBlockVars(block.get()); InitBlockVars(block.get());
visited_nodes_.clear();
reused_nodes_.clear(); reused_nodes_.clear();
// collect all not persistable variables, and accumulate // collect all not persistable variables, and accumulate
// it's reference count // it's reference count
...@@ -63,8 +62,7 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -63,8 +62,7 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
for (const auto &op : block->Ops()) { for (const auto &op : block->Ops()) {
DLOG << "op_desc->Type(): " << op->Type(); DLOG << "op_desc->Type(): " << op->Type();
const auto &outputs_map = op->GetOutputs(); for (const auto &outputs : op->GetOutputs()) {
for (const auto &outputs : outputs_map) {
for (const auto &output : outputs.second) { for (const auto &output : outputs.second) {
if (!IsPersistable(output)) { if (!IsPersistable(output)) {
DLOG << "output: " << output; DLOG << "output: " << output;
...@@ -73,8 +71,7 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -73,8 +71,7 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
} }
} }
} }
const auto &inputs_map = op->GetInputs(); for (const auto &inputs : op->GetInputs()) {
for (const auto &inputs : inputs_map) {
for (const auto &input : inputs.second) { for (const auto &input : inputs.second) {
if (!IsPersistable(input)) { if (!IsPersistable(input)) {
DLOG << "input: " << input; DLOG << "input: " << input;
...@@ -83,6 +80,15 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -83,6 +80,15 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
} }
} }
} }
for (const auto &outputs : op->GetOutputs()) {
for (const auto &output : outputs.second) {
if (!IsPersistable(output)) {
DLOG << "output: " << output;
VarNode *node = CreateNode(output);
analysis_nodes_.push(node);
}
}
}
} }
// apply optimize // apply optimize
...@@ -115,7 +121,7 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program, ...@@ -115,7 +121,7 @@ void MemoryOptPass::operator()(const framework::ProgramDesc *program,
// shared data within all variables in the same reused list // shared data within all variables in the same reused list
for (const auto &list : reused_nodes_) { for (const auto &list : reused_nodes_) {
DLOG << "\n"; DLOG << "\n";
DLOG << "share data within these variables"; DLOG << "share memory within these variables";
std::string name = list[0]->name; std::string name = list[0]->name;
auto *reused_var = scope->Var(name); auto *reused_var = scope->Var(name);
auto *reuse_tensor = auto *reuse_tensor =
......
...@@ -59,7 +59,6 @@ class MemoryOptPass : public PassBase { ...@@ -59,7 +59,6 @@ class MemoryOptPass : public PassBase {
std::stack<VarNode *> analysis_nodes_; std::stack<VarNode *> analysis_nodes_;
std::vector<std::vector<VarNode *>> reused_nodes_; std::vector<std::vector<VarNode *>> reused_nodes_;
std::unordered_map<std::string, VarNode *> created_nodes_; std::unordered_map<std::string, VarNode *> created_nodes_;
std::unordered_map<std::string, VarNode *> visited_nodes_;
std::unordered_map<std::string, framework::VarDesc *> block_vars_; std::unordered_map<std::string, framework::VarDesc *> block_vars_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册