提交 52b0e584 编写于 作者: 陈后江

Refine

上级 382e2ec9
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "common/util.h" #include "common/util.h"
namespace paddle_mobile {
char *ReadFileToBuff(std::string filename) { char *ReadFileToBuff(std::string filename) {
FILE *file = fopen(filename.c_str(), "rb"); FILE *file = fopen(filename.c_str(), "rb");
PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ", PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ",
...@@ -29,3 +31,5 @@ char *ReadFileToBuff(std::string filename) { ...@@ -29,3 +31,5 @@ char *ReadFileToBuff(std::string filename) {
fclose(file); fclose(file);
return data; return data;
} }
} // namespace paddle_mobile
...@@ -161,10 +161,11 @@ void Executor<Dtype, P>::InitMemory() { ...@@ -161,10 +161,11 @@ void Executor<Dtype, P>::InitMemory() {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue; continue;
} }
char *data = char *origin_data =
ReadFileToBuff(program_.model_path + "/" + var_desc->Name()); ReadFileToBuff(program_.model_path + "/" + var_desc->Name());
char *data = origin_data;
LoadMemory((void**)&data, var_desc, tensor); LoadMemory((void**)&data, var_desc, tensor);
delete [] data; delete [] origin_data;
} else { } else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
varInputMemory(var_desc, var, tensor); varInputMemory(var_desc, var, tensor);
...@@ -176,15 +177,16 @@ void Executor<Dtype, P>::InitMemory() { ...@@ -176,15 +177,16 @@ void Executor<Dtype, P>::InitMemory() {
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
void Executor<Dtype, P>::InitCombineMemory() { void Executor<Dtype, P>::InitCombineMemory() {
char *data = nullptr; char *origin_data = nullptr;
bool self_alloc = false; bool self_alloc = false;
if (program_.combined_params_buf && program_.combined_params_len) { if (program_.combined_params_buf && program_.combined_params_len) {
data = (char *)program_.combined_params_buf; origin_data = (char *)program_.combined_params_buf;
} else { } else {
self_alloc = true; self_alloc = true;
data = ReadFileToBuff(program_.para_path); origin_data = ReadFileToBuff(program_.para_path);
} }
PADDLE_MOBILE_ENFORCE(data != nullptr, "data == nullptr"); PADDLE_MOBILE_ENFORCE(origin_data != nullptr, "data == nullptr");
char *data = origin_data;
for (const auto &block : to_predict_program_->Blocks()) { for (const auto &block : to_predict_program_->Blocks()) {
for (const auto &var_desc : block->Vars()) { for (const auto &var_desc : block->Vars()) {
auto var = program_.scope->Var(var_desc->Name()); auto var = program_.scope->Var(var_desc->Name());
...@@ -202,7 +204,7 @@ void Executor<Dtype, P>::InitCombineMemory() { ...@@ -202,7 +204,7 @@ void Executor<Dtype, P>::InitCombineMemory() {
} }
} }
if (self_alloc) { if (self_alloc) {
delete [] data; delete [] origin_data;
} }
LOG(kLOG_INFO) << "init combine memory finish"; LOG(kLOG_INFO) << "init combine memory finish";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册