提交 fbe8f49f 编写于 作者: L liuruilong

add fusion priority

上级 f8072161
......@@ -42,8 +42,17 @@ class FusionOpRegister {
matchers_[matcher->Type()] = shared_matcher;
}
const std::map<std::string, std::shared_ptr<FusionOpMatcher>> Matchers() {
return matchers_;
const std::vector<std::shared_ptr<FusionOpMatcher>> Matchers() {
std::vector<std::shared_ptr<FusionOpMatcher>> matchers;
for (const auto& match : matchers_) {
matchers.push_back(match.second);
}
std::sort(matchers.begin(), matchers.end(),
[](std::shared_ptr<FusionOpMatcher> first,
std::shared_ptr<FusionOpMatcher> second) {
return first->BeginNode().Depth() > second->BeginNode().Depth();
});
return matchers;
}
private:
......
......@@ -78,9 +78,8 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FusionOptimize(
}
for (auto &registed : FusionOpRegister::Instance()->Matchers()) {
std::string fusion_type = registed.first;
std::shared_ptr<FusionOpMatcher> matcher = registed.second;
// DLOG << " registed node \n " << matcher->BeginNode();
std::string fusion_type = registed->Type();
std::shared_ptr<FusionOpMatcher> matcher = registed;
auto match_vector = type_map[matcher->BeginType()];
......
......@@ -14,9 +14,7 @@ limitations under the License. */
#include "feed_op.h"
namespace paddle_mobile {
namespace operators {
}
namespace operators {}
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
......
......@@ -14,9 +14,7 @@ limitations under the License. */
#include "fetch_op.h"
namespace paddle_mobile {
namespace operators {
}
namespace operators {}
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
......
......@@ -44,7 +44,7 @@ void FusionDWConvBNReluOp<Dtype, T>::InferShape() const {
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
template class FusionDWConvBNReluOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
......
......@@ -38,8 +38,6 @@ class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher {
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM,
{{"Scale", "Scale"},
......
......@@ -371,7 +371,7 @@ class BatchNormParam : OpParam {
input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
// is_test_ = GetAttr<bool>("is_test", attrs);
}
const Tensor *InputX() const { return input_x_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册