提交 746567b0 编写于 作者: L liuruilong

add support for mul fusion changes

上级 ccd7c7de
...@@ -6,7 +6,7 @@ option(USE_OPENMP "openmp support" OFF) ...@@ -6,7 +6,7 @@ option(USE_OPENMP "openmp support" OFF)
option(USE_EXCEPTION "use std exception" ON) option(USE_EXCEPTION "use std exception" ON)
option(LOG_PROFILE "log profile" ON) option(LOG_PROFILE "log profile" ON)
# select the platform to build # select the platform to build
option(CPU "cpu" OFF) option(CPU "cpu" ON)
option(MALI_GPU "mali gpu" ON) option(MALI_GPU "mali gpu" ON)
option(FPGA "fpga" OFF) option(FPGA "fpga" OFF)
......
...@@ -92,7 +92,7 @@ int Node::Depth(int begin) { ...@@ -92,7 +92,7 @@ int Node::Depth(int begin) {
Node &Node::Folder( Node &Node::Folder(
int size, std::string type, int size, std::string type,
std::map<std::string, std::pair<std::string, std::string>> change, std::map<std::string, std::vector<std::pair<std::string, std::string>>> change,
std::vector<std::shared_ptr<Node>> *removed_nodes) { std::vector<std::shared_ptr<Node>> *removed_nodes) {
std::shared_ptr<framework::OpDesc> op_desc = std::shared_ptr<framework::OpDesc> op_desc =
std::make_shared<framework::OpDesc>(); std::make_shared<framework::OpDesc>();
...@@ -109,12 +109,15 @@ Node &Node::Folder( ...@@ -109,12 +109,15 @@ Node &Node::Folder(
void Node::Folder( void Node::Folder(
std::shared_ptr<framework::OpDesc> op_desc, std::shared_ptr<framework::OpDesc> op_desc,
std::vector<std::shared_ptr<Node>> *outputs, int index, std::vector<std::shared_ptr<Node>> *outputs, int index,
std::map<std::string, std::pair<std::string, std::string>> *change, std::map<std::string, std::vector<std::pair<std::string, std::string>>> *change,
Node *begin_node, std::vector<std::shared_ptr<Node>> *removed_nodes) { Node *begin_node, std::vector<std::shared_ptr<Node>> *removed_nodes) {
if (change->find(this->type_) != change->end()) { if (change->find(this->type_) != change->end()) {
auto change_pair = (*change)[this->type_];
op_desc->GetInputs()[change_pair.second] = auto change_pairs = (*change)[this->type_];
this->op_desc_->GetInputs()[change_pair.first]; for (const auto &change_pair: change_pairs) {
op_desc->GetInputs()[change_pair.second] =
this->op_desc_->GetInputs()[change_pair.first];
}
} }
for (auto &attr_pair : this->op_desc_->attrs_) { for (auto &attr_pair : this->op_desc_->attrs_) {
......
...@@ -43,7 +43,7 @@ class Node { ...@@ -43,7 +43,7 @@ class Node {
int Depth(int begin = 0); int Depth(int begin = 0);
Node &Folder( Node &Folder(
int size, std::string type, int size, std::string type,
std::map<std::string, std::pair<std::string, std::string>> change_map, std::map<std::string, std::vector<std::pair<std::string, std::string>>> change,
std::vector<std::shared_ptr<Node>> *removed_nodes); std::vector<std::shared_ptr<Node>> *removed_nodes);
std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(int size); std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(int size);
std::shared_ptr<framework::OpDesc> OpDescOfNode() { return op_desc_; } std::shared_ptr<framework::OpDesc> OpDescOfNode() { return op_desc_; }
...@@ -56,7 +56,7 @@ class Node { ...@@ -56,7 +56,7 @@ class Node {
void Folder( void Folder(
std::shared_ptr<framework::OpDesc> op_desc, std::shared_ptr<framework::OpDesc> op_desc,
std::vector<std::shared_ptr<Node>> *outputs, int index, std::vector<std::shared_ptr<Node>> *outputs, int index,
std::map<std::string, std::pair<std::string, std::string>> *change, std::map<std::string, std::vector<std::pair<std::string, std::string>>> *change,
Node *begin_node, std::vector<std::shared_ptr<Node>> *removed_nodes); Node *begin_node, std::vector<std::shared_ptr<Node>> *removed_nodes);
std::shared_ptr<framework::OpDesc> op_desc_; std::shared_ptr<framework::OpDesc> op_desc_;
#ifdef PADDLE_MOBILE_DEBUG #ifdef PADDLE_MOBILE_DEBUG
......
...@@ -40,7 +40,7 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher { ...@@ -40,7 +40,7 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher {
vector<std::shared_ptr<framework::OpDesc>> origin_descs = vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node->OpDescs(node_.Depth()); node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes); {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
} }
std::string Type() { return G_OP_TYPE_CONV_ADD; } std::string Type() { return G_OP_TYPE_CONV_ADD; }
...@@ -68,11 +68,13 @@ class FusionConvAddOp : public framework::OperatorWithKernel< ...@@ -68,11 +68,13 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
}; };
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_REGISTER #ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar( static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher()); new FusionConvAddMatcher());
#define CONV_ADD_REGISTER #define CONV_ADD_REGISTER
#endif #endif
#endif #endif
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
......
...@@ -36,7 +36,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher { ...@@ -36,7 +36,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher {
framework::Node *node, framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) { std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes); {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
} }
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; } std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; }
}; };
...@@ -65,11 +65,11 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel< ...@@ -65,11 +65,11 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_RELU_REGISTER //#ifndef CONV_ADD_RELU_REGISTER
#define CONV_ADD_RELU_REGISTER //#define CONV_ADD_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_relu_registrar(new // static framework::FusionOpRegistrar fusion_conv_add_relu_registrar(new
// FusionConvAddReluOpMatcher()); // FusionConvAddReluOpMatcher());
#endif //#endif
#endif #endif
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
......
...@@ -38,7 +38,7 @@ class FusionFcMatcher : public framework::FusionOpMatcher { ...@@ -38,7 +38,7 @@ class FusionFcMatcher : public framework::FusionOpMatcher {
framework::Node *node, framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) { std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}, removed_nodes); {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Z"}}}}, removed_nodes);
} }
std::string Type() { return G_OP_TYPE_FC; } std::string Type() { return G_OP_TYPE_FC; }
...@@ -66,17 +66,21 @@ class FusionFcOp ...@@ -66,17 +66,21 @@ class FusionFcOp
}; };
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef CONV_CPU_REGISTER #ifndef CONV_CPU_REGISTER
#define CONV_CPU_REGISTER #define CONV_CPU_REGISTER
static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher());
#endif #endif
#endif #endif
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
#ifndef CONV_CPU_REGISTER #ifndef CONV_CPU_REGISTER
#define CONV_CPU_REGISTER #define CONV_CPU_REGISTER
static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher()); static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher());
#endif #endif
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
......
...@@ -19,7 +19,7 @@ int main() { ...@@ -19,7 +19,7 @@ int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader; paddle_mobile::Loader<paddle_mobile::CPU> loader;
// ../../../test/models/googlenet // ../../../test/models/googlenet
// ../../../test/models/mobilenet // ../../../test/models/mobilenet
auto program = loader.Load(g_mobilenet_ssd, false, false); auto program = loader.Load(g_googlenet, true);
// auto program = loader.Load(g_googlenet_combine + "/model", // auto program = loader.Load(g_googlenet_combine + "/model",
// g_googlenet_combine + // g_googlenet_combine +
// "/params", true); // "/params", true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册