提交 c6ee8ac6 编写于 作者: L liuruilong

add fusion priority

上级 2a949339
...@@ -42,8 +42,17 @@ class FusionOpRegister { ...@@ -42,8 +42,17 @@ class FusionOpRegister {
matchers_[matcher->Type()] = shared_matcher; matchers_[matcher->Type()] = shared_matcher;
} }
const std::map<std::string, std::shared_ptr<FusionOpMatcher>> Matchers() { const std::vector<std::shared_ptr<FusionOpMatcher>> Matchers() {
return matchers_; std::vector<std::shared_ptr<FusionOpMatcher>> matchers;
for (const auto& match : matchers_) {
matchers.push_back(match.second);
}
std::sort(matchers.begin(), matchers.end(),
[](std::shared_ptr<FusionOpMatcher> first,
std::shared_ptr<FusionOpMatcher> second) {
return first->BeginNode().Depth() > second->BeginNode().Depth();
});
return matchers;
} }
private: private:
......
...@@ -78,9 +78,8 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FusionOptimize( ...@@ -78,9 +78,8 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FusionOptimize(
} }
for (auto &registed : FusionOpRegister::Instance()->Matchers()) { for (auto &registed : FusionOpRegister::Instance()->Matchers()) {
std::string fusion_type = registed.first; std::string fusion_type = registed->Type();
std::shared_ptr<FusionOpMatcher> matcher = registed.second; std::shared_ptr<FusionOpMatcher> matcher = registed;
// DLOG << " registed node \n " << matcher->BeginNode();
auto match_vector = type_map[matcher->BeginType()]; auto match_vector = type_map[matcher->BeginType()];
......
...@@ -14,9 +14,7 @@ limitations under the License. */ ...@@ -14,9 +14,7 @@ limitations under the License. */
#include "feed_op.h" #include "feed_op.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {}
}
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
......
...@@ -14,9 +14,7 @@ limitations under the License. */ ...@@ -14,9 +14,7 @@ limitations under the License. */
#include "fetch_op.h" #include "fetch_op.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {}
}
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
......
...@@ -44,7 +44,7 @@ void FusionDWConvBNReluOp<Dtype, T>::InferShape() const { ...@@ -44,7 +44,7 @@ void FusionDWConvBNReluOp<Dtype, T>::InferShape() const {
framework::DDim ddim = framework::make_ddim(output_shape); framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim); this->param_.Output()->Resize(ddim);
} }
template class FusionDWConvBNReluOp<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -38,8 +38,6 @@ class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher { ...@@ -38,8 +38,6 @@ class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher {
void FolderNodes( void FolderNodes(
framework::Node *node, framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) { std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM, {{G_OP_TYPE_BATCHNORM,
{{"Scale", "Scale"}, {{"Scale", "Scale"},
......
...@@ -371,7 +371,7 @@ class BatchNormParam : OpParam { ...@@ -371,7 +371,7 @@ class BatchNormParam : OpParam {
input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope); input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs); epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs); momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs); // is_test_ = GetAttr<bool>("is_test", attrs);
} }
const Tensor *InputX() const { return input_x_; } const Tensor *InputX() const { return input_x_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册