提交 eab7ab05 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(gopt): gen nchw_nchw44 when kernel is optimized

GitOrigin-RevId: 89083be200041dcbcbe12d64f9c67667bcf101b2
上级 777f3ea9
......@@ -1811,7 +1811,15 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
}
return new_var;
}
//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv
static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic,
const size_t pack_c_size, const size_t fh,
const size_t fw, const size_t stride_h,
const size_t stride_w) {
return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw &&
stride_h == stride_w && (stride_h == 1 || stride_h == 2) &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
}
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestFilterResult = std::pair<TransType, RelayoutMode>;
......@@ -1848,15 +1856,19 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan,
hybrid_nchw_nchwxx](
const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> TestFilterResult {
const VarNode* filter, const size_t stride_h,
const size_t stride_w) -> TestFilterResult {
TestFilterResult ret{TransType::TRANS_NONE, {}};
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
size_t IC = filter->shape()[1];
size_t OC = filter->shape()[0];
size_t IC = filter->shape()[1];
size_t FH = filter->shape()[2];
size_t FW = filter->shape()[3];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX;
ret.second = weight_to_nchwxx_mode_dense;
} else if (IC < pack_c_size && OC % pack_c_size == 0) {
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h,
stride_w)) {
ret.first = TransType::TRANS_HYBIRD_NCHWXX;
ret.second = hybrid_nchw_nchwxx;
}
......@@ -1883,7 +1895,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
mgb_assert(conv_opr.param().format ==
megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]);
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1],
conv_opr.param().stride_h,
conv_opr.param().stride_w);
//! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
......@@ -1957,8 +1971,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
auto is_trans =
test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]);
auto is_trans = test_trans_nchwxx(
conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w);
//! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
......@@ -2199,17 +2214,21 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
constexpr size_t pack_c_size = 4_z;
auto test_trans_nchw44_dot =
[](const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> TestTransResult {
const VarNode* filter, const size_t stride_h,
const size_t stride_w) -> TestTransResult {
TestTransResult ret{TransType::TRANS_NONE, {}, {}};
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
size_t IC = filter->shape()[1];
size_t OC = filter->shape()[0];
size_t IC = filter->shape()[1];
size_t FH = filter->shape()[2];
size_t FW = filter->shape()[3];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} else if (IC < pack_c_size && OC % pack_c_size == 0) {
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h,
stride_w)) {
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
......@@ -2241,8 +2260,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to "
"NCHW44_DOT");
auto is_trans =
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]);
auto is_trans = test_trans_nchw44_dot(
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
conv_opr.param().stride_w);
//! can not trans to nchwxx
if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
......@@ -2315,8 +2335,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
auto is_trans =
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]);
auto is_trans = test_trans_nchw44_dot(
conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w);
//! can not trans to nchwxx
if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
......
......@@ -2976,11 +2976,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//! Hybrid nchw88 mode
//! Hybrid nchw44 mode
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
opr::ConvBias::Param param_conv_bias_stride4;
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4;
auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv);
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}),
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv);
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4);
//! channel wise
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
......@@ -3015,22 +3020,35 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}),
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias);
SymbolVar y_opt;
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
unpack_vector(gopt::optimize_for_inference(
{y, conv1, conv1_f4, conv1_s4, conv2}, options),
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
find_opr<opr::Convolution>(conv1_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::ConvBias>(conv1_s4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::Convolution>(conv1_f4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
find_opr<opr::ConvBias>(conv2_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
graph->compile({{y_opt, {}}, {conv2, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ConvertFormatNCHW44.json"));
HostTensorND host_y_opt, host_y;
HostTensorND host_conv1;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
make_callback_copy(y_opt, host_y_opt),
make_callback_copy(conv1, host_conv1)});
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
......@@ -3140,11 +3158,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//! Hybrid nchw88 mode
//! Hybrid nchw44 mode
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
opr::ConvBias::Param param_conv_bias_stride4;
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4;
auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv);
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}),
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv);
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4);
//! channel wise
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
......@@ -3179,11 +3202,19 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}),
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias);
SymbolVar y_opt;
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw44_dot();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4},
options),
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(conv1_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::ConvBias>(conv1_s4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::Convolution>(conv1_f4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册