提交 87de704a 编写于 作者: M Megvii Engine Team

feat(gopt): fuse conv h_swish

GitOrigin-RevId: a3d12991fbb2e16a5e91b41fe3f3143257515c2a
上级 4adba378
......@@ -392,6 +392,13 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
MIDOUT_END(); \
break; \
} \
case param::ConvBias::NonlineMode::H_SWISH: { \
MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \
cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \
} \
MIDOUT_END(); \
break; \
} \
default: \
megdnn_assert(0); \
break; \
......
......@@ -137,6 +137,12 @@ struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> {
return QConverter::convert<int8x8_t, float32x4_t>(vitem0);
}
int8x8_t operator()(const float32x4_t& src) const {
auto vitem0 = vmulq_f32(src, this->vscale_src);
H_SWISH_KERN_N1(f32, vitem0);
vitem0 = vmulq_f32(vitem0, this->vscale_dst);
return QConverter::convert<int8x8_t, float32x4_t>(vitem0);
}
};
template <>
......
......@@ -421,7 +421,7 @@ std::vector<TestArg> get_int8_nchw44_args(
using NLMode = param::ConvBias::NonlineMode;
// clang-format off
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU}) {
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::H_SWISH}) {
for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) {
for (size_t b : {1,2}) {
for (size_t ic : {8,16}) {
......@@ -542,7 +542,7 @@ std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size) {
using NLMode = param::ConvBias::NonlineMode;
// clang-format off
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU}) {
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::H_SWISH}) {
for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) {
for (size_t b : {12, 8, 4}) {
for (size_t ic : {16, 32}) {
......@@ -577,7 +577,7 @@ std::vector<TestArg> get_int8_nchw4_small_channel_args(size_t kernel_size) {
using NLMode = param::ConvBias::NonlineMode;
// clang-format off
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU}) {
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::H_SWISH}) {
for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) {
for (size_t b : {64, 16}) {
for (size_t ic : {4, 12}) {
......@@ -696,7 +696,7 @@ std::vector<TestArg> get_int8_nchw4_tensorcore_args(size_t kernel_size) {
using NLMode = param::ConvBias::NonlineMode;
// clang-format off
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU}) {
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::H_SWISH}) {
for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) {
size_t b = 64, oc = 128;
for (size_t ic : {32, 64}) {
......
......@@ -1291,7 +1291,8 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_TENSORCORE_INT8) {
param.format = ConvBias::Param::Format::NCHW32;
using NonlineMode = ConvBias::Param::NonlineMode;
for (NonlineMode mode : {NonlineMode::IDENTITY, NonlineMode::RELU}) {
for (NonlineMode mode :
{NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::H_SWISH}) {
for (size_t batch : {2}) {
for (size_t ic : {64, 32}) {
for (size_t oc : {32}) {
......
......@@ -1083,7 +1083,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) {
for (size_t p : {0, 2})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
{NonlineMode::IDENTITY, NonlineMode::RELU,
NonlineMode::H_SWISH}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
......@@ -1185,7 +1186,8 @@ TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32_RECORD) {
1, oc, (h + 2 * p - kernel) / param.stride_h + 1,
(w + 2 * p - kernel) / param.stride_w + 1});
};
for (NonlineMode nonline_mode : {NonlineMode::IDENTITY, NonlineMode::RELU}) {
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::H_SWISH}) {
run(1, 1, 24, 24, 2, 2, nonline_mode);
}
......@@ -1230,7 +1232,8 @@ TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32_NOPACK_PREPROCESS) {
for (size_t p : {0, 2})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
{NonlineMode::IDENTITY, NonlineMode::RELU,
NonlineMode::H_SWISH}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
......@@ -1285,7 +1288,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_6x16) {
for (size_t p : {0, 2})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
{NonlineMode::IDENTITY, NonlineMode::RELU,
NonlineMode::H_SWISH}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
......@@ -1351,7 +1355,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) {
for (size_t p : {0, 1})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
{NonlineMode::IDENTITY, NonlineMode::RELU,
NonlineMode::H_SWISH}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
......@@ -1418,7 +1423,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA_FILTER_PREPROCESS) {
for (size_t p : {0, 1})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
{NonlineMode::IDENTITY, NonlineMode::RELU,
NonlineMode::H_SWISH}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
......
......@@ -1827,6 +1827,10 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
elem->param().mode == Mode::FUSE_ADD_SIGMOID ||
elem->param().mode == Mode::SIGMOID) {
return NonlineMode::SIGMOID;
} else if (
elem->param().mode == Mode::FUSE_ADD_H_SWISH ||
elem->param().mode == Mode::H_SWISH) {
return NonlineMode::H_SWISH;
} else {
return NonlineMode::IDENTITY;
}
......@@ -1836,8 +1840,8 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
bool can_be_fused = true;
can_be_fused &= (elem->input().size() == 2);
can_be_fused &= (elem->param().mode == Mode::FUSE_ADD_RELU) ||
(elem->param().mode == Mode::FUSE_ADD_TANH) ||
(elem->param().mode == Mode::FUSE_ADD_SIGMOID);
(elem->param().mode == Mode::FUSE_ADD_SIGMOID) ||
(elem->param().mode == Mode::FUSE_ADD_H_SWISH);
return can_be_fused;
};
......@@ -1853,7 +1857,8 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
bool can_be_fused = true;
can_be_fused &= (elem->input().size() == 1);
can_be_fused &= (elem->param().mode == Mode::RELU) ||
(elem->param().mode == Mode::SIGMOID);
(elem->param().mode == Mode::SIGMOID) ||
(elem->param().mode == Mode::H_SWISH);
return can_be_fused;
};
......
......@@ -1937,6 +1937,52 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass2) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4);
}
TEST(TestGoptInference, ConvBiasNonlinearityFusePassHswish) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
auto cn = CompNode::load("cpu0");
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
opr::Convolution::Param param;
auto x = mkvar("x", {5, 8, 16, 24}), w1 = mkcvar("w1", {4, 8, 1, 1}),
w2 = mkcvar("w2", {4, 8, 1, 1});
auto b1 = mkcvar("b1", {1, 4, 1, 1});
auto y_cut = opr::Convolution::make(x, w1, param);
auto y = opr::Elemwise::make({y_cut + b1}, opr::Elemwise::Param::Mode::H_SWISH);
y = opr::Elemwise::make({y}, opr::Elemwise::Param::Mode::RELU);
auto y_cut2 = opr::Convolution::make(x, w2, param);
y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::H_SWISH);
y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::RELU);
y = y + y_cut2;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4().enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(
opr::ConvBias::Param::NonlineMode::H_SWISH,
find_opr<opr::ConvBias>(y_opt).param().nonlineMode);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.FuseConvBiasNonlinPassHswish.json"));
HostTensorND host_y, host_y_opt;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4);
}
TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) {
NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册