提交 2e70cf1d 编写于 作者: M Megvii Engine Team

feat(mgb/opt): add nchw->nchw4 in tensorcore pass

GitOrigin-RevId: 755f8dfefe28bb14dba5d86a54a9bb725af285ed
上级 1e8337f1
...@@ -756,6 +756,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( ...@@ -756,6 +756,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb(nchw32, { cb(nchw32, {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
add_pass(EnableNCHW4Pass::make_nchw4_converter());
add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(EnableTensorCorePass::make_tensorcore_converter());
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>(); add_pass<RemoveRedundantTypeCvtPass>();
...@@ -763,6 +764,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( ...@@ -763,6 +764,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb(chwn4, { cb(chwn4, {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
add_pass(EnableNCHW4Pass::make_nchw4_converter());
add_pass(EnableCHWN4Pass::make_chwn4_converter()); add_pass(EnableCHWN4Pass::make_chwn4_converter());
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>(); add_pass<RemoveRedundantTypeCvtPass>();
......
...@@ -1356,16 +1356,17 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1356,16 +1356,17 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
megdnn::param::Convolution::Format conv_format = megdnn::param::Convolution::Format conv_format =
megdnn::param::Convolution::Format::NCHW4; megdnn::param::Convolution::Format::NCHW4;
megdnn::param::ConvBias::Format conv_bias_format = megdnn::param::ConvBias::Format conv_bias_format =
megdnn::param::ConvBias::Format::NCHW4; megdnn::param::ConvBias::Format::NCHW4;
megdnn::param::BatchConvBias::Format batch_conv_bias_format = megdnn::param::BatchConvBias::Format batch_conv_bias_format =
megdnn::param::BatchConvBias::Format::NCHW4; megdnn::param::BatchConvBias::Format::NCHW4;
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4;
RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
RelayoutMode weight_to_nchw4_mode_dense = RelayoutMode weight_to_nchw4_mode_dense =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE;
RelayoutMode weight_to_nchw4_mode_group = RelayoutMode weight_to_nchw4_mode_group =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP;
auto trans_nchw4 = [weight_to_nchw4_mode_dense, auto trans_nchw4 = [weight_to_nchw4_mode_dense,
weight_to_nchw4_mode_group]( weight_to_nchw4_mode_group](
const megdnn::param::Convolution::Sparse conv_mode, const megdnn::param::Convolution::Sparse conv_mode,
...@@ -1391,9 +1392,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1391,9 +1392,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
OperatorNodeBase* opr, const VarNodeArray& new_inp) { OperatorNodeBase* opr, const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
mgb_assert(conv_opr.param().format == if (conv_opr.param().format !=
megdnn::param::Convolution::Format::NCHW, megdnn::param::Convolution::Format::NCHW) {
"ConvertFormat Pass only support converting NCHW to NCHW4"); return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
// src: NCHW --> NCWH4 // src: NCHW --> NCWH4
if (new_inp[0]->shape().ndim != 5) { if (new_inp[0]->shape().ndim != 5) {
...@@ -1427,7 +1430,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1427,7 +1430,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& batch_conv_bias_opr = auto& batch_conv_bias_opr =
opr->cast_final_safe<opr::BatchConvBiasForward>(); opr->cast_final_safe<opr::BatchConvBiasForward>();
mgb_assert(batch_conv_bias_opr.param().format == if (batch_conv_bias_opr.param().format !=
megdnn::param::BatchConvBias::Format::NCHW) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(batch_conv_bias_opr.param().format ==
megdnn::param::BatchConvBias::Format::NCHW, megdnn::param::BatchConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4"); "ConvertFormat Pass only support converting NCHW to NCHW4");
// what should be converted: src, weight // what should be converted: src, weight
...@@ -1494,9 +1503,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1494,9 +1503,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format == if (conv_bias_opr.param().format !=
megdnn::param::ConvBias::Format::NCHW, megdnn::param::Convolution::Format::NCHW) {
"ConvertFormat Pass only support converting NCHW to NCHW4"); return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
// what should be converted: src, weight // what should be converted: src, weight
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
// src: NCHW --> NCHW4 // src: NCHW --> NCHW4
...@@ -1604,8 +1616,9 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1604,8 +1616,9 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using Format = Param::Format; using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
mgb_assert(pooling.param().format == Format::NCHW, if (pooling.param().format != Format::NCHW) {
"ConvertFormat Pass only support converting NCHW to NCHW4."); return opr;
}
if (new_inp[0]->shape().ndim == 5) { if (new_inp[0]->shape().ndim == 5) {
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = pooling.param(); auto new_param = pooling.param();
...@@ -1628,8 +1641,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1628,8 +1641,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using Format = Param::Format; using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& resize = opr->cast_final_safe<opr::ResizeForward>(); auto& resize = opr->cast_final_safe<opr::ResizeForward>();
mgb_assert(resize.param().format == Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4.");
if (new_inp[0]->shape().ndim == 5) { if (new_inp[0]->shape().ndim == 5) {
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = resize.param(); auto new_param = resize.param();
...@@ -1652,8 +1663,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1652,8 +1663,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
using Format = Param::Format; using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>(); auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>();
mgb_assert(warp.param().format == Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4.");
if (new_inp[0]->shape().ndim == 5) { if (new_inp[0]->shape().ndim == 5) {
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = warp.param(); auto new_param = warp.param();
......
...@@ -1512,6 +1512,7 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) { ...@@ -1512,6 +1512,7 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) {
#if MGB_CUDA #if MGB_CUDA
TEST(TestEnableTensorCore, SmallInputShape) { TEST(TestEnableTensorCore, SmallInputShape) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0"); auto cn = CompNode::load("gpu0");
...@@ -1579,6 +1580,104 @@ TEST(TestEnableTensorCore, SmallInputShape) { ...@@ -1579,6 +1580,104 @@ TEST(TestEnableTensorCore, SmallInputShape) {
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
} }
TEST(TestEnableTensorCore, Nchw4Nchw) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
cn.activate();
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
auto sm_ver = prop.major * 10 + prop.minor;
if (sm_ver < 75) {
printf("This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)\n",
sm_ver, 75);
return;
}
HostTensorGenerator<dtype::Int8> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto mkshape = [](opr::ConvBias::Param::Format format, size_t N, size_t C,
size_t H, size_t W) -> TensorShape {
mgb_assert(C % 4 == 0);
if (format == opr::ConvBias::Param::Format::NCHW4) {
return {N, C / 4, H, W, 4};
} else {
mgb_assert(format == opr::ConvBias::Param::Format::NCHW);
return {N, C, H, W};
}
};
for (auto format : {opr::ConvBias::Param::Format::NCHW,
opr::ConvBias::Param::Format::NCHW4}) {
auto x = mkvar("x", mkshape(format, 32, 64, 16, 16),
dtype::QuantizedS8(2.5f)),
w = mkcvar("w1", mkshape(format, 64, 64, 3, 3),
dtype::QuantizedS8(2.5f)),
b = mkcvar("b", mkshape(format, 1, 64, 1, 1),
dtype::QuantizedS32(6.25f)),
z = mkcvar("b1", mkshape(format, 32, 64, 8, 8),
dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param;
param.format = format;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param.stride_h = param.stride_w = 2;
param.pad_h = param.pad_w = 1;
auto y = opr::ConvBias::make(
x, w, b, z, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(y, w, b, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Float32());
SymbolVar y_opt;
SymbolVar y_no_tc;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw32().enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({y}, options), y_no_tc);
}
auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt);
std::string json_name;
ASSERT_EQ(2u, nr_dimshuffle);
if (format == opr::ConvBias::Param::Format::NCHW4) {
json_name = "TestGoptInference.Nchw4Nchw.NCHW4.json";
} else {
mgb_assert(format == opr::ConvBias::Param::Format::NCHW);
json_name = "TestGoptInference.Nchw4Nchw.NCHW.json";
}
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(json_name.c_str()));
HostTensorND host_y, host_y_opt;
auto func = graph->compile({make_callback_copy(y_no_tc, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
}
}
TEST(TestEnableTensorCore, ConvBiasWithZ) { TEST(TestEnableTensorCore, ConvBiasWithZ) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0"); auto cn = CompNode::load("gpu0");
...@@ -2043,53 +2142,74 @@ TEST(TestGoptInference, EnableCHWN4) { ...@@ -2043,53 +2142,74 @@ TEST(TestGoptInference, EnableCHWN4) {
.rename(name), .rename(name),
dtype); dtype);
}; };
auto mkshape = [](opr::ConvBias::Param::Format format, size_t N, size_t C,
size_t H, size_t W) -> TensorShape {
mgb_assert(C % 4 == 0);
if (format == opr::ConvBias::Param::Format::NCHW4) {
return {N, C / 4, H, W, 4};
} else {
mgb_assert(format == opr::ConvBias::Param::Format::NCHW);
return {N, C, H, W};
}
};
auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)), for (auto format : {opr::ConvBias::Param::Format::NCHW,
w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), opr::ConvBias::Param::Format::NCHW4}) {
b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), auto x = mkvar("x", mkshape(format, 32, 64, 16, 16),
b1 = mkvar("b1", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)); dtype::QuantizedS8(2.5f)),
opr::ConvBias::Param param; w = mkcvar("w1", mkshape(format, 64, 64, 3, 3),
param.format = opr::ConvBias::Param::Format::NCHW4; dtype::QuantizedS8(2.5f)),
param.stride_h = param.stride_w = 1; b = mkcvar("b", mkshape(format, 1, 64, 1, 1),
param.pad_h = param.pad_w = 1; dtype::QuantizedS32(6.25f)),
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; b1 = mkvar("b1", mkshape(format, 32, 64, 16, 16),
dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param;
param.format = format;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
auto y = opr::ConvBiasForward::make( auto y = opr::ConvBiasForward::make(
x, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); x, w, b, param, {},
auto y1 = opr::ElemwiseMultiType::make( OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
{y, b1}, opr::ElemwiseMultiType::Mode::QFUSE_ADD_RELU, auto y1 = opr::ElemwiseMultiType::make(
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); {y, b1}, opr::ElemwiseMultiType::Mode::QFUSE_ADD_RELU,
auto y2 = opr::ConvBiasForward::make( OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); auto y2 = opr::ConvBiasForward::make(
auto y3 = opr::ElemwiseMultiType::make( y, w, b, param, {},
{y, b1}, opr::ElemwiseMultiType::Param::Mode::QSUB, OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); auto y3 = opr::ElemwiseMultiType::make(
auto y4 = opr::ElemwiseMultiType::make( {y, b1}, opr::ElemwiseMultiType::Param::Mode::QSUB,
{y1, y2}, opr::ElemwiseMultiType::Param::Mode::QADD, OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); auto y4 = opr::ElemwiseMultiType::make(
y4 = opr::ElemwiseMultiType::make( {y1, y2}, opr::ElemwiseMultiType::Param::Mode::QADD,
{y3, y4}, opr::ElemwiseMultiType::Param::Mode::QADD, OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); y4 = opr::ElemwiseMultiType::make(
y4 = opr::TypeCvt::make(y4, dtype::Float32()); {y3, y4}, opr::ElemwiseMultiType::Param::Mode::QADD,
SymbolVar y_opt; OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
SymbolVar y_cudnn; y4 = opr::TypeCvt::make(y4, dtype::Float32());
{ SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; SymbolVar y_cudnn;
options.enable_chwn4(); {
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); auto options = gopt::OptimizeForInferenceOptions{};
options.enable_chwn4();
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt);
}
unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass<gopt::FuseConvBiasZPass>()
.apply({{y4}})
.endpoint_vars(),
y_cudnn);
ASSERT_EQ(opr::ConvBias::Param::Format::CHWN4,
find_opr<opr::ConvBias>(y_opt).param().format);
HostTensorND host_y, host_y_opt;
auto func = graph->compile({make_callback_copy(y_cudnn, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
} }
unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass<gopt::FuseConvBiasZPass>()
.apply({{y4}})
.endpoint_vars(),
y_cudnn);
HostTensorND host_y, host_y_opt;
auto func = graph->compile({make_callback_copy(y_cudnn, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
} }
TEST(TestGoptInference, EnableCHWN4WarpPespective) { TEST(TestGoptInference, EnableCHWN4WarpPespective) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册