提交 cc4f55c8 编写于 作者: D dolphin8

Merge remote-tracking branch 'upstream/develop' into develop

......@@ -106,11 +106,14 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
}
std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
for (int m = 0; m < nodes.size(); ++m) {
auto &node = nodes[m];
op_descs.push_back(node->op_desc_);
if (add_split) {
GenerateOps(&op_descs, begin_node.get(), add_split);
} else {
for (int m = 0; m < nodes.size(); ++m) {
auto &node = nodes[m];
op_descs.push_back(node->op_desc_);
}
}
// GenerateOps(&op_descs, begin_node.get());
block->ops_ = op_descs;
}
......@@ -267,12 +270,12 @@ void ProgramOptimize::GenerateOps(
}
void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_descs,
Node *begin_node) {
std::vector<std::shared_ptr<framework::OpDesc>> *op_descs, Node *begin_node,
bool can_add_split) {
// std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
// Node *input_node, Node *current_node, bool adding_thread, int
// thread_num
if (false) {
if (can_add_split) {
this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr);
} else {
this->GenerateOps(op_descs, begin_node, begin_node);
......
......@@ -34,7 +34,7 @@ class ProgramOptimize {
int current_block_;
std::vector<std::shared_ptr<BlockDesc>> new_blocks_;
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_descs,
Node *begin_node);
Node *begin_node, bool can_add_split);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *input_node, Node *current_node);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
......
......@@ -76,8 +76,9 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) {
template <typename Dtype, Precision P>
const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &dirname, bool optimize) {
auto program = this->LoadProgram(dirname + "/__model__", optimize);
const std::string &dirname, bool optimize, bool can_add_split) {
auto program =
this->LoadProgram(dirname + "/__model__", optimize, can_add_split);
program.model_path = dirname;
return program;
}
......@@ -94,7 +95,7 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
template <typename Dtype, Precision P>
const framework::Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
const std::string &model_path, bool optimize) {
const std::string &model_path, bool optimize, bool can_add_split) {
std::string model_filename = model_path;
PaddleMobile__Framework__Proto__ProgramDesc *c_program;
uint8_t *buf = NULL;
......@@ -146,7 +147,7 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
if (optimize) {
framework::ProgramOptimize program_optimize;
program.optimizeProgram =
program_optimize.FushionOptimize(originProgramDesc);
program_optimize.FushionOptimize(originProgramDesc, can_add_split);
}
if (optimize) {
program.optimizeProgram->Description("optimize: ");
......@@ -310,6 +311,7 @@ void Executor<Dtype, P>::InitMemory() {
template <typename Dtype, Precision P>
void Executor<Dtype, P>::InitCombineMemory() {
LOG(kLOG_INFO) << " begin init combine memory";
char *origin_data = Get_binary_data(program_.para_path);
char *data = origin_data;
for (const auto &block : to_predict_program_->Blocks()) {
......@@ -330,6 +332,7 @@ void Executor<Dtype, P>::InitCombineMemory() {
}
}
delete origin_data;
LOG(kLOG_INFO) << " end init combine memory ";
}
template <typename Dtype, Precision P>
......
......@@ -35,7 +35,8 @@ class Loader {
* @b 加载分开形式的 fluid 模型
* */
const framework::Program<Dtype, P> Load(const std::string &dirname,
bool optimize = false);
bool optimize = false,
bool can_add_split = false);
/*
* @b load combine format fluid mode
......@@ -47,7 +48,8 @@ class Loader {
private:
const framework::Program<Dtype, P> LoadProgram(const std::string &model_path,
bool optimize = false);
bool optimize = false,
bool can_add_split = false);
};
template <typename Dtype = CPU, Precision P = Precision::FP32>
......
......@@ -19,9 +19,10 @@ int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
// ../../../test/models/googlenet
// ../../../test/models/mobilenet
auto program = loader.Load(g_resnet, true);
loader.Load(g_googlenet_combine + "/model", g_googlenet_combine + "/params",
true);
auto program = loader.Load(g_googlenet, true, true);
// loader.Load(g_googlenet_combine + "/model", g_googlenet_combine +
// "/params",
// true);
program.originProgram->Description("program desc: ");
return 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册