提交 ee2e2b3c 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix nchwxx optpass of no handle conv_bias opr which with no bias

GitOrigin-RevId: b2b053add464540c22a61e72f078950c18bf92b0
上级 59a9275c
......@@ -1862,7 +1862,8 @@ static inline bool nchw_nchwxx_valid(
auto& src_node = new_inp[0];
auto& filter_node = new_inp[1];
auto dst_node = opr.output(0);
if (filter_node->shape().ndim != 4) {
//! already transformed or have fuse Z
if (filter_node->shape().ndim != 4 || new_inp.size() == 4) {
return false;
}
megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm;
......@@ -1884,7 +1885,8 @@ static inline bool nchw_nchwxx_valid(
megdnn::ConvBiasForward::BiasMode bias_mode =
megdnn::ConvBiasForward::BiasMode::NO_BIAS;
if (std::is_same<OprType, opr::ConvBiasForward>::value) {
if (std::is_same<OprType, opr::ConvBiasForward>::value &&
new_inp.size() > 2) {
TensorShape bias_shape = new_inp[2]->shape();
if (bias_shape.ndim == 5) {
bias_shape = nchwxx_shape_2_nchw_shape(bias_shape);
......@@ -2067,6 +2069,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
pack_c_size](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size() <= 3,
"nchwxx does not support conv_bias fuse Z right now");
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
......@@ -2092,7 +2096,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
temp_inp[0] = new_src.node();
}
//! the bias is nchwxx
if (temp_inp[2]->shape().ndim == 5) {
if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) {
auto new_bias =
RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode);
temp_inp[2] = new_bias.node();
......@@ -2102,7 +2106,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
*conv_bias_bias = nullptr;
//! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
......@@ -2117,21 +2121,34 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
src_to_nchwxx_mode);
conv_bias_src = new_src.node();
}
//! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2],
src_to_nchwxx_mode);
conv_bias_bias = new_bias.node();
//! bias trans to nchwxx mode
if (new_inp.size() > 2) {
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], src_to_nchwxx_mode);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
}
auto new_param = conv_bias_opr.param();
new_param.format = conv_bias_format;
mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_filter->shape().ndim >= 6,
"The conv_bias src dim is not trans to nchwxx");
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
SymbolVar new_conv_bias_opr;
if (conv_bias_bias) {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchwxx");
......@@ -2139,25 +2156,37 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
} else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
*conv_bias_bias = nullptr;
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
conv_bias_filter = new_filter.node();
//! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2],
src_to_nchwxx_mode);
conv_bias_bias = new_bias.node();
if (new_inp.size() > 2) {
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], src_to_nchwxx_mode);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
}
mgb_assert(conv_bias_src->shape().ndim == 4 &&
conv_bias_filter->shape().ndim == 5);
mgb_assert((conv_bias_bias->shape().ndim == 5) ||
conv_bias_bias->shape().is_scalar());
auto new_param = conv_bias_opr.param();
new_param.format = conv_bias_format;
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
SymbolVar new_conv_bias_opr;
if (conv_bias_bias) {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchwxx");
......@@ -2275,6 +2304,10 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Argmax::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ImmutableTensor::typeinfo()] = relayout_inp_to_nchw;
}
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
......@@ -2459,6 +2492,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size() <= 3,
"nchwxx-dot does not support conv_bias fuse Z right now");
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
......@@ -2489,7 +2524,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
}
//! the bias is nchwxx
if (temp_inp[2]->shape().ndim == 5) {
if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW4_TO_NCHW);
temp_inp[2] = new_bias.node();
......@@ -2499,7 +2534,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
return new_opr;
} else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
*conv_bias_bias = nullptr;
//! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
......@@ -2514,21 +2549,34 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_src = new_src.node();
}
//! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
//! bias trans to nchwxx mode
if (new_inp.size() > 2) {
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
}
auto new_param = conv_bias_opr.param();
new_param.format = is_trans.conv_format;
mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_filter->shape().ndim >= 6,
"The conv_bias src dim is not trans to nchwxx");
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
SymbolVar new_conv_bias_opr;
if (conv_bias_bias) {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchwxx");
......@@ -2536,25 +2584,37 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
} else {
mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
*conv_bias_bias = nullptr;
auto new_filter = RelayoutPlaceholder::make(new_inp[1],
is_trans.relayout_mod);
conv_bias_filter = new_filter.node();
//! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
if (new_inp.size() > 2) {
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
}
mgb_assert(conv_bias_src->shape().ndim == 4 &&
conv_bias_filter->shape().ndim == 5);
mgb_assert((conv_bias_bias->shape().ndim == 5) ||
conv_bias_bias->shape().is_scalar());
auto new_param = conv_bias_opr.param();
new_param.format = is_trans.conv_format;
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
SymbolVar new_conv_bias_opr;
if (conv_bias_bias) {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchwxx");
......
......@@ -3009,9 +3009,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
//! no supported hybrid nchw44
opr::ConvBias::Param param_conv_bias_pad0;
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0;
auto b1 = mkcvar("b1", {1, 8, 1, 1});
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1});
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {},
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, param_conv_bias_pad0, {},
OperatorNodeConfig("conv1_f1"));
auto conv1_add = conv1_f1 * conv1;
......@@ -3263,9 +3262,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f));
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv_1_2 = opr::ConvBias::make(
conv_1_q8, w1_2, b1_2, param_conv_bias, {},
conv_1_q8, w1_2, param_conv_bias, {},
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}});
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册