From 26634db7a8d9c5c1faa1c2b06f3a34006e835d46 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 28 Oct 2021 19:26:13 +0800 Subject: [PATCH] fix(dnn): support relayout for non-contigous layout GitOrigin-RevId: 44a0adddbaf8b23c744a50f8e7f58a7c1190df1a --- dnn/src/common/relayout.cpp | 18 +++++++++++++++++- dnn/test/aarch64/relayout.cpp | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/dnn/src/common/relayout.cpp b/dnn/src/common/relayout.cpp index 14a400f5f..63fd53ccc 100644 --- a/dnn/src/common/relayout.cpp +++ b/dnn/src/common/relayout.cpp @@ -29,17 +29,24 @@ bool is_transpose_single( * assuming contig layout is: * shape: b, m, n, c * stride: mnc, nc, c, 1 + * assuming non-contig layout is: + * shape: b, m, n, c + * stride: m*stride_m*c, stride_m*c, c, 1 * * then given layout should be: * shape: b, n, m, c * stride: mnc, c, nc, 1 + * non-contig stride: m*stride_m*c, c, stride_m*c, 1 * * if c == 1: * shape: b, n, m * stride: mn, 1, n + * non-contig stride: m*stride_m, 1, stride_m + * * if b == 1: * shape: n, m, c * stride: c, nc, 1 + * non-contig stride: c, stride_m*c, 1 * * if b == 1 && c == 1: * shape: n, m @@ -65,7 +72,16 @@ bool is_transpose_single( p.n = layout[1]; p.m = layout[2]; p.c = 1; - return strd(2, p.n) && strd(0, p.m * p.n); + + if (strd(2, p.n) && strd(0, p.m * p.n)) { + return true; + } else if ( + allow_no_contig && (size_t)(layout.stride[2]) >= p.n && + strd(0, p.m * (size_t)(layout.stride[2])) && strd(1, 1)) { + p.stride_m = layout.stride[2]; + return true; + } + return false; } if (strd(2, 1)) { // b == 1 diff --git a/dnn/test/aarch64/relayout.cpp b/dnn/test/aarch64/relayout.cpp index 31e10eb02..2383420b0 100644 --- a/dnn/test/aarch64/relayout.cpp +++ b/dnn/test/aarch64/relayout.cpp @@ -41,6 +41,20 @@ TEST_F(AARCH64, Relayout) { } } +TEST_F(AARCH64, RelayoutNonContig) { + Checker checker(handle()); + std::vector<::megdnn::DType> dtype_vec; + dtype_vec.push_back(dtype::Float32()); + dtype_vec.push_back(dtype::Int16()); + dtype_vec.push_back(dtype::Uint16()); + dtype_vec.push_back(dtype::Int8()); + for (auto dtype : dtype_vec) { + TensorLayout src({4, 90, 15, 29}, {41760, 1, 2784, 96}, dtype); + TensorLayout dst({4, 90, 15, 29}, {39150, 435, 29, 1}, dtype); + checker.execl({src, dst}); + } +} + TEST_F(AARCH64, RelayoutBig) { Checker checker(handle()); ConsecutiveRNG rng; -- GitLab