未验证 提交 2dcff5ca 编写于 作者: Z zhupengyang 提交者: GitHub

fix conv-act-fuse-pass when there is no "bias" (#2003)

test=develop
上级 81132a32
......@@ -23,11 +23,14 @@ namespace lite {
namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvActivationFuser fuser("conv2d", "relu");
for (auto conv_type : {"conv2d", "depthwise_conv2d"}) {
for (auto act_type : {"relu"}) {
for (auto has_bias : {true, false}) {
fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias);
fuser(graph.get());
fusion::ConvActivationFuser depthwise_fuser("depthwise_conv2d", "relu");
depthwise_fuser(graph.get());
}
}
}
}
} // namespace mir
......
......@@ -22,35 +22,33 @@ namespace mir {
namespace fusion {
void ConvActivationFuser::BuildPattern() {
// create input nodes.
// create nodes.
auto* input =
VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* filter =
VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput();
auto* bias =
VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput();
// create op nodes
auto* conv2d =
OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
PMNode* bias = nullptr;
if (has_bias_) {
bias = VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput();
}
auto* conv2d = OpNode("conv2d", conv_type_)->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
auto* act = OpNode("act", act_type_)->AsIntermediate();
// create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
// create output node
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input, bias};
conv2d_inputs >> *conv2d >> *conv2d_out;
*conv2d_out >> *act >> *out;
std::vector<PMNode*> conv2d_inputs{filter, input};
conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out;
if (has_bias_) {
*bias >> *conv2d;
}
}
void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
......@@ -66,33 +64,15 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(matched.at("input"), new_op_node);
IR_NODE_LINK_TO(matched.at("filter"), new_op_node);
if (has_bias_) {
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
}
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc = *desc;
op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
cpp::OpDesc op_desc = *matched.at("conv2d")->stmt()->op_info();
op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
// Other inputs. See operators/conv_op.h
std::vector<std::string> input_arg_names = desc->InputArgumentNames();
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides"));
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
// TODO(sangoly): support other activation types
op_desc.SetAttr("fuse_relu", true);
return op_desc;
}
......
......@@ -26,10 +26,12 @@ namespace fusion {
class ConvActivationFuser : public FuseBase {
public:
explicit ConvActivationFuser(const std::string& conv_type,
const std::string& act_type) {
const std::string& act_type,
bool has_bias) {
CHECK(act_type == "relu") << "Only relu activation be supported now";
conv_type_ = conv_type;
act_type_ = act_type;
has_bias_ = has_bias;
}
void BuildPattern() override;
......@@ -39,6 +41,7 @@ class ConvActivationFuser : public FuseBase {
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
std::string act_type_;
bool has_bias_;
};
} // namespace fusion
......
......@@ -70,7 +70,8 @@ void CompareOutData(const lite::Predictor& tgt, const lite::Predictor& ref) {
const auto* ref_pdata = ref_otensor->data<float>();
EXPECT_EQ(tgt_otensor->dims().production(), ref_otensor->dims().production());
for (size_t i = 0; i < tgt_otensor->dims().production(); ++i) {
auto diff = std::fabs((tgt_pdata[i] - ref_pdata[i]) / ref_pdata[i]);
auto diff = std::fabs(tgt_pdata[i] - ref_pdata[i]) /
(std::fabs(ref_pdata[i]) + 1e-6);
VLOG(3) << diff;
EXPECT_LT(diff, 0.1);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册