提交 a94fb7b1 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix convert format opt pass by ensuring the replacement won't...

fix(mgb/gopt): fix convert format opt pass by ensuring the replacement won't change output var's channel

GitOrigin-RevId: 170994935a0f3920e1537171d0f789a5bfaa238a
上级 2c2caf33
...@@ -1205,17 +1205,33 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1205,17 +1205,33 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
TensorFormat::Type::IMAGE2D_PACK4); TensorFormat::Type::IMAGE2D_PACK4);
return ret; return ret;
}; };
/* This helper function guarantees the format convert pass won't change
auto replace_resize_opr = [](OperatorNodeBase* opr, * output var's channel. Changing output's channel will cause channel
* mismatch problem for replacing conv/conv_bias operator.
*/
auto replace_helper = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> OperatorNodeBase* {
auto&& new_shp = new_inp[0]->shape();
size_t inp_channel = new_shp[1];
if (new_shp.eq_shape(opr->input(0)->shape())&& inp_channel % 4 != 0) {
auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
return new_opr;
}
return nullptr;
};
auto replace_resize_opr = [replace_helper](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>(); auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
mgb_assert(resize_opr.param().format == mgb_assert(resize_opr.param().format ==
megdnn::param::Resize::Format::NCHW, megdnn::param::Resize::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NHWCD4"); "ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode* inp = nullptr; VarNode* inp = nullptr;
if (new_inp[0]->shape().ndim == 4) { if (new_inp[0]->shape().ndim == 4) {
// new input src is NCHW
auto param = megdnn::param::RelayoutFormat(); auto param = megdnn::param::RelayoutFormat();
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
auto rf = opr::RelayoutFormat::make(new_inp[0], param); auto rf = opr::RelayoutFormat::make(new_inp[0], param);
...@@ -1235,9 +1251,13 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1235,9 +1251,13 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_resize_opr.node()->owner_opr(); return new_resize_opr.node()->owner_opr();
}; };
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, auto replace_warp_perspective_opr = [replace_helper](
const VarNodeArray& new_inp) { OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>(); auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
mgb_assert(warp_opr.param().format == mgb_assert(warp_opr.param().format ==
megdnn::param::WarpPerspective::Format::NCHW, megdnn::param::WarpPerspective::Format::NCHW,
...@@ -1273,9 +1293,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1273,9 +1293,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_warp_opr.node()->owner_opr(); return new_warp_opr.node()->owner_opr();
}; };
auto replace_warp_affine_opr = [](OperatorNodeBase* opr, auto replace_warp_affine_opr = [replace_helper](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>(); auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
mgb_assert(warp_opr.param().format == mgb_assert(warp_opr.param().format ==
megdnn::param::WarpAffine::Format::NCHW, megdnn::param::WarpAffine::Format::NCHW,
...@@ -1303,9 +1326,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1303,9 +1326,12 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return new_warp_opr.node()->owner_opr(); return new_warp_opr.node()->owner_opr();
}; };
auto replace_pooling_opr = [](OperatorNodeBase* opr, auto replace_pooling_opr = [replace_helper](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
return opr_shallow_copy;
}
auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
mgb_assert(pooling_opr.param().format == mgb_assert(pooling_opr.param().format ==
megdnn::param::Pooling::Format::NCHW, megdnn::param::Pooling::Format::NCHW,
...@@ -1344,24 +1370,6 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1344,24 +1370,6 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
mgb_assert(new_inp[i]->shape().ndim == 5 && mgb_assert(new_inp[i]->shape().ndim == 5 &&
new_inp[i]->format().type() == new_inp[i]->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4); TensorFormat::Type::IMAGE2D_PACK4);
// Oprs which will change the shape of input like concat,
// reshape etc. should not be used after cd4 convertion padding,
// due to the padding info will be lost and we cannot recover
// the origin unpadded data. For example, concat two tensors of
// shape {1, 6, 128, 128}, if both tensors convert to cd4 then
// the channel will be 8, and the result of concat channel will
// be 16, but there will be 2 padding zeros in the middle of
// channel axis, which will cause problems in succeding opr.
if (opr->dyn_typeinfo() == opr::Concat::typeinfo()) {
auto concat = try_cast_as_op<opr::Concat>(opr);
mgb_assert(
!(concat->param().axis == 1 &&
concat->input(i)->shape()[1] % 4 != 0),
"We cannot concat tensor in channel axis which has "
"been padded, as it may lost padding pos if we "
"pass "
"the output to conv etc.");
}
auto param = megdnn::param::RelayoutFormat(); auto param = megdnn::param::RelayoutFormat();
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
auto rf = opr::RelayoutFormat::make(new_inp[i], param); auto rf = opr::RelayoutFormat::make(new_inp[i], param);
......
...@@ -1045,13 +1045,18 @@ TEST(TestGoptInference, ConvertFormatPadIC) { ...@@ -1045,13 +1045,18 @@ TEST(TestGoptInference, ConvertFormatPadIC) {
param.sparse = opr::Convolution::Param::Sparse::DENSE; param.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {12, 12, 3, 3}); auto w1 = mkcvar("w1", {12, 12, 3, 3});
auto y = opr::Convolution::make(concat, w1, param); auto y = opr::Convolution::make(concat, w1, param);
MGB_MARK_USED_VAR(y);
SymbolVar y_opt; SymbolVar y_opt;
ASSERT_THROW(unpack_vector(gopt::optimize_for_inference( unpack_vector(
{y}, gopt::OptimizeForInferenceOptions{} gopt::optimize_for_inference(
.enable_use_nhwcd4()), {y},
y_opt), gopt::OptimizeForInferenceOptions{}.enable_use_nhwcd4()),
AssertionError); y_opt);
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
} }
TEST(TestGoptInference, ConvertBatchNormPass) { TEST(TestGoptInference, ConvertBatchNormPass) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册