提交 5f4501e0 编写于 作者: M Megvii Engine Team

fix(gopt): fix conv bias fuse 2 noline

GitOrigin-RevId: a6ab9f4e5ef6f0d607197dfded9ec0637b638301
上级 ac2f548c
......@@ -1771,7 +1771,6 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
bool can_be_fused = true;
can_be_fused &= (elem->input().size() == 1);
can_be_fused &= (elem->param().mode == Mode::RELU) ||
(elem->param().mode == Mode::TANH) ||
(elem->param().mode == Mode::SIGMOID);
return can_be_fused;
......@@ -1911,13 +1910,14 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
}
} else if (try_fuse_nonlinearity(elem)) {
auto inp = rewriter.get_var(elem->input(0));
auto elem_noline = get_nonlinearity_mode(elem);
{
auto conv = try_cast_as_op<opr::Convolution>(inp->owner_opr());
if (conv && check_conv(conv) &&
m_deps[elem->input(0)].size() == 1) {
opr::ConvBiasForward::Param param =
convert_to_conv_bias_param(conv->param());
param.nonlineMode = get_nonlinearity_mode(elem);
param.nonlineMode = elem_noline;
auto new_var = opr::ConvBiasForward::make(
conv->input(0), conv->input(1), param,
conv->execution_policy(), conv->config())
......@@ -1941,9 +1941,16 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
;
};
if (conv && check_conv_bias(conv) &&
m_deps[elem->input(0)].size() == 1) {
m_deps[elem->input(0)].size() == 1 &&
conv->input().size() > 2) {
auto param = conv->param();
param.nonlineMode = get_nonlinearity_mode(elem);
bool noline_ok = param.nonlineMode == NonlineMode::IDENTITY ||
(param.nonlineMode == NonlineMode::RELU &&
elem_noline == NonlineMode::RELU);
if (!noline_ok) {
return;
}
param.nonlineMode = elem_noline;
auto new_var =
opr::ConvBiasForward::make(
conv->input(0), conv->input(1), conv->input(2),
......
......@@ -1731,6 +1731,52 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4);
}
TEST(TestGoptInference, ConvBiasNonlinearityFusePass2) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
auto cn = CompNode::load("cpu0");
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
opr::Convolution::Param param;
auto x = mkvar("x", {5, 8, 16, 24}), w1 = mkcvar("w1", {4, 8, 1, 1}),
w2 = mkcvar("w2", {4, 8, 1, 1});
auto b1 = mkcvar("b1", {1, 4, 1, 1});
auto y_cut = opr::Convolution::make(x, w1, param);
auto y = opr::Elemwise::make({y_cut + b1}, opr::Elemwise::Param::Mode::SIGMOID);
y = opr::Elemwise::make({y}, opr::Elemwise::Param::Mode::RELU);
auto y_cut2 = opr::Convolution::make(x, w2, param);
y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::SIGMOID);
y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::RELU);
y = y + y_cut2;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4().enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(
opr::ConvBias::Param::NonlineMode::SIGMOID,
find_opr<opr::ConvBias>(y_opt).param().nonlineMode);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.FuseConvBiasNonlinPass2.json"));
HostTensorND host_y, host_y_opt;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4);
}
TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) {
NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册