提交 ee2e2b3c 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix nchwxx optpass of no handle conv_bias opr which with no bias

GitOrigin-RevId: b2b053add464540c22a61e72f078950c18bf92b0
上级 59a9275c
...@@ -1862,7 +1862,8 @@ static inline bool nchw_nchwxx_valid( ...@@ -1862,7 +1862,8 @@ static inline bool nchw_nchwxx_valid(
auto& src_node = new_inp[0]; auto& src_node = new_inp[0];
auto& filter_node = new_inp[1]; auto& filter_node = new_inp[1];
auto dst_node = opr.output(0); auto dst_node = opr.output(0);
if (filter_node->shape().ndim != 4) { //! already transformed or have fuse Z
if (filter_node->shape().ndim != 4 || new_inp.size() == 4) {
return false; return false;
} }
megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm; megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm;
...@@ -1884,7 +1885,8 @@ static inline bool nchw_nchwxx_valid( ...@@ -1884,7 +1885,8 @@ static inline bool nchw_nchwxx_valid(
megdnn::ConvBiasForward::BiasMode bias_mode = megdnn::ConvBiasForward::BiasMode bias_mode =
megdnn::ConvBiasForward::BiasMode::NO_BIAS; megdnn::ConvBiasForward::BiasMode::NO_BIAS;
if (std::is_same<OprType, opr::ConvBiasForward>::value) { if (std::is_same<OprType, opr::ConvBiasForward>::value &&
new_inp.size() > 2) {
TensorShape bias_shape = new_inp[2]->shape(); TensorShape bias_shape = new_inp[2]->shape();
if (bias_shape.ndim == 5) { if (bias_shape.ndim == 5) {
bias_shape = nchwxx_shape_2_nchw_shape(bias_shape); bias_shape = nchwxx_shape_2_nchw_shape(bias_shape);
...@@ -2067,6 +2069,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2067,6 +2069,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
pack_c_size](OperatorNodeBase* opr, pack_c_size](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size() <= 3,
"nchwxx does not support conv_bias fuse Z right now");
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format == mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW, megdnn::param::ConvBias::Format::NCHW,
...@@ -2092,7 +2096,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2092,7 +2096,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
temp_inp[0] = new_src.node(); temp_inp[0] = new_src.node();
} }
//! the bias is nchwxx //! the bias is nchwxx
if (temp_inp[2]->shape().ndim == 5) { if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) {
auto new_bias = auto new_bias =
RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode); RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode);
temp_inp[2] = new_bias.node(); temp_inp[2] = new_bias.node();
...@@ -2102,7 +2106,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2102,7 +2106,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return new_opr; return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2]; *conv_bias_bias = nullptr;
//! filter trans to nchwxx mode //! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5, new_inp[1]->shape().ndim == 5,
...@@ -2117,21 +2121,34 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2117,21 +2121,34 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
src_to_nchwxx_mode); src_to_nchwxx_mode);
conv_bias_src = new_src.node(); conv_bias_src = new_src.node();
} }
//! bias trans to nchwxx mode, bias may be scale //! bias trans to nchwxx mode
if (new_inp[2]->shape().ndim == 4) { if (new_inp.size() > 2) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2], if (new_inp[2]->shape().ndim == 4) {
src_to_nchwxx_mode); auto new_bias = RelayoutPlaceholder::make(
conv_bias_bias = new_bias.node(); new_inp[2], src_to_nchwxx_mode);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
} }
auto new_param = conv_bias_opr.param(); auto new_param = conv_bias_opr.param();
new_param.format = conv_bias_format; new_param.format = conv_bias_format;
mgb_assert(conv_bias_src->shape().ndim == 5 && mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_filter->shape().ndim >= 6, conv_bias_filter->shape().ndim >= 6,
"The conv_bias src dim is not trans to nchwxx"); "The conv_bias src dim is not trans to nchwxx");
auto new_conv_bias_opr = opr::ConvBias::make( SymbolVar new_conv_bias_opr;
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, if (conv_bias_bias) {
conv_bias_opr.execution_policy(), conv_bias_opr.config()); new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5, mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchwxx"); "The conv_bias dst dim is not trans to nchwxx");
...@@ -2139,25 +2156,37 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2139,25 +2156,37 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
} else { } else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2]; *conv_bias_bias = nullptr;
auto new_filter = auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second); RelayoutPlaceholder::make(new_inp[1], is_trans.second);
conv_bias_filter = new_filter.node(); conv_bias_filter = new_filter.node();
//! bias trans to nchwxx mode, bias may be scale //! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) { if (new_inp.size() > 2) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2], if (new_inp[2]->shape().ndim == 4) {
src_to_nchwxx_mode); auto new_bias = RelayoutPlaceholder::make(
conv_bias_bias = new_bias.node(); new_inp[2], src_to_nchwxx_mode);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
} }
mgb_assert(conv_bias_src->shape().ndim == 4 && mgb_assert(conv_bias_src->shape().ndim == 4 &&
conv_bias_filter->shape().ndim == 5); conv_bias_filter->shape().ndim == 5);
mgb_assert((conv_bias_bias->shape().ndim == 5) ||
conv_bias_bias->shape().is_scalar());
auto new_param = conv_bias_opr.param(); auto new_param = conv_bias_opr.param();
new_param.format = conv_bias_format; new_param.format = conv_bias_format;
auto new_conv_bias_opr = opr::ConvBias::make( SymbolVar new_conv_bias_opr;
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, if (conv_bias_bias) {
conv_bias_opr.execution_policy(), conv_bias_opr.config()); new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5, mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchwxx"); "The conv dst dim is not trans to nchwxx");
...@@ -2275,6 +2304,10 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2275,6 +2304,10 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
relayout_inp_to_nchw; relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Argmax::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ImmutableTensor::typeinfo()] = relayout_inp_to_nchw;
} }
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
...@@ -2459,6 +2492,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2459,6 +2492,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size() <= 3,
"nchwxx-dot does not support conv_bias fuse Z right now");
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format == mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW, megdnn::param::ConvBias::Format::NCHW,
...@@ -2489,7 +2524,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2489,7 +2524,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
} }
//! the bias is nchwxx //! the bias is nchwxx
if (temp_inp[2]->shape().ndim == 5) { if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) {
auto new_bias = RelayoutPlaceholder::make( auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW4_TO_NCHW); new_inp[2], RelayoutMode::NCHW4_TO_NCHW);
temp_inp[2] = new_bias.node(); temp_inp[2] = new_bias.node();
...@@ -2499,7 +2534,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2499,7 +2534,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
return new_opr; return new_opr;
} else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) { } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2]; *conv_bias_bias = nullptr;
//! filter trans to nchwxx mode //! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5, new_inp[1]->shape().ndim == 5,
...@@ -2514,21 +2549,34 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2514,21 +2549,34 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
new_inp[0], RelayoutMode::NCHW_TO_NCHW4); new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_src = new_src.node(); conv_bias_src = new_src.node();
} }
//! bias trans to nchwxx mode, bias may be scale //! bias trans to nchwxx mode
if (new_inp[2]->shape().ndim == 4) { if (new_inp.size() > 2) {
auto new_bias = RelayoutPlaceholder::make( if (new_inp[2]->shape().ndim == 4) {
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); auto new_bias = RelayoutPlaceholder::make(
conv_bias_bias = new_bias.node(); new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
} }
auto new_param = conv_bias_opr.param(); auto new_param = conv_bias_opr.param();
new_param.format = is_trans.conv_format; new_param.format = is_trans.conv_format;
mgb_assert(conv_bias_src->shape().ndim == 5 && mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_filter->shape().ndim >= 6, conv_bias_filter->shape().ndim >= 6,
"The conv_bias src dim is not trans to nchwxx"); "The conv_bias src dim is not trans to nchwxx");
auto new_conv_bias_opr = opr::ConvBias::make( SymbolVar new_conv_bias_opr;
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, if (conv_bias_bias) {
conv_bias_opr.execution_policy(), conv_bias_opr.config()); new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5, mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchwxx"); "The conv_bias dst dim is not trans to nchwxx");
...@@ -2536,25 +2584,37 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2536,25 +2584,37 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
} else { } else {
mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX); mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2]; *conv_bias_bias = nullptr;
auto new_filter = RelayoutPlaceholder::make(new_inp[1], auto new_filter = RelayoutPlaceholder::make(new_inp[1],
is_trans.relayout_mod); is_trans.relayout_mod);
conv_bias_filter = new_filter.node(); conv_bias_filter = new_filter.node();
//! bias trans to nchwxx mode, bias may be scale //! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) { if (new_inp.size() > 2) {
auto new_bias = RelayoutPlaceholder::make( if (new_inp[2]->shape().ndim == 4) {
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); auto new_bias = RelayoutPlaceholder::make(
conv_bias_bias = new_bias.node(); new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
conv_bias_bias = new_inp[2];
}
} }
mgb_assert(conv_bias_src->shape().ndim == 4 && mgb_assert(conv_bias_src->shape().ndim == 4 &&
conv_bias_filter->shape().ndim == 5); conv_bias_filter->shape().ndim == 5);
mgb_assert((conv_bias_bias->shape().ndim == 5) ||
conv_bias_bias->shape().is_scalar());
auto new_param = conv_bias_opr.param(); auto new_param = conv_bias_opr.param();
new_param.format = is_trans.conv_format; new_param.format = is_trans.conv_format;
auto new_conv_bias_opr = opr::ConvBias::make( SymbolVar new_conv_bias_opr;
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, if (conv_bias_bias) {
conv_bias_opr.execution_policy(), conv_bias_opr.config()); new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(),
conv_bias_opr.config());
}
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5, mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchwxx"); "The conv dst dim is not trans to nchwxx");
......
...@@ -3009,9 +3009,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { ...@@ -3009,9 +3009,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
//! no supported hybrid nchw44 //! no supported hybrid nchw44
opr::ConvBias::Param param_conv_bias_pad0; opr::ConvBias::Param param_conv_bias_pad0;
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0;
auto b1 = mkcvar("b1", {1, 8, 1, 1});
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1});
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, auto conv1_f1 = opr::ConvBias::make(x, w1_f1, param_conv_bias_pad0, {},
OperatorNodeConfig("conv1_f1")); OperatorNodeConfig("conv1_f1"));
auto conv1_add = conv1_f1 * conv1; auto conv1_add = conv1_f1 * conv1;
...@@ -3263,9 +3262,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ...@@ -3263,9 +3262,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f));
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv_1_2 = opr::ConvBias::make( auto conv_1_2 = opr::ConvBias::make(
conv_1_q8, w1_2, b1_2, param_conv_bias, {}, conv_1_q8, w1_2, param_conv_bias, {},
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}});
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册