提交 1b1ad56a 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix warp fusion opt pass

GitOrigin-RevId: a40bbcd71929c1b542acc77d7959a1191217bf62
上级 99b17623
......@@ -676,9 +676,9 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const {
auto on_opr = [&try_warp_nchw_typecvt, &try_warp_nhwc2nchw_typecvt,
&try_warp_nhwc2nchw4_typecvt, &try_warp_nchw2nchw4_typecvt,
&rewriter](OperatorNodeBase* opr) {
if (!try_warp_nchw_typecvt(opr) && !try_warp_nhwc2nchw_typecvt(opr) &&
!try_warp_nhwc2nchw4_typecvt(opr) &&
!try_warp_nchw2nchw4_typecvt(opr)) {
if (!try_warp_nhwc2nchw4_typecvt(opr) &&
!try_warp_nchw2nchw4_typecvt(opr) &&
!try_warp_nchw_typecvt(opr) && !try_warp_nhwc2nchw_typecvt(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
......
......@@ -3723,7 +3723,7 @@ TEST(TestGoptInference, PreProcessCase1) {
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::RelayoutFormat>());
}
TEST(TestGoptInference, WarpAndPreProcessCase) {
TEST(TestGoptInference, WarpAndPreProcessCase0) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("gpu0");
......@@ -3774,7 +3774,57 @@ TEST(TestGoptInference, WarpAndPreProcessCase) {
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.WarpAndPreProcessCase.json"));
"TestGoptInference.WarpAndPreProcessCase0.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
TEST(TestGoptInference, WarpAndPreProcessCase1) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("gpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
size_t n = 1;
size_t c = 3;
size_t h = 16;
size_t w = 16;
auto host_x1 = gen({n, h, w, c}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto mat_host = std::make_shared<HostTensorND>(cn, TensorShape{n, 3, 3},
dtype::Float32());
warp_perspective_mat_gen(*mat_host, n, h, w);
auto mat = opr::Host2DeviceCopy::make(*graph, mat_host).rename("mat");
opr::WarpPerspective::Param warp_param;
warp_param.format = opr::WarpPerspective::Param::Format::NHWC;
auto x_warp =
opr::WarpPerspective::make(x, mat, TensorShape{h, w}, warp_param);
auto x_nchw = opr::Dimshuffle::make(x_warp, {0, 3, 1, 2}, 4, cn);
auto result = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn);
auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_preprocess();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::WarpPerspective>());
ASSERT_EQ(opr::WarpPerspective::Param::Format::NHWC_NCHW,
find_opr<opr::WarpPerspective>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.WarpAndPreProcessCase1.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
......
......@@ -93,6 +93,7 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
dest.shape[4] = 4;
break;
case Param::Format::NHWC_NCHW:
dest.ndim = 4;
dest[0] = matshp[0];
dest.shape[1] = imgshp.shape[3];
dest.shape[2] = oshp2d.shape[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册