未验证 提交 ebcdb28b 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Fuse] Add scals fuse. (#4094)

上级 8f3e0ef3
...@@ -40,6 +40,7 @@ USE_MIR_PASS(lite_conv_elementwise_fuse_pass); ...@@ -40,6 +40,7 @@ USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_match_matrix_activation_fuse_pass); USE_MIR_PASS(lite_match_matrix_activation_fuse_pass);
USE_MIR_PASS(lite_scales_fuse_pass);
USE_MIR_PASS(lite_elementwise_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_precision_cast_pass);
......
...@@ -30,6 +30,7 @@ lite_cc_library(mir_passes ...@@ -30,6 +30,7 @@ lite_cc_library(mir_passes
fusion/__xpu__fc_fuse_pass.cc fusion/__xpu__fc_fuse_pass.cc
fusion/__xpu__mmdnn_fuse_pass.cc fusion/__xpu__mmdnn_fuse_pass.cc
fusion/match_matrix_activation_fuse_pass.cc fusion/match_matrix_activation_fuse_pass.cc
fusion/scales_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc
......
...@@ -40,6 +40,9 @@ lite_cc_library(fuse_scale_activation ...@@ -40,6 +40,9 @@ lite_cc_library(fuse_scale_activation
lite_cc_library(fuse_match_matrix_activation lite_cc_library(fuse_match_matrix_activation
SRCS match_matrix_activation_fuser.cc SRCS match_matrix_activation_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
lite_cc_library(fuse_scales
SRCS scales_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers set(mir_fusers
fuse_fc fuse_fc
...@@ -56,6 +59,7 @@ set(mir_fusers ...@@ -56,6 +59,7 @@ set(mir_fusers
fuse_sequence_pool_concat fuse_sequence_pool_concat
fuse_scale_activation fuse_scale_activation
fuse_match_matrix_activation fuse_match_matrix_activation
fuse_scales
CACHE INTERNAL "fusers") CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scales_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/scales_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ScalesFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ScalesFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_scales_fuse_pass, paddle::lite::mir::ScalesFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ScalesFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scales_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ScalesFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
auto scales_teller = [](const Node* node) -> bool {
bool bias_after_scale =
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<bool>(
"bias_after_scale");
return bias_after_scale;
};
// create op nodes
auto* scale1 = OpNode("scale1", "scale")
->assert_is_op("scale")
->assert_node_satisfied(scales_teller)
->AsIntermediate();
auto* scale2 = OpNode("scale2", "scale")
->assert_is_op("scale")
->assert_node_satisfied(scales_teller)
->AsIntermediate();
// create intermediate nodes
auto* scale1_out = VarNode("scale1_out")
->assert_is_op_output("scale", "Out")
->assert_is_op_input("scale", "X")
->AsIntermediate();
// create output node
auto* out = VarNode("out")->assert_is_op_output("scale", "Out")->AsOutput();
// create topology.
*x >> *scale1 >> *scale1_out >> *scale2 >> *out;
}
void ScalesFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto scale_op = LiteOpRegistry::Global().Create("scale");
auto scale = matched.at("scale1")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();
scale_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(scale_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("out"));
}
cpp::OpDesc ScalesFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("scale1")->stmt()->op_info();
float scale1 = op_desc.GetAttr<float>("scale");
float bias1 = op_desc.GetAttr<float>("bias");
float scale2 =
matched.at("scale2")->stmt()->op_info()->GetAttr<float>("scale");
float bias2 = matched.at("scale2")->stmt()->op_info()->GetAttr<float>("bias");
op_desc.SetAttr<float>("scale", scale1 * scale2);
op_desc.SetAttr<float>("bias", bias1 * scale2 + bias2);
auto& out_name = matched.at("out")->arg()->name;
op_desc.SetOutput("Out", {out_name});
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ScalesFuser : public FuseBase {
public:
ScalesFuser() {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -23,8 +23,11 @@ namespace lite { ...@@ -23,8 +23,11 @@ namespace lite {
namespace mir { namespace mir {
void SequencePoolConcatFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void SequencePoolConcatFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::SequencePoolConcatFuser fuser; fusion::SequencePool7ConcatFuser fuser;
fuser(graph.get()); fuser(graph.get());
fusion::SequencePool2ConcatFuser fuser2;
fuser2(graph.get());
} }
} // namespace mir } // namespace mir
......
...@@ -21,22 +21,6 @@ namespace lite { ...@@ -21,22 +21,6 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePoolConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
#define STR1(R) #R #define STR1(R) #R
#define STR2(R) STR1(R) #define STR2(R) STR1(R)
...@@ -58,6 +42,22 @@ void SequencePoolConcatFuser::BuildPattern() { ...@@ -58,6 +42,22 @@ void SequencePoolConcatFuser::BuildPattern() {
*sequence_pool_##num >> *sequence_pool_##num##_idx; \ *sequence_pool_##num >> *sequence_pool_##num##_idx; \
*x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat; *x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat;
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePool7ConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out = auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out"); VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out; *concat >> *concat_out;
...@@ -69,13 +69,9 @@ void SequencePoolConcatFuser::BuildPattern() { ...@@ -69,13 +69,9 @@ void SequencePoolConcatFuser::BuildPattern() {
POOL_CONCAT_PATTERN(5); POOL_CONCAT_PATTERN(5);
POOL_CONCAT_PATTERN(6); POOL_CONCAT_PATTERN(6);
POOL_CONCAT_PATTERN(7); POOL_CONCAT_PATTERN(7);
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
} }
void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph, void SequencePool7ConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op = auto sequence_pool_concat_op =
...@@ -99,7 +95,7 @@ void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph, ...@@ -99,7 +95,7 @@ void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out")); IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
} }
cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc SequencePool7ConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info(); cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat"); op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X", op_desc.SetInput("X",
...@@ -147,6 +143,64 @@ cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -147,6 +143,64 @@ cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc; return op_desc;
} }
void SequencePool2ConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out;
POOL_CONCAT_PATTERN(1);
POOL_CONCAT_PATTERN(2);
}
void SequencePool2ConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op =
LiteOpRegistry::Global().Create("sequence_pool_concat");
auto concat = matched.at("concat")->stmt()->op();
auto* scope = concat->scope();
auto& valid_places = concat->valid_places();
sequence_pool_concat_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(sequence_pool_concat_op, valid_places);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_1"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_2"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
}
cpp::OpDesc SequencePool2ConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X",
{matched.at("sequence_pool_x_1")->arg()->name,
matched.at("sequence_pool_x_2")->arg()->name});
std::vector<std::string> pooltypes;
pooltypes.push_back(matched.at("sequence_pool_1")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_2")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
op_desc.SetAttr("pooltype", pooltypes);
op_desc.SetOutput("Out", {matched.at("concat_out")->arg()->name});
return op_desc;
}
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
} // namespace fusion } // namespace fusion
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
...@@ -23,7 +23,16 @@ namespace lite { ...@@ -23,7 +23,16 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
class SequencePoolConcatFuser : public FuseBase { class SequencePool7ConcatFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
class SequencePool2ConcatFuser : public FuseBase {
public: public:
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
......
...@@ -97,6 +97,7 @@ class Optimizer { ...@@ -97,6 +97,7 @@ class Optimizer {
"lite_transpose_softmax_transpose_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", // "lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", // "identity_scale_eliminate_pass", //
"lite_scales_fuse_pass", //
"elementwise_mul_constant_eliminate_pass", // "elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", // "lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", // "lite_scale_activation_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册