提交 2b92e037 编写于 作者: L liuruilong

fix fc crash

上级 7da97bcf
...@@ -77,7 +77,7 @@ static const std::string G_OP_TYPE_BATCHNORM = "batch_norm"; ...@@ -77,7 +77,7 @@ static const std::string G_OP_TYPE_BATCHNORM = "batch_norm";
static const std::string G_OP_TYPE_BOX_CODER = "box_coder"; static const std::string G_OP_TYPE_BOX_CODER = "box_coder";
static const std::string G_OP_TYPE_CONCAT = "concat"; static const std::string G_OP_TYPE_CONCAT = "concat";
static const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; static const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
static const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "FusionConvAddRelu"; static const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
static const std::string G_OP_TYPE_FC = "fc"; static const std::string G_OP_TYPE_FC = "fc";
static const std::string G_OP_TYPE_LRN = "lrn"; static const std::string G_OP_TYPE_LRN = "lrn";
static const std::string G_OP_TYPE_MUL = "mul"; static const std::string G_OP_TYPE_MUL = "mul";
...@@ -92,6 +92,7 @@ static const std::string G_OP_TYPE_TRANSPOSE = "transpose"; ...@@ -92,6 +92,7 @@ static const std::string G_OP_TYPE_TRANSPOSE = "transpose";
static const std::string G_OP_TYPE_SPLIT = "split"; static const std::string G_OP_TYPE_SPLIT = "split";
static const std::string G_OP_TYPE_FEED = "feed"; static const std::string G_OP_TYPE_FEED = "feed";
static const std::string G_OP_TYPE_FETCH = "fetch"; static const std::string G_OP_TYPE_FETCH = "fetch";
static const std::string G_OP_TYPE_DEPTHWISE_CONV = "depthwise_conv2d";
static std::unordered_map< static std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>> std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
......
...@@ -45,6 +45,47 @@ bool Node::operator==(const Node &in) { ...@@ -45,6 +45,47 @@ bool Node::operator==(const Node &in) {
return true; return true;
} }
bool Node::CanSplit(std::unordered_set<std::string> complex_compute_set) {
bool split = false;
CanSplit(&split, false, 0, &complex_compute_set, this);
return split;
}
void Node::CanSplit(bool *split, bool spliting,
int complex_count,
std::unordered_set<std::string> *complex_compute_set, Node *pre_node) {
if (spliting) {
if (complex_compute_set->find(this->type_) != complex_compute_set->end()) {
complex_count++;
}
}
if (inputs_.size() > 1 && pre_node != inputs_.back()) {
return;
}
if (inputs_.size() > 1 && pre_node == inputs_.back()) {
if (complex_count > 1) {
*split = true;
return;
}
}
// multi output, to check
if (outputs_.size() > 1) {
spliting = true;
complex_compute_set = 0;
} else {
if (spliting == true && inputs_.size() > 0) {
spliting = false;
} else {
}
}
for (auto &output : outputs_) {
output->CanSplit(split, spliting, complex_count, complex_compute_set, this);
}
}
std::vector<std::shared_ptr<framework::OpDesc>> Node::OpDescs(uint size) { std::vector<std::shared_ptr<framework::OpDesc>> Node::OpDescs(uint size) {
std::vector<std::shared_ptr<framework::OpDesc>> op_descs; std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
OpDescs(size - 1, &op_descs); OpDescs(size - 1, &op_descs);
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <unordered_set>
#include "common/log.h" #include "common/log.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
...@@ -36,6 +37,7 @@ class Node : PaddleMobileObject { ...@@ -36,6 +37,7 @@ class Node : PaddleMobileObject {
: op_desc_(op_desc), type_(op_desc->Type()) {} : op_desc_(op_desc), type_(op_desc->Type()) {}
Node &operator>(std::shared_ptr<Node> node); Node &operator>(std::shared_ptr<Node> node);
bool operator==(const Node &in); bool operator==(const Node &in);
bool CanSplit(std::unordered_set<std::string> complex_compute_set);
std::string ToString() const; std::string ToString() const;
std::shared_ptr<Node> To(int size); std::shared_ptr<Node> To(int size);
uint Depth(uint begin = 0); uint Depth(uint begin = 0);
...@@ -49,6 +51,9 @@ class Node : PaddleMobileObject { ...@@ -49,6 +51,9 @@ class Node : PaddleMobileObject {
void Description(); void Description();
private: private:
void CanSplit(bool *split, bool spliting,
int complex_count,
std::unordered_set<std::string> *complex_compute_set, Node *pre_node);
void OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, void OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *node, bool adding_thread, int thread_num); Node *node, bool adding_thread, int thread_num);
void OpDescs(uint size, void OpDescs(uint size,
......
...@@ -99,6 +99,7 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize( ...@@ -99,6 +99,7 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
// DLOG << "node: \n" << *begin_node; // DLOG << "node: \n" << *begin_node;
std::vector<std::shared_ptr<framework::OpDesc>> op_descs; std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
// bool can_splite = begin_node->CanSplit({G_OP_TYPE_CONV, G_OP_TYPE_BATCHNORM, G_OP_TYPE_DEPTHWISE_CONV});
GenerateOps(&op_descs, begin_node.get()); GenerateOps(&op_descs, begin_node.get());
block->ops_ = op_descs; block->ops_ = op_descs;
} }
...@@ -111,6 +112,28 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize( ...@@ -111,6 +112,28 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
return optimize_program; return optimize_program;
} }
void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *input_node,
Node *current_node) {
if (current_node->inputs_.size() > 1 &&
input_node != current_node->inputs_.back()) {
return;
} else if (current_node->inputs_.size() > 1 &&
input_node == current_node->inputs_.back()) {
op_desc->push_back(current_node->op_desc_);
} else {
op_desc->push_back(current_node->op_desc_);
}
for (int i = 0; i < current_node->outputs_.size(); ++i) {
auto &output = current_node->outputs_[i];
GenerateOps(op_desc, current_node, output.get());
}
}
void ProgramOptimize::GenerateOps( void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *input_node, std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *input_node,
Node *current_node, bool adding_thread, int thread_num, Node *current_node, bool adding_thread, int thread_num,
...@@ -234,7 +257,11 @@ void ProgramOptimize::GenerateOps( ...@@ -234,7 +257,11 @@ void ProgramOptimize::GenerateOps(
// std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, // std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
// Node *input_node, Node *current_node, bool adding_thread, int // Node *input_node, Node *current_node, bool adding_thread, int
// thread_num // thread_num
this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr); if (false) {
this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr);
} else {
this->GenerateOps(op_descs, begin_node, begin_node);
}
} }
} // namespace framework } // namespace framework
......
...@@ -33,9 +33,11 @@ class ProgramOptimize { ...@@ -33,9 +33,11 @@ class ProgramOptimize {
private: private:
int current_block_; int current_block_;
std::vector<std::shared_ptr<BlockDesc>> new_blocks_; std::vector<std::shared_ptr<BlockDesc>> new_blocks_;
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_descs, void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_descs,
Node *begin_node); Node *begin_node);
void GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *input_node,
Node *current_node);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *input_node, Node *current_node, bool adding_thread, Node *input_node, Node *current_node, bool adding_thread,
int thread_num, std::shared_ptr<BlockDesc> new_block); int thread_num, std::shared_ptr<BlockDesc> new_block);
......
...@@ -220,13 +220,18 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load( ...@@ -220,13 +220,18 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
} }
} }
} }
originProgramDesc->Description("program: ");
if (optimize) { if (optimize) {
framework::ProgramOptimize program_optimize; framework::ProgramOptimize program_optimize;
program.optimizeProgram = program.optimizeProgram =
program_optimize.FushionOptimize(originProgramDesc); program_optimize.FushionOptimize(originProgramDesc);
} }
if (optimize) {
program.optimizeProgram->Description("optimize: ");
} else {
originProgramDesc->Description("program: ");
}
paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL); paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL);
return program; return program;
...@@ -254,6 +259,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size, ...@@ -254,6 +259,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
std::vector<std::shared_ptr<framework::OpDesc>> ops = block_desc->Ops(); std::vector<std::shared_ptr<framework::OpDesc>> ops = block_desc->Ops();
for (int j = 0; j < ops.size(); ++j) { for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<framework::OpDesc> op = ops[j]; std::shared_ptr<framework::OpDesc> op = ops[j];
DLOG << "create op: " << op->Type();
auto op_base = framework::OpRegistry<Dtype>::CreateOp( auto op_base = framework::OpRegistry<Dtype>::CreateOp(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope); program_.scope);
......
...@@ -28,10 +28,10 @@ class FushionConvAddReluOpMatcher : public framework::FusionOpMatcher { ...@@ -28,10 +28,10 @@ class FushionConvAddReluOpMatcher : public framework::FusionOpMatcher {
std::make_shared<framework::Node>(G_OP_TYPE_RELU); std::make_shared<framework::Node>(G_OP_TYPE_RELU);
} }
void FolderNodes(framework::Node &node) { void FolderNodes(framework::Node *node) {
std::vector<std::shared_ptr<framework::OpDesc>> origin_descs = std::vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node.OpDescs(node_.Depth()); node->OpDescs(node_.Depth());
node.Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}});
} }
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; } std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; }
......
...@@ -32,10 +32,10 @@ class FusionFcMatcher : public framework::FusionOpMatcher { ...@@ -32,10 +32,10 @@ class FusionFcMatcher : public framework::FusionOpMatcher {
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD); node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD);
} }
void FolderNodes(framework::Node &node) { void FolderNodes(framework::Node *node) {
vector<std::shared_ptr<framework::OpDesc>> origin_descs = vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node.OpDescs(node_.Depth()); node->OpDescs(node_.Depth());
node.Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}});
} }
......
...@@ -18,11 +18,12 @@ limitations under the License. */ ...@@ -18,11 +18,12 @@ limitations under the License. */
int main() { int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader; paddle_mobile::Loader<paddle_mobile::CPU> loader;
bool optimize = true;
auto time1 = time(); auto time1 = time();
auto program = loader.Load(g_googlenet, false); auto program = loader.Load(g_googlenet, optimize);
auto time2 = time(); auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n"; DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, false); paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
std::vector<float> input; std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224}; std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims); GetInput<float>(g_test_image_1x3x224x224, &input, dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册