未验证 提交 2dcff5ca 编写于 作者: Z zhupengyang 提交者: GitHub

fix conv-act-fuse-pass when there is no "bias" (#2003)

test=develop
上级 81132a32
...@@ -23,11 +23,14 @@ namespace lite { ...@@ -23,11 +23,14 @@ namespace lite {
namespace mir { namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvActivationFuser fuser("conv2d", "relu"); for (auto conv_type : {"conv2d", "depthwise_conv2d"}) {
fuser(graph.get()); for (auto act_type : {"relu"}) {
for (auto has_bias : {true, false}) {
fusion::ConvActivationFuser depthwise_fuser("depthwise_conv2d", "relu"); fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias);
depthwise_fuser(graph.get()); fuser(graph.get());
}
}
}
} }
} // namespace mir } // namespace mir
......
...@@ -22,35 +22,33 @@ namespace mir { ...@@ -22,35 +22,33 @@ namespace mir {
namespace fusion { namespace fusion {
void ConvActivationFuser::BuildPattern() { void ConvActivationFuser::BuildPattern() {
// create input nodes. // create nodes.
auto* input = auto* input =
VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* filter = auto* filter =
VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput();
auto* bias = PMNode* bias = nullptr;
VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput(); if (has_bias_) {
bias = VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput();
// create op nodes }
auto* conv2d = auto* conv2d = OpNode("conv2d", conv_type_)->AsIntermediate();
OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
auto* act = auto* act = OpNode("act", act_type_)->AsIntermediate();
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out") auto* conv2d_out = VarNode("conv2d_out")
->assert_is_op_output(conv_type_, "Output") ->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input(act_type_, "X") ->assert_is_op_input(act_type_, "X")
->AsIntermediate(); ->AsIntermediate();
// create output node
auto* out = auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology. // create topology.
std::vector<PMNode*> conv2d_inputs{filter, input, bias}; std::vector<PMNode*> conv2d_inputs{filter, input};
conv2d_inputs >> *conv2d >> *conv2d_out; conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out;
*conv2d_out >> *act >> *out; if (has_bias_) {
*bias >> *conv2d;
}
} }
void ConvActivationFuser::InsertNewNode(SSAGraph* graph, void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
...@@ -66,33 +64,15 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph, ...@@ -66,33 +64,15 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(matched.at("input"), new_op_node); IR_NODE_LINK_TO(matched.at("input"), new_op_node);
IR_NODE_LINK_TO(matched.at("filter"), new_op_node); IR_NODE_LINK_TO(matched.at("filter"), new_op_node);
IR_NODE_LINK_TO(matched.at("bias"), new_op_node); if (has_bias_) {
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
}
IR_NODE_LINK_TO(new_op_node, matched.at("output")); IR_NODE_LINK_TO(new_op_node, matched.at("output"));
} }
cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info(); cpp::OpDesc op_desc = *matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc = *desc;
op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
// Other inputs. See operators/conv_op.h
std::vector<std::string> input_arg_names = desc->InputArgumentNames();
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides"));
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
// TODO(sangoly): support other activation types
op_desc.SetAttr("fuse_relu", true); op_desc.SetAttr("fuse_relu", true);
return op_desc; return op_desc;
} }
......
...@@ -26,10 +26,12 @@ namespace fusion { ...@@ -26,10 +26,12 @@ namespace fusion {
class ConvActivationFuser : public FuseBase { class ConvActivationFuser : public FuseBase {
public: public:
explicit ConvActivationFuser(const std::string& conv_type, explicit ConvActivationFuser(const std::string& conv_type,
const std::string& act_type) { const std::string& act_type,
bool has_bias) {
CHECK(act_type == "relu") << "Only relu activation be supported now"; CHECK(act_type == "relu") << "Only relu activation be supported now";
conv_type_ = conv_type; conv_type_ = conv_type;
act_type_ = act_type; act_type_ = act_type;
has_bias_ = has_bias;
} }
void BuildPattern() override; void BuildPattern() override;
...@@ -39,6 +41,7 @@ class ConvActivationFuser : public FuseBase { ...@@ -39,6 +41,7 @@ class ConvActivationFuser : public FuseBase {
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_; std::string conv_type_;
std::string act_type_; std::string act_type_;
bool has_bias_;
}; };
} // namespace fusion } // namespace fusion
......
...@@ -70,7 +70,8 @@ void CompareOutData(const lite::Predictor& tgt, const lite::Predictor& ref) { ...@@ -70,7 +70,8 @@ void CompareOutData(const lite::Predictor& tgt, const lite::Predictor& ref) {
const auto* ref_pdata = ref_otensor->data<float>(); const auto* ref_pdata = ref_otensor->data<float>();
EXPECT_EQ(tgt_otensor->dims().production(), ref_otensor->dims().production()); EXPECT_EQ(tgt_otensor->dims().production(), ref_otensor->dims().production());
for (size_t i = 0; i < tgt_otensor->dims().production(); ++i) { for (size_t i = 0; i < tgt_otensor->dims().production(); ++i) {
auto diff = std::fabs((tgt_pdata[i] - ref_pdata[i]) / ref_pdata[i]); auto diff = std::fabs(tgt_pdata[i] - ref_pdata[i]) /
(std::fabs(ref_pdata[i]) + 1e-6);
VLOG(3) << diff; VLOG(3) << diff;
EXPECT_LT(diff, 0.1); EXPECT_LT(diff, 0.1);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册