提交 5f4501e0 编写于 作者: M Megvii Engine Team

fix(gopt): fix conv bias fuse 2 noline

GitOrigin-RevId: a6ab9f4e5ef6f0d607197dfded9ec0637b638301
上级 ac2f548c
...@@ -1771,7 +1771,6 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { ...@@ -1771,7 +1771,6 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
bool can_be_fused = true; bool can_be_fused = true;
can_be_fused &= (elem->input().size() == 1); can_be_fused &= (elem->input().size() == 1);
can_be_fused &= (elem->param().mode == Mode::RELU) || can_be_fused &= (elem->param().mode == Mode::RELU) ||
(elem->param().mode == Mode::TANH) ||
(elem->param().mode == Mode::SIGMOID); (elem->param().mode == Mode::SIGMOID);
return can_be_fused; return can_be_fused;
...@@ -1911,13 +1910,14 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { ...@@ -1911,13 +1910,14 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
} }
} else if (try_fuse_nonlinearity(elem)) { } else if (try_fuse_nonlinearity(elem)) {
auto inp = rewriter.get_var(elem->input(0)); auto inp = rewriter.get_var(elem->input(0));
auto elem_noline = get_nonlinearity_mode(elem);
{ {
auto conv = try_cast_as_op<opr::Convolution>(inp->owner_opr()); auto conv = try_cast_as_op<opr::Convolution>(inp->owner_opr());
if (conv && check_conv(conv) && if (conv && check_conv(conv) &&
m_deps[elem->input(0)].size() == 1) { m_deps[elem->input(0)].size() == 1) {
opr::ConvBiasForward::Param param = opr::ConvBiasForward::Param param =
convert_to_conv_bias_param(conv->param()); convert_to_conv_bias_param(conv->param());
param.nonlineMode = get_nonlinearity_mode(elem); param.nonlineMode = elem_noline;
auto new_var = opr::ConvBiasForward::make( auto new_var = opr::ConvBiasForward::make(
conv->input(0), conv->input(1), param, conv->input(0), conv->input(1), param,
conv->execution_policy(), conv->config()) conv->execution_policy(), conv->config())
...@@ -1941,9 +1941,16 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { ...@@ -1941,9 +1941,16 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
; ;
}; };
if (conv && check_conv_bias(conv) && if (conv && check_conv_bias(conv) &&
m_deps[elem->input(0)].size() == 1) { m_deps[elem->input(0)].size() == 1 &&
conv->input().size() > 2) {
auto param = conv->param(); auto param = conv->param();
param.nonlineMode = get_nonlinearity_mode(elem); bool noline_ok = param.nonlineMode == NonlineMode::IDENTITY ||
(param.nonlineMode == NonlineMode::RELU &&
elem_noline == NonlineMode::RELU);
if (!noline_ok) {
return;
}
param.nonlineMode = elem_noline;
auto new_var = auto new_var =
opr::ConvBiasForward::make( opr::ConvBiasForward::make(
conv->input(0), conv->input(1), conv->input(2), conv->input(0), conv->input(1), conv->input(2),
......
...@@ -1731,6 +1731,52 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { ...@@ -1731,6 +1731,52 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4);
} }
TEST(TestGoptInference, ConvBiasNonlinearityFusePass2) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
auto cn = CompNode::load("cpu0");
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
opr::Convolution::Param param;
auto x = mkvar("x", {5, 8, 16, 24}), w1 = mkcvar("w1", {4, 8, 1, 1}),
w2 = mkcvar("w2", {4, 8, 1, 1});
auto b1 = mkcvar("b1", {1, 4, 1, 1});
auto y_cut = opr::Convolution::make(x, w1, param);
auto y = opr::Elemwise::make({y_cut + b1}, opr::Elemwise::Param::Mode::SIGMOID);
y = opr::Elemwise::make({y}, opr::Elemwise::Param::Mode::RELU);
auto y_cut2 = opr::Convolution::make(x, w2, param);
y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::SIGMOID);
y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::RELU);
y = y + y_cut2;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4().enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(
opr::ConvBias::Param::NonlineMode::SIGMOID,
find_opr<opr::ConvBias>(y_opt).param().nonlineMode);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.FuseConvBiasNonlinPass2.json"));
HostTensorND host_y, host_y_opt;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4);
}
TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) { TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) {
NaiveMegDNNHandleScope naive_megdnn_handle; NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册