From fbe8f49f7e4ac4462d460429ef6da350f0628947 Mon Sep 17 00:00:00 2001 From: liuruilong Date: Mon, 16 Jul 2018 15:53:20 +0800 Subject: [PATCH] add fusion priority --- .../program/program-optimize/fusion_op_register.h | 13 +++++++++++-- .../program/program-optimize/program_optimize.cpp | 5 ++--- src/operators/feed_op.cpp | 4 +--- src/operators/fetch_op.cpp | 4 +--- src/operators/fusion_dwconv_bn_relu_op.cpp | 2 +- src/operators/fusion_dwconv_bn_relu_op.h | 2 -- src/operators/op_param.h | 2 +- 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/framework/program/program-optimize/fusion_op_register.h b/src/framework/program/program-optimize/fusion_op_register.h index 1cd6b1dd77..f16a65c28f 100644 --- a/src/framework/program/program-optimize/fusion_op_register.h +++ b/src/framework/program/program-optimize/fusion_op_register.h @@ -42,8 +42,17 @@ class FusionOpRegister { matchers_[matcher->Type()] = shared_matcher; } - const std::map> Matchers() { - return matchers_; + const std::vector> Matchers() { + std::vector> matchers; + for (const auto& match : matchers_) { + matchers.push_back(match.second); + } + std::sort(matchers.begin(), matchers.end(), + [](std::shared_ptr first, + std::shared_ptr second) { + return first->BeginNode().Depth() > second->BeginNode().Depth(); + }); + return matchers; } private: diff --git a/src/framework/program/program-optimize/program_optimize.cpp b/src/framework/program/program-optimize/program_optimize.cpp index 3619bc79f5..82d33bc65d 100644 --- a/src/framework/program/program-optimize/program_optimize.cpp +++ b/src/framework/program/program-optimize/program_optimize.cpp @@ -78,9 +78,8 @@ std::shared_ptr ProgramOptimize::FusionOptimize( } for (auto ®isted : FusionOpRegister::Instance()->Matchers()) { - std::string fusion_type = registed.first; - std::shared_ptr matcher = registed.second; - // DLOG << " registed node \n " << matcher->BeginNode(); + std::string fusion_type = registed->Type(); + std::shared_ptr matcher = registed; auto match_vector = type_map[matcher->BeginType()]; diff --git a/src/operators/feed_op.cpp b/src/operators/feed_op.cpp index 7fc9101fa8..4447f2c699 100644 --- a/src/operators/feed_op.cpp +++ b/src/operators/feed_op.cpp @@ -14,9 +14,7 @@ limitations under the License. */ #include "feed_op.h" namespace paddle_mobile { -namespace operators { - -} +namespace operators {} } // namespace paddle_mobile namespace ops = paddle_mobile::operators; diff --git a/src/operators/fetch_op.cpp b/src/operators/fetch_op.cpp index cecfb28ee7..adbd61d5ec 100644 --- a/src/operators/fetch_op.cpp +++ b/src/operators/fetch_op.cpp @@ -14,9 +14,7 @@ limitations under the License. */ #include "fetch_op.h" namespace paddle_mobile { -namespace operators { - -} +namespace operators {} } // namespace paddle_mobile namespace ops = paddle_mobile::operators; diff --git a/src/operators/fusion_dwconv_bn_relu_op.cpp b/src/operators/fusion_dwconv_bn_relu_op.cpp index ba03a436c3..e55295830e 100644 --- a/src/operators/fusion_dwconv_bn_relu_op.cpp +++ b/src/operators/fusion_dwconv_bn_relu_op.cpp @@ -44,7 +44,7 @@ void FusionDWConvBNReluOp::InferShape() const { framework::DDim ddim = framework::make_ddim(output_shape); this->param_.Output()->Resize(ddim); } -template class FusionDWConvBNReluOp; + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/fusion_dwconv_bn_relu_op.h b/src/operators/fusion_dwconv_bn_relu_op.h index bf95b51da4..6f9f03e493 100644 --- a/src/operators/fusion_dwconv_bn_relu_op.h +++ b/src/operators/fusion_dwconv_bn_relu_op.h @@ -38,8 +38,6 @@ class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher { void FolderNodes( framework::Node *node, std::vector> *removed_nodes) { - vector> origin_descs = - node->OpDescs(node_.Depth()); node->Folder(node_.Depth(), Type(), {{G_OP_TYPE_BATCHNORM, {{"Scale", "Scale"}, diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 994392d678..390de2d5cf 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -371,7 +371,7 @@ class BatchNormParam : OpParam { input_variance_ = InputVarianceFrom(inputs, scope); epsilon_ = GetAttr("epsilon", attrs); momentum_ = GetAttr("momentum", attrs); -// is_test_ = GetAttr("is_test", attrs); + // is_test_ = GetAttr("is_test", attrs); } const Tensor *InputX() const { return input_x_; } -- GitLab