提交 f1c86606 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/cuda): fix FuseConvBiasWithZ pass for HSwish activation

GitOrigin-RevId: b290469cb1993214c01cab823864476778963f70
上级 adfa4688
......@@ -1875,6 +1875,8 @@ void FuseConvBiasZPass::apply(OptState& state) const {
auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
if (elem->param().mode == MultiMode::QFUSE_ADD_RELU)
return NonlineMode::RELU;
else if (elem->param().mode == MultiMode::QFUSE_ADD_H_SWISH)
return NonlineMode::H_SWISH;
}
return NonlineMode::IDENTITY;
};
......@@ -1941,7 +1943,8 @@ void FuseConvBiasZPass::apply(OptState& state) const {
if (elem->input().size() != 2)
return false;
if (elem->param().mode != MultiMode::QADD &&
elem->param().mode != MultiMode::QFUSE_ADD_RELU)
elem->param().mode != MultiMode::QFUSE_ADD_RELU &&
elem->param().mode != MultiMode::QFUSE_ADD_H_SWISH)
return false;
return try_replace_var_node(opr);
};
......
......@@ -1701,80 +1701,96 @@ TEST(FuseConvBiasZPass, BlockFuse) {
dtype);
};
auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)),
w1 = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b1 = mkcvar("b1", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
w2 = mkcvar("w2", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b2 = mkcvar("b2", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
w3 = mkcvar("w3", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b3 = mkcvar("b3", {1, 16, 1, 1, 4}, dtype::QuantizedS32(3.0f));
using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode;
using NonlineMode = opr::ConvBias::Param::NonlineMode;
for (auto mode :
{ElemMultiMode::QFUSE_ADD_RELU, ElemMultiMode::QFUSE_ADD_H_SWISH}) {
auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)),
w1 = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b1 = mkcvar("b1", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
w2 = mkcvar("w2", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b2 = mkcvar("b2", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
w3 = mkcvar("w3", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b3 = mkcvar("b3", {1, 16, 1, 1, 4}, dtype::QuantizedS32(3.0f));
NonlineMode nonline_mode = NonlineMode::RELU;
if (mode == ElemMultiMode::QFUSE_ADD_H_SWISH) {
nonline_mode = NonlineMode::H_SWISH;
}
opr::ConvBias::Param param;
param.format = opr::Convolution::Param::Format::NCHW4;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1;
opr::ConvBias::Param param;
param.format = opr::Convolution::Param::Format::NCHW4;
param.nonlineMode = nonline_mode;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1;
auto y1 = opr::ConvBias::make(x, w1, b1, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
auto y2 = opr::ConvBias::make(y1, w2, b2, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}),
y3 = opr::ElemwiseMultiType::make(
{y1, y2},
{opr::ElemwiseMultiType::Param::Mode::QFUSE_ADD_RELU},
OperatorNodeConfig{dtype::QuantizedS8(1.2f)});
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
auto y4 = opr::ConvBias::make(y3, w3, b3, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}),
z = opr::ElemwiseMultiType::make(
{y3, y4},
{opr::ElemwiseMultiType::Param::Mode::QADD},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
z = opr::TypeCvt::make(z, dtype::Float32());
//! fuse z mannually
auto z0 = opr::ConvBias::make(x, w1, b1, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
auto z1 = opr::ConvBias::make(z0, w2, b2, z0, param, {},
OperatorNodeConfig{dtype::QuantizedS8(1.2f)}),
z2 = opr::ConvBias::make(z1, w3, b3, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}),
z4 = opr::ElemwiseMultiType::make(
{z1, z2}, {opr::ElemwiseMultiType::Mode::QADD},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
z4 = opr::TypeCvt::make(z4, dtype::Float32());
SymbolVar z_fuse;
SymbolVar z_nonfuse;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity()
.enable_fuse_conv_bias_with_z();
unpack_vector(gopt::optimize_for_inference({z}, options), z_fuse);
}
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({z4}, options), z_nonfuse);
}
auto nr_elem_multi_type = find_opr_num<mgb::opr::ElemwiseMultiType>(z_fuse);
MGB_MARK_USED_VAR(nr_elem_multi_type);
ASSERT_EQ(1u, nr_elem_multi_type);
graph->compile({{z_fuse, {}}})
->to_json()
->writeto_fpath(
output_file("FuseConvBiasZPass.BlockFuse_fuse.json"));
graph->compile({{z_nonfuse, {}}})
->to_json()
->writeto_fpath(
output_file("FuseConvBiasZPass.BlockFuse_nonfuse.json"));
HostTensorND host_z_fuse, host_z_nonfuse;
auto func = graph->compile({make_callback_copy(z_nonfuse, host_z_nonfuse),
auto y1 = opr::ConvBias::make(
x, w1, b1, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
auto y2 = opr::ConvBias::make(
y1, w2, b2, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}),
y3 = opr::ElemwiseMultiType::make(
{y1, y2}, {mode},
OperatorNodeConfig{dtype::QuantizedS8(1.2f)});
param.nonlineMode = nonline_mode;
auto y4 = opr::ConvBias::make(
y3, w3, b3, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}),
z = opr::ElemwiseMultiType::make(
{y3, y4}, {opr::ElemwiseMultiType::Param::Mode::QADD},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
z = opr::TypeCvt::make(z, dtype::Float32());
//! fuse z mannually
auto z0 = opr::ConvBias::make(
x, w1, b1, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
auto z1 = opr::ConvBias::make(
z0, w2, b2, z0, param, {},
OperatorNodeConfig{dtype::QuantizedS8(1.2f)}),
z2 = opr::ConvBias::make(
z1, w3, b3, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)}),
z4 = opr::ElemwiseMultiType::make(
{z1, z2}, {opr::ElemwiseMultiType::Mode::QADD},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
z4 = opr::TypeCvt::make(z4, dtype::Float32());
SymbolVar z_fuse;
SymbolVar z_nonfuse;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity()
.enable_fuse_conv_bias_with_z();
unpack_vector(gopt::optimize_for_inference({z}, options), z_fuse);
}
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({z4}, options),
z_nonfuse);
}
auto nr_elem_multi_type =
find_opr_num<mgb::opr::ElemwiseMultiType>(z_fuse);
MGB_MARK_USED_VAR(nr_elem_multi_type);
ASSERT_EQ(1u, nr_elem_multi_type);
graph->compile({{z_fuse, {}}})
->to_json()
->writeto_fpath(
output_file("FuseConvBiasZPass.BlockFuse_fuse.json"));
graph->compile({{z_nonfuse, {}}})
->to_json()
->writeto_fpath(output_file(
"FuseConvBiasZPass.BlockFuse_nonfuse.json"));
HostTensorND host_z_fuse, host_z_nonfuse;
auto func =
graph->compile({make_callback_copy(z_nonfuse, host_z_nonfuse),
make_callback_copy(z_fuse, host_z_fuse)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_z_fuse, host_z_nonfuse);
func->execute();
MGB_ASSERT_TENSOR_EQ(host_z_fuse, host_z_nonfuse);
}
}
TEST(TestEnableTensorCore, ShuffleMerge) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册