提交 c20d4cc6 编写于 作者: M Megvii Engine Team

feat(dnn): fix opt pass nchw44 can not dump resnet

GitOrigin-RevId: 28e5c37f53349d482b191751923b5a4b05b0633d
上级 3dbac4f4
......@@ -1815,6 +1815,15 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
return new_var;
}
static inline TensorShape nchwxx_shape_2_nchw_shape(
const TensorShape& origin_shape) {
mgb_assert(origin_shape.ndim == 5);
TensorShape result = origin_shape;
result[1] *= result[4];
result.ndim = 4;
return result;
}
template <typename OprType>
static inline bool nchw_nchwxx_valid(
const OprType& opr, const VarNodeArray& new_inp, const size_t pack_size,
......@@ -1847,7 +1856,10 @@ static inline bool nchw_nchwxx_valid(
megdnn::ConvBiasForward::BiasMode bias_mode =
megdnn::ConvBiasForward::BiasMode::NO_BIAS;
if (std::is_same<OprType, opr::ConvBiasForward>::value) {
auto& bias_shape = new_inp[2]->shape();
TensorShape bias_shape = new_inp[2]->shape();
if (bias_shape.ndim == 5) {
bias_shape = nchwxx_shape_2_nchw_shape(bias_shape);
}
if (bias_shape.ndim == 0) {
bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS;
} else if (bias_shape.eq_shape(dst_node->shape())) {
......
......@@ -3069,12 +3069,18 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
//! Dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}),
auto w4 = mkcvar("w4", {16, 32, 3, 3}), b4 = mkcvar("b4", {1, 16, 1, 1}),
conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {},
OperatorNodeConfig("conv4"));
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}),
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {},
auto w4_1 = mkcvar("w4_1", {16, 32, 1, 1}),
b4_1 = mkcvar("b4_1", {2, 16, 4, 4}),
conv4_1 =
opr::ConvBias::make(conv3_3, w4_1, b4_1, param_conv_bias_pad0,
{}, OperatorNodeConfig("conv4_1"));
auto conv4_add = conv4 + conv4_1;
auto w5 = mkcvar("w5", {6, 16, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}),
conv5 = opr::ConvBias::make(conv4_add, w5, b5, param_conv_bias, {},
OperatorNodeConfig("conv5"));
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}),
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {},
......@@ -3082,6 +3088,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity();
options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册