提交 8e218ac8 编写于 作者: W WangLiu 提交者: GitHub

Merge pull request #389 from codeWorm2015/develop

fix #388 support combine format fluid model
......@@ -11,10 +11,10 @@ else()
set(CMAKE_BUILD_TYPE Release)
endif ()
if(DEBUGING)
message(STATUS "debuging")
add_definitions(-DPADDLE_MOBILE_DEBUG)
else()
message(STATUS "releasing")
add_definitions(-fvisibility=hidden -fvisibility-inlines-hidden)
......@@ -25,7 +25,6 @@ if (USE_EXCEPTION)
add_definitions(-fexceptions)
else()
add_definitions(-fno-exceptions)
endif ()
if(IS_MAC)
......@@ -119,7 +118,6 @@ else ()
add_definitions(-DTRANSPOSE_OP)
endif()
add_library(paddle-mobile SHARED ${PADDLE_MOBILE_CC} ${PADDLE_MOBILE_H})
if(DEBUGING)
......
......@@ -279,17 +279,14 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
template <typename Dtype, Precision P>
void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor,
const std::string &file_path, char *data) {
framework::LoDTensor *tensor, char *&data) {
// 1. version
uint32_t version = *(uint32_t *)data;
data += sizeof(uint32_t);
DLOG << "version: " << version;
// 2 Lod information
uint64_t lod_level = *(uint64_t *)data;
data += sizeof(uint64_t);
DLOG << "lod_level: " << lod_level;
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
......@@ -297,7 +294,6 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
uint64_t size = *(uint64_t *)data;
data += sizeof(uint64_t);
DLOG << "lod size: " << i << size;
std::vector<size_t> tmp(size / sizeof(size_t));
for (int k = 0; k < tmp.size(); ++k) {
......@@ -315,12 +311,10 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
// 3. tensor version
uint32_t tensor_version = *(uint32_t *)data;
data += sizeof(uint32_t);
DLOG << "tensor_version: " << tensor_version;
// 4. tensor desc
int32_t size = *(int32_t *)data;
data += sizeof(int32_t);
DLOG << "tensor desc size: " << size;
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
......@@ -344,7 +338,6 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
break;
case framework::VARTYPE_TYPE_FP32:
type_size = 4;
DLOG << " type size: " << type_size;
memory = tensor->mutable_data<float>();
break;
case framework::VARTYPE_TYPE_FP64:
......@@ -382,8 +375,8 @@ void Executor<Dtype, P>::InitMemory() {
char *origin_data =
Get_binary_data(program_.model_path + "/" + var_desc->Name());
LoadMemory(*var_desc, tensor,
program_.model_path + "/" + var_desc->Name(), origin_data);
char *data = origin_data;
LoadMemory(*var_desc, tensor, data);
delete origin_data;
} else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
......@@ -399,7 +392,7 @@ void Executor<Dtype, P>::InitMemory() {
template <typename Dtype, Precision P>
void Executor<Dtype, P>::InitCombineMemory() {
char *origin_data = Get_binary_data(program_.para_path);
char *data = origin_data;
for (const auto &block : to_predict_program_->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = program_.scope->Var(var_desc->Name());
......@@ -408,18 +401,15 @@ void Executor<Dtype, P>::InitCombineMemory() {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
LoadMemory(*var_desc, tensor,
program_.model_path + "/" + var_desc->Name(), origin_data);
LoadMemory(*var_desc, tensor, data);
} else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto tensor = var->template GetMutable<framework::LoDTensor>();
tensor->template mutable_data<Ptype>();
}
}
}
}
delete origin_data;
}
......
......@@ -63,8 +63,7 @@ class Executor {
void InitMemory();
void LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor, const std::string &file_path,
char *data);
framework::LoDTensor *tensor, char *&data);
void InitCombineMemory();
framework::Program<Dtype> program_;
int batch_size_ = 1;
......
......@@ -20,10 +20,9 @@ int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
bool optimize = false;
auto time1 = time();
auto program = loader.Load(g_googlenet, optimize);
// auto program = loader.Load(g_googlenet_combine + "/model",
// g_googlenet_combine + "/params", optimize);
// auto program = loader.Load(g_googlenet, optimize);
auto program = loader.Load(g_googlenet_combine + "/model",
g_googlenet_combine + "/params", optimize);
auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册