提交 a94fb7b1 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix convert format opt pass by ensuring the replacement won't...

fix(mgb/gopt): fix convert format opt pass by ensuring the replacement won't change output var's channel

GitOrigin-RevId: 170994935a0f3920e1537171d0f789a5bfaa238a
上级 2c2caf33
......@@ -1205,17 +1205,33 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
TensorFormat::Type::IMAGE2D_PACK4);
return ret;
};
auto replace_resize_opr = [](OperatorNodeBase* opr,
/* This helper function guarantees the format convert pass won't change
* output var's channel. Changing output's channel will cause channel
* mismatch problem for replacing conv/conv_bias operator.
*/
auto replace_helper = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> OperatorNodeBase* {
auto&& new_shp = new_inp[0]->shape();
size_t inp_channel = new_shp[1];
if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) {
auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
return new_opr;
}
return nullptr;
};
auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
mgb_assert(resize_opr.param().format ==
megdnn::param::Resize::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode* inp = nullptr;
if (new_inp[0]->shape().ndim == 4) {
// new input src is NCHW
auto param = megdnn::param::RelayoutFormat();
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
auto rf = opr::RelayoutFormat::make(new_inp[0], param);
......@@ -1235,9 +1251,13 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_resize_opr.node()->owner_opr();
};
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
auto replace_warp_perspective_opr = [replace_helper](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
mgb_assert(warp_opr.param().format ==
megdnn::param::WarpPerspective::Format::NCHW,
......@@ -1273,9 +1293,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_warp_opr.node()->owner_opr();
};
auto replace_warp_affine_opr = [](OperatorNodeBase* opr,
auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
mgb_assert(warp_opr.param().format ==
megdnn::param::WarpAffine::Format::NCHW,
......@@ -1303,9 +1326,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_warp_opr.node()->owner_opr();
};
auto replace_pooling_opr = [](OperatorNodeBase* opr,
auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
mgb_assert(pooling_opr.param().format ==
megdnn::param::Pooling::Format::NCHW,
......@@ -1344,24 +1370,6 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
mgb_assert(new_inp[i]->shape().ndim == 5 &&
new_inp[i]->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4);
// Oprs which will change the shape of input like concat,
// reshape etc. should not be used after cd4 convertion padding,
// due to the padding info will be lost and we cannot recover
// the origin unpadded data. For example, concat two tensors of
// shape {1, 6, 128, 128}, if both tensors convert to cd4 then
// the channel will be 8, and the result of concat channel will
// be 16, but there will be 2 padding zeros in the middle of
// channel axis, which will cause problems in succeding opr.
if (opr->dyn_typeinfo() == opr::Concat::typeinfo()) {
auto concat = try_cast_as_op<opr::Concat>(opr);
mgb_assert(
!(concat->param().axis == 1 &&
concat->input(i)->shape()[1] % 4 != 0),
"We cannot concat tensor in channel axis which has "
"been padded, as it may lost padding pos if we "
"pass "
"the output to conv etc.");
}
auto param = megdnn::param::RelayoutFormat();
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
auto rf = opr::RelayoutFormat::make(new_inp[i], param);
......
......@@ -1045,13 +1045,18 @@ TEST(TestGoptInference, ConvertFormatPadIC) {
param.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {12, 12, 3, 3});
auto y = opr::Convolution::make(concat, w1, param);
MGB_MARK_USED_VAR(y);
SymbolVar y_opt;
ASSERT_THROW(unpack_vector(gopt::optimize_for_inference(
{y}, gopt::OptimizeForInferenceOptions{}
.enable_use_nhwcd4()),
y_opt),
AssertionError);
unpack_vector(
gopt::optimize_for_inference(
{y},
gopt::OptimizeForInferenceOptions{}.enable_use_nhwcd4()),
y_opt);
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertBatchNormPass) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册