未验证 提交 ebcdb28b 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Fuse] Add scals fuse. (#4094)

上级 8f3e0ef3
......@@ -40,6 +40,7 @@ USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_match_matrix_activation_fuse_pass);
USE_MIR_PASS(lite_scales_fuse_pass);
USE_MIR_PASS(lite_elementwise_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass);
......
......@@ -30,6 +30,7 @@ lite_cc_library(mir_passes
fusion/__xpu__fc_fuse_pass.cc
fusion/__xpu__mmdnn_fuse_pass.cc
fusion/match_matrix_activation_fuse_pass.cc
fusion/scales_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
......
......@@ -40,6 +40,9 @@ lite_cc_library(fuse_scale_activation
lite_cc_library(fuse_match_matrix_activation
SRCS match_matrix_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_scales
SRCS scales_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -56,6 +59,7 @@ set(mir_fusers
fuse_sequence_pool_concat
fuse_scale_activation
fuse_match_matrix_activation
fuse_scales
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scales_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/scales_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ScalesFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ScalesFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_scales_fuse_pass, paddle::lite::mir::ScalesFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ScalesFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scales_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ScalesFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
auto scales_teller = [](const Node* node) -> bool {
bool bias_after_scale =
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<bool>(
"bias_after_scale");
return bias_after_scale;
};
// create op nodes
auto* scale1 = OpNode("scale1", "scale")
->assert_is_op("scale")
->assert_node_satisfied(scales_teller)
->AsIntermediate();
auto* scale2 = OpNode("scale2", "scale")
->assert_is_op("scale")
->assert_node_satisfied(scales_teller)
->AsIntermediate();
// create intermediate nodes
auto* scale1_out = VarNode("scale1_out")
->assert_is_op_output("scale", "Out")
->assert_is_op_input("scale", "X")
->AsIntermediate();
// create output node
auto* out = VarNode("out")->assert_is_op_output("scale", "Out")->AsOutput();
// create topology.
*x >> *scale1 >> *scale1_out >> *scale2 >> *out;
}
void ScalesFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto scale_op = LiteOpRegistry::Global().Create("scale");
auto scale = matched.at("scale1")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();
scale_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(scale_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("out"));
}
cpp::OpDesc ScalesFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("scale1")->stmt()->op_info();
float scale1 = op_desc.GetAttr<float>("scale");
float bias1 = op_desc.GetAttr<float>("bias");
float scale2 =
matched.at("scale2")->stmt()->op_info()->GetAttr<float>("scale");
float bias2 = matched.at("scale2")->stmt()->op_info()->GetAttr<float>("bias");
op_desc.SetAttr<float>("scale", scale1 * scale2);
op_desc.SetAttr<float>("bias", bias1 * scale2 + bias2);
auto& out_name = matched.at("out")->arg()->name;
op_desc.SetOutput("Out", {out_name});
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ScalesFuser : public FuseBase {
public:
ScalesFuser() {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -23,8 +23,11 @@ namespace lite {
namespace mir {
void SequencePoolConcatFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::SequencePoolConcatFuser fuser;
fusion::SequencePool7ConcatFuser fuser;
fuser(graph.get());
fusion::SequencePool2ConcatFuser fuser2;
fuser2(graph.get());
}
} // namespace mir
......
......@@ -21,22 +21,6 @@ namespace lite {
namespace mir {
namespace fusion {
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePoolConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
#define STR1(R) #R
#define STR2(R) STR1(R)
......@@ -58,6 +42,22 @@ void SequencePoolConcatFuser::BuildPattern() {
*sequence_pool_##num >> *sequence_pool_##num##_idx; \
*x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat;
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePool7ConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out;
......@@ -69,14 +69,10 @@ void SequencePoolConcatFuser::BuildPattern() {
POOL_CONCAT_PATTERN(5);
POOL_CONCAT_PATTERN(6);
POOL_CONCAT_PATTERN(7);
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
}
void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
void SequencePool7ConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op =
LiteOpRegistry::Global().Create("sequence_pool_concat");
......@@ -99,7 +95,7 @@ void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
}
cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc SequencePool7ConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X",
......@@ -147,6 +143,64 @@ cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc;
}
void SequencePool2ConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out;
POOL_CONCAT_PATTERN(1);
POOL_CONCAT_PATTERN(2);
}
void SequencePool2ConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op =
LiteOpRegistry::Global().Create("sequence_pool_concat");
auto concat = matched.at("concat")->stmt()->op();
auto* scope = concat->scope();
auto& valid_places = concat->valid_places();
sequence_pool_concat_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(sequence_pool_concat_op, valid_places);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_1"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_2"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
}
cpp::OpDesc SequencePool2ConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X",
{matched.at("sequence_pool_x_1")->arg()->name,
matched.at("sequence_pool_x_2")->arg()->name});
std::vector<std::string> pooltypes;
pooltypes.push_back(matched.at("sequence_pool_1")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_2")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
op_desc.SetAttr("pooltype", pooltypes);
op_desc.SetOutput("Out", {matched.at("concat_out")->arg()->name});
return op_desc;
}
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
} // namespace fusion
} // namespace mir
} // namespace lite
......
......@@ -23,7 +23,16 @@ namespace lite {
namespace mir {
namespace fusion {
class SequencePoolConcatFuser : public FuseBase {
class SequencePool7ConcatFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
class SequencePool2ConcatFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
......
......@@ -97,6 +97,7 @@ class Optimizer {
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
"lite_scales_fuse_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册