提交 45e2beea 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): add nchw4 optpass

GitOrigin-RevId: 551b6b828d33916b8e0a8bec73e6d3c6abd65536
上级 f2e1bb41
......@@ -541,6 +541,7 @@ def optimize_for_inference(
fuse_conv_bias_nonlinearity=False,
use_nchw32=False,
fuse_conv_bias_with_z=False,
use_nchw4=False,
use_nchw88=False,
use_nchw44=False,
use_chwn4=False
......@@ -561,6 +562,7 @@ def optimize_for_inference(
OpenCL devices
:param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. This is supported only in NHWCD4 format.
:param use_nchw4: whether to use NCHW4 tensor format.
:param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some
times.
:param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
......@@ -588,6 +590,7 @@ def optimize_for_inference(
layout_tranform = None
for k, v in {
"use_nchw4": "nchw4",
"use_nhwcd4": "nhwcd4",
"use_nchw32": "nchw32",
"use_nchw88": "nchw88",
......
......@@ -463,6 +463,7 @@ class trace:
"enable_io16xc32": "f16_io_f32_comp",
"enable_ioc16": "f16_io_comp",
"enable_hwcd4": "use_nhwcd4",
"enable_nchw4": "use_nchw4",
"enable_nchw88": "use_nchw88",
"enable_nchw32": "use_nchw32",
"enable_nchw44": "use_nchw44",
......
......@@ -80,6 +80,7 @@ struct _OptimizeForInferenceOptions {
#define SET(_trans, _trans_capital) \
void enable_##_trans(); \
SET(nchw4, NCHW4);
SET(nhwcd4, NHWCD4);
SET(nchw88, NCHW88);
SET(nchw44, NCHW44);
......
......@@ -252,6 +252,7 @@ def optimize_for_inference(args, outputs):
'enable_io16xc32': 'f16_io_f32_comp',
'enable_ioc16': 'f16_io_comp',
'enable_hwcd4': 'use_nhwcd4',
'enable_nchw4': 'use_nchw4',
'enable_nchw88': 'use_nchw88',
'enable_nchw44': 'use_nchw44',
'enable_nchw32': 'use_nchw32',
......@@ -381,6 +382,12 @@ def main():
'for inference; you may need to disable CUDA and set '
'MGB_USE_MEGDNN_DBG=2'
)
parser.add_argument(
'--enable-nchw4',
action='store_true',
help='transform the model format from NCHW to NCHW4 '
'for inference'
)
parser.add_argument(
'--enable-nchw88',
action='store_true',
......
......@@ -980,6 +980,7 @@ Args Args::from_argv(int argc, char **argv) {
continue; \
}
cb(nchw4);
cb(chwn4);
cb(nchw44);
cb(nchw88);
......
......@@ -97,6 +97,7 @@ struct GraphCommonOptimizeOptions {
bool fuse_conv_bias_with_z = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format
NHWCD4, ///< compute using NHWCD4 tensor format
NCHW88, ///< compute using NCHW88 tensor format
NCHW44, ///< compute using NCHW44 tensor format
......@@ -137,6 +138,7 @@ struct GraphCommonOptimizeOptions {
return layout_transform == LayoutTransform::_trans_capital; \
}
SET(nchw4, NCHW4);
SET(nhwcd4, NHWCD4);
SET(nchw88, NCHW88);
SET(nchw44, NCHW44);
......
......@@ -725,6 +725,13 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
cb(nchw4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableNCHW4Pass::make_nchw4_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
});
cb(nhwcd4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(ConvertFormatPass::make_nhwcd4_converter());
......
......@@ -63,10 +63,15 @@ public:
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW_TO_NCHW44, //!< from nchw layout to nchw44 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
NCHW44_TO_NCHW, //!< from nchw44 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout
......@@ -167,6 +172,42 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[2];
dst[4] = inp_shape[4];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW){
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_GROUP) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[2] % 4 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1];
dst[2] = inp_shape[2] / 4;
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 4;
}else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0);
dst.ndim = 5;
......@@ -226,23 +267,6 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[2] = inp_shape[3];
dst[3] = inp_shape[1];
dst[4] = 8;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW44) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 &&
......@@ -394,6 +418,66 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2) / 4, cv(4), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1), sub(2) / 4, sub(3), sub(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 2, 4, 5, 3});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
......@@ -492,34 +576,6 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW44] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW44_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
......@@ -1239,6 +1295,293 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
return ret;
}
/* ================ EnableNCHW4Pass ================ */
VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
return RelayoutPlaceholder::make(
new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
.node();
}
return new_var;
}
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
auto ret = std::make_unique<EnableNCHW4Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
using RelayoutMode = RelayoutPlaceholder::LayoutType;
megdnn::param::Convolution::Format conv_format =
megdnn::param::Convolution::Format::NCHW4;
megdnn::param::ConvBias::Format conv_bias_format =
megdnn::param::ConvBias::Format::NCHW4;
megdnn::param::BatchConvBias::Format batch_conv_bias_format =
megdnn::param::BatchConvBias::Format::NCHW4;
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4;
RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
RelayoutMode weight_to_nchw4_mode_dense =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE;
RelayoutMode weight_to_nchw4_mode_group =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP;
auto trans_nchw4 = [weight_to_nchw4_mode_dense,
weight_to_nchw4_mode_group](
const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> RelayoutMode {
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
mgb_assert(filter->shape().ndim == 4,
"The origin filter is not NCHW mode");
size_t IC = filter->shape()[1];
mgb_assert(IC % 4 == 0,
"The input channel should be divisible by 4");
return weight_to_nchw4_mode_dense;
} else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
mgb_assert(filter->shape().ndim == 5,
"The origin filter if not NCHW mode");
size_t IC = filter->shape()[2];
mgb_assert(IC % 4 == 0,
"The input channel should be divisible by 4");
return weight_to_nchw4_mode_group;
}
};
auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
mgb_assert(conv_opr.param().format ==
megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4");
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
// src: NCHW --> NCWH4
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0],
src_to_nchw4_mode);
conv_src = new_src.node();
}
// weight: NCHW --> NCHW4
auto weight_mode =
trans_nchw4(conv_opr.param().sparse, new_inp[1]);
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
conv_filter = new_filter.node();
// format: NCHW --> NCHW4
auto new_param = conv_opr.param();
new_param.format = conv_format;
// dst
auto new_conv_opr = opr::Convolution::make(
conv_src, conv_filter, new_param,
conv_opr.execution_policy(), conv_opr.config());
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
mgb_assert(new_conv_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchw4");
return new_opr;
};
auto replace_batch_conv_bias_opr = [batch_conv_bias_format,
src_to_nchw4_mode](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& batch_conv_bias_opr =
opr->cast_final_safe<opr::BatchConvBiasForward>();
mgb_assert(batch_conv_bias_opr.param().format ==
megdnn::param::BatchConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4");
// what should be converted: src, weight
VarNode *src = new_inp[0], *filter = new_inp[1];
// src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim !=5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0],
src_to_nchw4_mode);
src = new_src.node();
}
// weight: BNCHW --> BNCHW4
// only support dense mode, which is similar with conv->group.
auto weight_mode =
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP;
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
filter = new_filter.node();
// format: NCHW --> NCHW4
auto new_param = batch_conv_bias_opr.param();
new_param.format = batch_conv_bias_format;
if (new_inp.size() == 2) {
auto dst = opr::BatchConvBias::make(src, filter, new_param,
batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr();
mgb_assert(dst.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// bias: NCHW --> NCHW4
VarNode* bias = new_inp[2];
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2],
src_to_nchw4_mode);
bias = new_bias.node();
}
if (new_inp.size() == 3) {
auto dst = opr::BatchConvBias::make(src, filter, bias, new_param,
batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr();
mgb_assert(dst.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// z_inp: NCHW --> NCHW4
VarNode* z_inp = new_inp[3];
if (new_inp[3]->shape().ndim == 4) {
auto new_z = RelayoutPlaceholder::make(new_inp[3],
src_to_nchw4_mode);
z_inp = new_z.node();
}
auto dst = opr::BatchConvBias::make(src, filter, bias, z_inp,
new_param,batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr();
mgb_assert(dst.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
};
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format,
src_to_nchw4_mode](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4");
// what should be converted: src, weight
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
// src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim !=5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0],
src_to_nchw4_mode);
conv_bias_src = new_src.node();
}
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
auto weight_mode =
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
conv_bias_filter = new_filter.node();
// format: NCHW --> NCHW4
auto new_param = conv_bias_opr.param();
new_param.format = conv_bias_format;
if (new_inp.size() == 2) {
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// bias: NCHW --> NCHW4
VarNode* conv_bias_bias = new_inp[2];
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2],
src_to_nchw4_mode);
conv_bias_bias = new_bias.node();
}
if (new_inp.size() == 3) {
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// z_inp: NCHW --> NCHW4
VarNode* z_inp = new_inp[3];
if (new_inp[3]->shape().ndim == 4) {
auto new_z = RelayoutPlaceholder::make(new_inp[3],
src_to_nchw4_mode);
z_inp = new_z.node();
}
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src,
conv_bias_filter, conv_bias_bias, z_inp, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 5) {
has_inp_changed = true;
break;
}
}
if (has_inp_changed) {
auto temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 4) {
auto new_var = RelayoutPlaceholder::make(
new_inp[i], src_to_nchw4_mode);
temp_inp[i] = new_var.node();
} else {
mgb_assert((new_inp[i]->shape().ndim == 5) ||
new_inp[i]->shape().is_scalar());
}
}
return serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
} else {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) {
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
mgb_assert(opr->input(i)->shape().ndim == 4);
mgb_assert(new_inp[i]->shape().ndim == 5);
auto new_var =
RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode);
temp_inp[i] = new_var.node();
}
}
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
};
auto&& replace_func = ret->m_opr_replace_func;
//! supportted nchw4
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::BatchConvBias::typeinfo()] =
replace_batch_conv_bias_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr;
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr;
//! not supported nchw4
replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ConvolutionBackwardData::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
return ret;
}
/* ================ EnableNchwxxPass =============== */
VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
......@@ -1251,7 +1594,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
} else if (m_pack_c_size == 4) {
return RelayoutPlaceholder::make(
new_var,
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW)
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
.node();
}
}
......@@ -1287,8 +1630,8 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW44;
src_to_nchw_mode = RelayoutMode::NCHW44_TO_NCHW;
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW4;
src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
conv_format = megdnn::param::ConvolutionV0::Format::NCHW44;
pooling_format = megdnn::param::Pooling::Format::NCHW44;
......
......@@ -229,6 +229,19 @@ namespace gopt {
static std::unique_ptr<EnableCHWN4Pass> make_chwn4_converter();
};
/*!
* \brief convert tensor format to nchw4 to speed up inference on CUDA
*/
class EnableNCHW4Pass final : public TensorReformatPass {
VarNode* on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const override;
public:
const char* name() const override { return mgb_cstr_log("tensor_format_nchw4"); }
//! make nchw -> nchw4 converter opt pass
static std::unique_ptr<EnableNCHW4Pass> make_nchw4_converter();
};
/*!
* \brief convert tensor format to nchwxx to speed up inference on certain
* devices
......
......@@ -2327,8 +2327,134 @@ TEST(TestGoptInference, EnableCHWN4ShuffleRemove) {
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
}
TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
cn.activate();
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
auto sm_ver = prop.major * 10 + prop.minor;
if (sm_ver < 61) {
printf("This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)\n",
sm_ver, 61);
return;
}
HostTensorGenerator<dtype::Int8> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param_conv_bias;
param_conv_bias.format = opr::ConvBias::Param::Format::NCHW;
param_conv_bias.stride_h = param_conv_bias.stride_w = 1;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
// dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv1 = opr::ConvBiasForward::make(
x, w1, b1, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
// group
// icpg != 1 && ocpg != 1
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2,
param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto y = opr::TypeCvt::make(conv2, dtype::Float32());
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ConvertFormatNCHW4GPU.json"));
HostTensorND host_y, host_y_opt;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
}
#endif
TEST(TestGoptInference, ConvertFormatNCHW4) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto x = mkvar("x", {2, 4, 16, 16});
// ConvBias
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {8, 4, 3, 3}), b1 = mkcvar("b1", {1, 8, 1, 1});
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias);
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1});
auto conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias);
// Convolution
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
param_conv.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w3 = mkcvar("w3", {8, 8, 3, 3});
auto y = opr::Convolution::make(conv2, w3, param_conv);
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ConvertFormatNCHW4.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNCHW88) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册