提交 30b3d3aa 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/gopt): add convolution nchw44-dot format gopt

GitOrigin-RevId: e8e1e9637944ead470ebe4e2b697ddf7d437aaba
上级 48d1ac14
......@@ -287,8 +287,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
}
} else if (filter_meta.format == Format::NCHW44 ||
filter_meta.format == Format::NCHW44_DOT) {
if (filter_meta.format == Format::NCHW44 && !is_output &&
src.layout.ndim == 4) {
if (!is_output && src.layout.ndim == 4) {
return n * layout.stride[0] + c * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3];
} else {
......
......@@ -554,6 +554,7 @@ def optimize_for_inference(
use_nchw4=False,
use_nchw88=False,
use_nchw44=False,
use_nchw44_dot=False,
use_chwn4=False
):
"""optimize computing graph for inference
......@@ -577,6 +578,8 @@ def optimize_for_inference(
times.
:param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
times.
:param use_nchw44_dot: whether to use NCHW44_DOT tensor format. This format is
optimized for inference in armv8.2
:param use_nchw32: whether to use NCHW32 tensor format. Mainly used for
nvidia tensorcore.
:param use_chwn4: whether to use CHWN4 tensor format. Mainly used for
......@@ -605,6 +608,7 @@ def optimize_for_inference(
"use_nchw32": "nchw32",
"use_nchw88": "nchw88",
"use_nchw44": "nchw44",
"use_nchw44_dot": "nchw44_dot",
"use_chwn4": "chwn4",
}.items():
if settings[k]:
......
......@@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions {
SET(nhwcd4, NHWCD4);
SET(nchw88, NCHW88);
SET(nchw44, NCHW44);
SET(nchw44_dot, NCHW44_DOT);
SET(nchw32, NCHW32);
SET(chwn4, CHWN4);
#undef SET
......
......@@ -255,6 +255,7 @@ def optimize_for_inference(args, outputs):
'enable_nchw4': 'use_nchw4',
'enable_nchw88': 'use_nchw88',
'enable_nchw44': 'use_nchw44',
'enable_nchw44_dot': 'use_nchw44_dot',
'enable_nchw32': 'use_nchw32',
'enable_chwn4': 'use_chwn4',
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
......@@ -400,6 +401,12 @@ def main():
help='transform the model format from NCHW to NCHW44 '
'for inference'
)
parser.add_argument(
'--enable-nchw44-dot',
action='store_true',
help='transform the model format from NCHW to NCHW44_DOT '
'for optimizing armv8.2 dot in inference'
)
parser.add_argument(
'--enable-nchw32',
action='store_true',
......
......@@ -97,14 +97,15 @@ struct GraphCommonOptimizeOptions {
bool fuse_conv_bias_with_z = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format
NHWCD4, ///< compute using NHWCD4 tensor format
NCHW88, ///< compute using NCHW88 tensor format
NCHW44, ///< compute using NCHW44 tensor format
NCHW32, ///< compute using NCHW32 tensor format, used for
///< tensorcore
CHWN4, ///< compute using CHWN4 tensor format, transformed mainly
///< used for cuda
NCHW4, ///< compute using NCHW4 tensor format
NHWCD4, ///< compute using NHWCD4 tensor format
NCHW88, ///< compute using NCHW88 tensor format
NCHW44, ///< compute using NCHW44 tensor format
NCHW44_DOT, ///< compute using NCHW44_DOT tensor format
NCHW32, ///< compute using NCHW32 tensor format, used for
///< tensorcore
CHWN4, ///< compute using CHWN4 tensor format, transformed mainly
///< used for cuda
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;
......@@ -142,6 +143,7 @@ struct GraphCommonOptimizeOptions {
SET(nhwcd4, NHWCD4);
SET(nchw88, NCHW88);
SET(nchw44, NCHW44);
SET(nchw44_dot, NCHW44_DOT);
SET(nchw32, NCHW32);
SET(chwn4, CHWN4);
#undef SET
......
......@@ -738,6 +738,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
});
cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); });
cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); });
cb(nchw44_dot,
{ add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); });
cb(nchw32, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
......
......@@ -28,6 +28,7 @@
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h"
#if MGB_ENABLE_TENSOR_RT
......@@ -59,19 +60,19 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
public:
//! relayout type of this opr
enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout
......@@ -92,6 +93,10 @@ public:
//!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44,
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group
};
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
......@@ -268,7 +273,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[1];
dst[4] = 8;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DENSE) {
WEIGHT_NCHW_TO_NCHW44_DENSE ||
layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 &&
inp_shape[1] % 4 == 0);
dst.ndim = 6;
......@@ -279,7 +286,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[4] = 4;
dst[5] = 4;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_GROUP) {
WEIGHT_NCHW_TO_NCHW44_GROUP ||
layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 &&
inp_shape[2] % 4 == 0);
dst.ndim = 7;
......@@ -646,6 +655,42 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 1, 3});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
cv(4), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
sub(4), cv(4), cv(4)},
0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) {
......@@ -1601,12 +1646,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
return new_var;
}
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
size_t pack_c_size) {
auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size);
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestFilterResult = std::pair<TransType, RelayoutMode>;
RelayoutMode weight_to_nchwxx_mode_dense =
......@@ -1954,8 +1994,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
};
ret->set_name(convter_pass_name);
auto&& replace_func = ret->m_opr_replace_func;
auto&& replace_func = m_opr_replace_func;
//! supportted nchwxx
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
......@@ -1978,6 +2017,246 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
}
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
size_t pack_c_size) {
auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size);
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
std::string convter_pass_name = "conv_format_nchw88";
if (pack_c_size == 4) {
convter_pass_name = "conv_format_nchw44";
}
ret->fill_opr_convert_fun(pack_c_size);
ret->set_name(convter_pass_name);
return ret;
}
/* ================ EnableNchw44DotPass =============== */
VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
return RelayoutPlaceholder::make(
new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
.node();
}
return new_var;
}
std::unique_ptr<EnableNchw44DotPass>
EnableNchw44DotPass::make_nchw44_dot_converter() {
auto ret = std::make_unique<EnableNchw44DotPass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestTransResult = std::pair<TransType, RelayoutMode>;
megdnn::param::ConvolutionV0::Format conv_dot_format =
megdnn::param::ConvBias::Format::NCHW44_DOT;
constexpr size_t pack_c_size = 4_z;
auto test_trans_nchw44_dot =
[](const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> TestTransResult {
TestTransResult ret{TransType::TRANS_NONE, {}};
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
size_t IC = filter->shape()[1];
size_t OC = filter->shape()[0];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX;
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
} else if (IC < pack_c_size && OC % pack_c_size == 0) {
ret.first = TransType::TRANS_HYBIRD_NCHWXX;
ret.second = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
}
} else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
size_t group = filter->shape()[0];
size_t ocpg = filter->shape()[1];
size_t icpg = filter->shape()[2];
if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
ret.first = TransType::TRANS_NONE;
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX;
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
}
}
return ret;
};
auto replace_conv_opr = [test_trans_nchw44_dot, conv_dot_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
mgb_assert(conv_opr.param().format ==
megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to "
"NCHW44_DOT");
auto is_trans =
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]);
//! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
VarNodeArray temp_inp = new_inp;
//! if src is nchwxx, should RelayoutPlaceholder to nchw
if (temp_inp[0]->shape().ndim == 5) {
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW4_TO_NCHW);
temp_inp[0] = new_src.node();
}
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
//! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
conv_filter = new_filter.node();
//! src trans to nchwxx mode
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
conv_src = new_src.node();
}
auto new_param = conv_opr.param();
new_param.format = conv_dot_format;
mgb_assert(conv_src->shape().ndim == 5 &&
conv_filter->shape().ndim >= 6,
"The conv src dim is not trans to nchwxx");
auto new_conv_opr = opr::Convolution::make(
conv_src, conv_filter, new_param,
conv_opr.execution_policy(), conv_opr.config());
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
mgb_assert(new_conv_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchwxx");
return new_opr;
} else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
conv_filter = new_filter.node();
mgb_assert(conv_src->shape().ndim == 4 &&
conv_filter->shape().ndim == 5,
"The src and filter is OK");
auto new_param = conv_opr.param();
new_param.format = conv_dot_format;
auto new_conv_opr = opr::Convolution::make(
conv_src, conv_filter, new_param,
conv_opr.execution_policy(), conv_opr.config());
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
mgb_assert(new_conv_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchwxx");
return new_opr;
}
};
auto replace_conv_bias_opr = [test_trans_nchw44_dot, conv_dot_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
auto is_trans =
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]);
//! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
VarNodeArray temp_inp = new_inp;
//! if src is nchwxx, should RelayoutPlaceholder to nchw
if (temp_inp[0]->shape().ndim == 5) {
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW4_TO_NCHW);
temp_inp[0] = new_src.node();
}
//! the bias is nchwxx
if (temp_inp[2]->shape().ndim == 5) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW4_TO_NCHW);
temp_inp[2] = new_bias.node();
}
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
//! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
conv_bias_filter = new_filter.node();
//! src trans to nchwxx mode
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_src = new_src.node();
}
//! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
}
auto new_param = conv_bias_opr.param();
new_param.format = conv_dot_format;
mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_filter->shape().ndim >= 6,
"The conv_bias src dim is not trans to nchwxx");
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchwxx");
return new_opr;
} else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
conv_bias_filter = new_filter.node();
//! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
}
mgb_assert(conv_bias_src->shape().ndim == 4 &&
conv_bias_filter->shape().ndim == 5);
mgb_assert((conv_bias_bias->shape().ndim == 5) ||
conv_bias_bias->shape().is_scalar());
auto new_param = conv_bias_opr.param();
new_param.format = conv_dot_format;
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchwxx");
return new_opr;
}
};
ret->fill_opr_convert_fun(4);
auto&& replace_func = ret->m_opr_replace_func;
//! supportted nchwxx
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
return ret;
}
......
......@@ -236,8 +236,10 @@ namespace gopt {
VarNode* on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const override;
public:
const char* name() const override { return mgb_cstr_log("tensor_format_nchw4"); }
const char* name() const override {
return mgb_cstr_log("tensor_format_nchw4");
}
//! make nchw -> nchw4 converter opt pass
static std::unique_ptr<EnableNCHW4Pass> make_nchw4_converter();
};
......@@ -246,30 +248,48 @@ namespace gopt {
* \brief convert tensor format to nchwxx to speed up inference on certain
* devices
*/
class EnableNchwxxPass final : public TensorReformatPass {
class EnableNchwxxPass : public TensorReformatPass {
std::string m_name = "tensor_format_nchwxx";
size_t m_pack_c_size;
VarNode* on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const override;
public:
EnableNchwxxPass(size_t pack_c_size) : m_pack_c_size(pack_c_size) {}
//! the flag for conv to transform to nchwxx
enum class TransType {
TRANS_PURE_NCHWXX, //!< weight and src all trans to nchwxx
TRANS_HYBIRD_NCHWXX, //!< input is nchw, output is nchwxx
TRANS_NONE, //!< no need trans
};
public:
EnableNchwxxPass(size_t pack_c_size) : m_pack_c_size(pack_c_size) {}
const char* name() const override {
return mgb_cstr_log(m_name.c_str());
}
void set_name(std::string in_name) { m_name = in_name; }
void fill_opr_convert_fun(size_t pack_c_size);
//! make nchw -> nchwxx converter opt pass, pack_c_size is the x, like
//! 4,8,16
static std::unique_ptr<EnableNchwxxPass> make_nchwxx_converter(
size_t pack_c_size);
};
/*!
* \brief convert tensor format from nchw44 to nchw44_dot to speed up
* inference on armv8.2
*/
class EnableNchw44DotPass final : public EnableNchwxxPass {
std::string m_name = "tensor_format_nchw44_dot";
VarNode* on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const override;
public:
EnableNchw44DotPass() : EnableNchwxxPass(4) {}
//! make nchw44 -> nchw44_dot converter opt pass
static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter();
};
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {};
/*!
......
......@@ -2356,7 +2356,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
.rename(name),
dtype);
};
auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param_conv_bias;
param_conv_bias.format = opr::ConvBias::Param::Format::NCHW;
......@@ -2376,7 +2376,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2,
param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto y = opr::TypeCvt::make(conv2, dtype::Float32());
SymbolVar y_opt;
......@@ -2617,4 +2617,86 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//!Hybrid nchw88 mode
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv);
//!channel wise
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}),
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias);
//! group
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}),
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias);
auto shape_of = opr::GetVarShape::make(conv3);
auto subtensor = opr::Subtensor::make(
shape_of, {opr::Subtensor::AxisIndexer::make_interval(
0, x.make_scalar(2), None, x.make_scalar(1))});
opr::Resize::Param param_resize;
param_resize.format = opr::Resize::Param::Format::NCHW;
auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize);
auto mat = mkcvar("mat", {2, 3, 3}),
warp = opr::WarpPerspectiveForward::make(
resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4}));
auto b = mkvar("b", {1, 8, 1, 1}),
elem = opr::Elemwise::make({warp + b},
opr::Elemwise::Param::Mode::RELU);
//! Dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}),
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias);
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}),
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias);
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}),
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw44_dot();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ConvertFormatNCHW44.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
*host_x = *gen({2, 3, 32, 32}, cn);
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册