diff --git a/dnn/include/megdnn/oprs/cv.h b/dnn/include/megdnn/oprs/cv.h index 55306f3276feccf36e82ad590ab8ea16b9948597..0c05da3258bcdbcf5304ac25ee140370a191a5d9 100644 --- a/dnn/include/megdnn/oprs/cv.h +++ b/dnn/include/megdnn/oprs/cv.h @@ -197,7 +197,11 @@ public: protected: //! get origin coord - std::pair get_origin_coord(float scale, int size, int idx, bool cubic=false); + std::pair get_cubic_coord(float scale, int idx); + + std::tuple get_nearest_linear_coord( + InterpolationMode imode, float scale, int size, int idx); + //! get nearest index in src int get_nearest_src(float scale, int size, int idx); diff --git a/dnn/src/arm_common/resize/opr_impl.cpp b/dnn/src/arm_common/resize/opr_impl.cpp index 3f448ae7c4e1c1f4a9dd908486c208cc5838442b..9d145a836fc087ce18a2f49f53b143bcf9d26fc5 100644 --- a/dnn/src/arm_common/resize/opr_impl.cpp +++ b/dnn/src/arm_common/resize/opr_impl.cpp @@ -6,12 +6,14 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/arm_common/resize/opr_impl.h" #include "src/arm_common/handle.h" #include "src/arm_common/resize/resize_cv.h" +#include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; using namespace arm_common; @@ -19,9 +21,58 @@ using namespace arm_common; void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); - if (param().format == param::Resize::Format::NCHW || - (src.layout[3] != 1 && src.layout[3] != 3) || - !is_nhwc_contig_wc(src.layout)) { + + if (param().format == param::Resize::Format::NCHW44 || + param().format == param::Resize::Format::NCHW88) { + bool is_contiguous = + src.layout.is_contiguous() && dst.layout.is_contiguous(); + bool dtype_same = src.layout.dtype == dst.layout.dtype; + bool nchw44_enable = param().format == param::Resize::Format::NCHW44 && + src.layout.dtype == dtype::Float32(); + bool nchw88_enable = + param().format == param::Resize::Format::NCHW88 && + DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false); + bool interp_supported = + param().imode == + param::Resize::InterpolationMode::INTER_NEAREST || + param().imode == param::Resize::InterpolationMode::INTER_LINEAR; + bool is_upsample2 = + param().imode == + param::Resize::InterpolationMode::INTER_NEAREST && + src.layout.shape[2] * 2 == dst.layout.shape[2] && + src.layout.shape[3] * 2 == dst.layout.shape[3]; + bool need_fallback = !is_contiguous || !dtype_same || + !interp_supported || + (!nchw44_enable && !nchw88_enable); + + if (need_fallback) { + fallback::ResizeImpl::exec(src, dst, workspace); + } else if (nchw44_enable) { + auto kern_param = KernParam::from_tensors( + param().format, param().imode, src, dst, workspace); + if (is_upsample2) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + kern_nearest_upsample2_pack_simd_width(src, dst)); + } else { + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw44_fp32(kern_param)); + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } else if (nchw88_enable) { + auto kern_param = KernParam::from_tensors( + param().format, param().imode, src, dst, workspace); + if (is_upsample2) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + kern_nearest_upsample2_pack_simd_width(src, dst)); + } else { + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw88_fp16(kern_param)); + } +#endif + } else { + fallback::ResizeImpl::exec(src, dst, workspace); + } + } else if (param().format == param::Resize::Format::NCHW || + (src.layout[3] != 1 && src.layout[3] != 3) || + !is_nhwc_contig_wc(src.layout)) { fallback::ResizeImpl::exec(src, dst, workspace); } else { megdnn_assert(param().format == param::Resize::Format::NHWC, @@ -30,4 +81,143 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, } } +template +void ResizeImpl::kern_nchw44_fp32(const KernParam& kern_param) { + UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); + float scale_h = static_cast(OH) / IH; + float scale_w = static_cast(OW) / IW; + + for (size_t n = 0; n < N; ++n) { + for (size_t c = 0; c < C / 4; ++c) { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; + + std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( + kern_param.imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( + kern_param.imode, scale_w, IW, ow); + +#define SRC_ADDRESS(ih, iw) \ + (sptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 4) +#define DST_ADDRESS(oh, ow) \ + (dptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 4) + float32x4_t r0 = vld1q_f32(SRC_ADDRESS(ih0, iw0)); + float32_t a0 = ah0 * aw0; + float32x4_t r1 = vld1q_f32(SRC_ADDRESS(ih0, iw1)); + float32_t a1 = ah0 * aw1; + float32x4_t r2 = vld1q_f32(SRC_ADDRESS(ih1, iw0)); + float32_t a2 = ah1 * aw0; + float32x4_t r3 = vld1q_f32(SRC_ADDRESS(ih1, iw1)); + float32_t a3 = ah1 * aw1; + + r0 = vmulq_n_f32(r0, a0); +#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) + r0 = vfmaq_n_f32(r0, r1, a1); + r0 = vfmaq_n_f32(r0, r2, a2); + r0 = vfmaq_n_f32(r0, r3, a3); +#else + r0 = vaddq_f32(r0, vmulq_n_f32(r1, a1)); + r0 = vaddq_f32(r0, vmulq_n_f32(r2, a2)); + r0 = vaddq_f32(r0, vmulq_n_f32(r3, a3)); +#endif + + vst1q_f32(DST_ADDRESS(oh, ow), r0); +#undef SRC_ADDRESS +#undef DST_ADDRESS + } + } + } + } +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template +void ResizeImpl::kern_nchw88_fp16(const KernParam& kern_param) { + UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); + float scale_h = static_cast(OH) / IH; + float scale_w = static_cast(OW) / IW; + const float16_t* src_ptr = reinterpret_cast(sptr); + float16_t* dst_ptr = reinterpret_cast(dptr); + + for (size_t n = 0; n < N; ++n) { + for (size_t c = 0; c < C / 8; ++c) { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; + + std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( + kern_param.imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( + kern_param.imode, scale_w, IW, ow); + +#define SRC_ADDRESS(ih, iw) \ + (src_ptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 8) +#define DST_ADDRESS(oh, ow) \ + (dst_ptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 8) + float16x8_t r0 = vld1q_f16(SRC_ADDRESS(ih0, iw0)); + float32_t a0 = ah0 * aw0; + float16x8_t r1 = vld1q_f16(SRC_ADDRESS(ih0, iw1)); + float32_t a1 = ah0 * aw1; + float16x8_t r2 = vld1q_f16(SRC_ADDRESS(ih1, iw0)); + float32_t a2 = ah1 * aw0; + float16x8_t r3 = vld1q_f16(SRC_ADDRESS(ih1, iw1)); + float32_t a3 = ah1 * aw1; + + r0 = vmulq_n_f16(r0, a0); +#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) + r0 = vfmaq_n_f16(r0, r1, a1); + r0 = vfmaq_n_f16(r0, r2, a2); + r0 = vfmaq_n_f16(r0, r3, a3); +#else + r0 = vaddq_f16(r0, vmulq_n_f16(r1, a1)); + r0 = vaddq_f16(r0, vmulq_n_f16(r2, a2)); + r0 = vaddq_f16(r0, vmulq_n_f16(r3, a3)); +#endif + + vst1q_f16(DST_ADDRESS(oh, ow), r0); +#undef SRC_ADDRESS +#undef DST_ADDRESS + } + } + } + } +} +#endif + +void ResizeImpl::kern_nearest_upsample2_pack_simd_width( + _megdnn_tensor_in src, _megdnn_tensor_out dst) { + const uint8_t* src_ptr = reinterpret_cast(src.raw_ptr); + uint8_t* dst_ptr = reinterpret_cast(dst.raw_ptr); + + size_t S = 2; + size_t N = src.layout.shape[0]; + size_t IC = src.layout.shape[1]; + size_t IH = src.layout.shape[2]; + size_t IW = src.layout.shape[3]; + size_t OH = dst.layout.shape[2]; + size_t OW = dst.layout.shape[3]; + + for (size_t i = 0; i < N * IC; ++i) { + for (size_t ih = 0; ih < IH; ++ih) { + for (size_t iw = 0; iw < IW; ++iw) { + size_t oh = ih * S; + size_t ow = iw * S; + uint8x16_t r0 = vld1q_u8(src_ptr + i * IH * IW * 16 + + ih * IW * 16 + iw * 16); + + for (size_t fh = 0; fh < S; ++fh) { + for (size_t fw = 0; fw < S; ++fw) { + vst1q_u8(dst_ptr + i * OH * OW * 16 + + (oh + fh) * OW * 16 + (ow + fw) * 16, + r0); + } + } + } + } + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/resize/opr_impl.h b/dnn/src/arm_common/resize/opr_impl.h index 59ac687bc5d222efc7643fef08cc1f47c884c42b..f40f252123bcdf9ea499538f0ac4c405a3cae488 100644 --- a/dnn/src/arm_common/resize/opr_impl.h +++ b/dnn/src/arm_common/resize/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -25,6 +26,16 @@ public: const TensorLayout&) override { return 0; } + +private: + template + void kern_nchw44_fp32(const KernParam& kern_param); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + template + void kern_nchw88_fp16(const KernParam& kern_param); +#endif + void kern_nearest_upsample2_pack_simd_width(_megdnn_tensor_in src, + _megdnn_tensor_out dst); }; } // namespace arm_common diff --git a/dnn/src/common/resize.cpp b/dnn/src/common/resize.cpp index f9c78602c32692da10a871324528f796bd109cf3..d7821c227bc7c3dd46398622ccf706522914de0b 100644 --- a/dnn/src/common/resize.cpp +++ b/dnn/src/common/resize.cpp @@ -40,11 +40,29 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8); megdnn_assert(src.shape[4] == 4); megdnn_assert(dst.shape[4] == 4); + } else if (param().format == Param::Format::NCHW44) { + megdnn_assert(src.ndim == 5); + megdnn_assert(src.shape[4] == 4); + megdnn_assert(dst.shape[4] == 4); + megdnn_assert(param().imode == + param::Resize::InterpolationMode::INTER_LINEAR || + param().imode == + param::Resize::InterpolationMode::INTER_NEAREST); + } else if (param().format == Param::Format::NCHW88) { + megdnn_assert(src.ndim == 5); + megdnn_assert(src.shape[4] == 8); + megdnn_assert(dst.shape[4] == 8); + megdnn_assert(param().imode == + param::Resize::InterpolationMode::INTER_LINEAR || + param().imode == + param::Resize::InterpolationMode::INTER_NEAREST); } else { megdnn_assert(param().format == Param::Format::NHWCD4, "invalid resize tensor format"); megdnn_assert(param().imode == - param::Resize::InterpolationMode::INTER_LINEAR); + param::Resize::InterpolationMode::INTER_LINEAR || + param().imode == + param::Resize::InterpolationMode::INTER_NEAREST); megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str()); } } @@ -67,24 +85,39 @@ void ResizeBackward::check_exec(const TensorLayout& diff, "Backward resize only supports Float32 and NCHW."); } -std::pair ResizeBase::get_origin_coord(float scale, int size, - int idx, bool cubic) { - //! copy from resize_cv.cpp +std::pair ResizeBase::get_cubic_coord(float scale, int idx) { float alpha = (idx + 0.5f) / scale - 0.5f; int origin_idx = static_cast(floor(alpha)); alpha -= origin_idx; - if (!cubic) { - if (origin_idx < 0) { - origin_idx = 0; - alpha = 0; - } else if (origin_idx + 1 >= size) { - origin_idx = size - 2; - alpha = 1; - } - } return {alpha, origin_idx}; } +std::tuple ResizeBase::get_nearest_linear_coord( + InterpolationMode imode, float scale, int size, int idx) { + if (size == 1) { + return std::make_tuple(1.0f, 0, 0.0f, 0); + } + + float alpha = (idx + 0.5f) / scale - 0.5f; + int origin_idx = static_cast(floor(alpha)); + alpha -= origin_idx; + + if (imode == InterpolationMode::INTER_NEAREST) { + origin_idx = get_nearest_src(scale, size, idx); + alpha = 0; + } + + if (origin_idx < 0) { + origin_idx = 0; + alpha = 0; + } else if (origin_idx + 1 >= size) { + origin_idx = size - 2; + alpha = 1; + } + + return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1); +} + int ResizeBase::get_nearest_src(float scale, int size, int idx) { return std::min(static_cast(idx / scale), size - 1); } diff --git a/dnn/src/fallback/resize/opr_impl.cpp b/dnn/src/fallback/resize/opr_impl.cpp index 207c892f171f3de39945ed6484e81cf224936eca..777560d66703284863963f860f970732e8d55b87 100644 --- a/dnn/src/fallback/resize/opr_impl.cpp +++ b/dnn/src/fallback/resize/opr_impl.cpp @@ -6,13 +6,14 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/fallback/resize/opr_impl.h" #include -#include "src/fallback/handle.h" #include "src/common/rounding_converter.cuh" +#include "src/fallback/handle.h" using namespace megdnn; using namespace fallback; @@ -30,37 +31,36 @@ void ResizeImpl::kern_fallback(const KernParam& kern_param) { float scale_h = static_cast(OH) / IH; float scale_w = static_cast(OW) / IW; - auto build_table = [this](float scale, int isize, - int osize) -> std::vector> { - std::vector> table; - rep(i, osize) { table.push_back(get_origin_coord(scale, isize, i)); } + auto build_table = [this](InterpolationMode imode, float scale, int isize, + int osize) { + std::vector> table; + rep(i, osize) { + table.push_back(get_nearest_linear_coord(imode, scale, isize, i)); + } return table; }; - auto table_h = build_table(scale_h, IH, OH); - auto table_w = build_table(scale_w, IW, OW); + auto table_h = build_table(kern_param.imode, scale_h, IH, OH); + auto table_w = build_table(kern_param.imode, scale_w, IW, OW); rep(n, N) { rep(c, static_cast(C)) { rep(oh, OH) { - auto coord_h = table_h[oh]; - float alphah = coord_h.first; - int ih0 = coord_h.second; - int ih1 = ih0 + 1; + float ah0, ah1, aw0, aw1; + int ih0, ih1, iw0, iw1; + + std::tie(ah0, ih0, ah1, ih1) = table_h[oh]; rep(ow, OW) { - auto coord_w = table_w[ow]; - float alphaw = coord_w.first; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; + std::tie(aw0, iw0, aw1, iw1) = table_w[ow]; dptr[c * OH * OW + oh * OW + ow] = output_converter( - sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * - (1.0f - alphaw) * (1.0f - alphah) + - sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * - alphaw * (1.0f - alphah) + - sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * - (1.0f - alphaw) * alphah + - sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * - alphaw * alphah); + sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * ah0 * + aw0 + + sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * ah0 * + aw1 + + sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * ah1 * + aw0 + + sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * ah1 * + aw1); } } } @@ -76,35 +76,31 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam& kern_param) { float scale_h = static_cast(OH) / IH; float scale_w = static_cast(OW) / IW; - auto build_table = [this](float scale, int isize, - int osize) -> std::vector> { - std::vector> table; - rep(i, osize) { table.push_back(get_origin_coord(scale, isize, i)); } + auto build_table = [this](InterpolationMode imode, float scale, int isize, + int osize) { + std::vector> table; + rep(i, osize) { + table.push_back(get_nearest_linear_coord(imode, scale, isize, i)); + } return table; }; - auto table_h = build_table(scale_h, IH, OH); - auto table_w = build_table(scale_w, IW, OW); + auto table_h = build_table(kern_param.imode, scale_h, IH, OH); + auto table_w = build_table(kern_param.imode, scale_w, IW, OW); rep(n, N) { rep(oh, OH) { - auto coord_h = table_h[oh]; - float alphah = coord_h.first; - int ih0 = coord_h.second; - int ih1 = ih0 + 1; + float ah0, ah1, aw0, aw1; + int ih0, ih1, iw0, iw1; + + std::tie(ah0, ih0, ah1, ih1) = table_h[oh]; rep(ow, OW) { - auto coord_w = table_w[ow]; - float alphaw = coord_w.first; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; + std::tie(aw0, iw0, aw1, iw1) = table_w[ow]; rep(c, C) { dptr[(oh * OW + ow) * C + c] = output_converter( - sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) * - (1.0f - alphah) + - sptr[(ih0 * IW + iw1) * C + c] * alphaw * - (1.0f - alphah) + - sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) * - alphah + - sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah); + sptr[(ih0 * IW + iw0) * C + c] * ah0 * aw0 + + sptr[(ih0 * IW + iw1) * C + c] * ah0 * aw1 + + sptr[(ih1 * IW + iw0) * C + c] * ah1 * aw0 + + sptr[(ih1 * IW + iw1) * C + c] * ah1 * aw1); } } } @@ -117,6 +113,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); if (param().format == param::Resize::Format::NCHW4 || + param().format == param::Resize::Format::NCHW44 || + param().format == param::Resize::Format::NCHW88 || (param().format == param::Resize::Format::NCHW && param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) { naive::ResizeImpl::exec(src, dst, workspace); @@ -125,12 +123,12 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, if ((param().format == param::Resize::Format::NCHW || (src.layout[3] != 1 && src.layout[3] != 3)) || (param().imode == param::Resize::InterpolationMode::LINEAR)) { -#define cb(dt, ct) \ - case DTypeTrait
::enumv: { \ - auto kparam = KernParam::from_tensors(param().format, src, dst, \ - workspace); \ - MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback(kparam)); \ - return; \ +#define cb(dt, ct) \ + case DTypeTrait
::enumv: { \ + auto kparam = KernParam::from_tensors( \ + param().format, param().imode, src, dst, workspace); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback(kparam)); \ + return; \ } switch (src.layout.dtype.enumv()) { @@ -141,10 +139,9 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, cb(dtype::Uint8, uint8_t); cb(dtype::Quantized8Asymm, uint8_t); default: - megdnn_throw( - ssprintf("Unsupported input DType in Resize: %s", - src.layout.dtype.name()) - .c_str()); + megdnn_throw(ssprintf("Unsupported input DType in Resize: %s", + src.layout.dtype.name()) + .c_str()); return; } diff --git a/dnn/src/naive/resize/opr_impl.cpp b/dnn/src/naive/resize/opr_impl.cpp index df6b527441aaa896f03f7bb14d3263334c09463d..24d174280d73ebbf67b0b52f0f7883bd018a0fb5 100644 --- a/dnn/src/naive/resize/opr_impl.cpp +++ b/dnn/src/naive/resize/opr_impl.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/naive/resize/opr_impl.h" @@ -27,10 +28,11 @@ using namespace resize; template ResizeImpl::KernParam ResizeImpl::KernParam::from_tensors( - Format format, _megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) { + Format format, InterpolationMode imode, _megdnn_tensor_in src, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { KernParam ret; ret.format = format; + ret.imode = imode; ret.n = src.layout.shape[0]; if (format == Format::NCHW) { ret.c = src.layout.shape[1]; @@ -54,6 +56,18 @@ ResizeImpl::KernParam ResizeImpl::KernParam::from_tensors( ret.iw = src.layout.shape[3]; ret.oh = dst.layout.shape[2]; ret.ow = dst.layout.shape[3]; + } else if (format == Format::NCHW44) { + ret.c = src.layout.shape[1] * 4; + ret.ih = src.layout.shape[2]; + ret.iw = src.layout.shape[3]; + ret.oh = dst.layout.shape[2]; + ret.ow = dst.layout.shape[3]; + } else if (format == Format::NCHW88) { + ret.c = src.layout.shape[1] * 8; + ret.ih = src.layout.shape[2]; + ret.iw = src.layout.shape[3]; + ret.oh = dst.layout.shape[2]; + ret.ow = dst.layout.shape[3]; } else { megdnn_assert(format == Format::NHWCD4); ret.c = src.layout.shape[2] * 4; @@ -115,33 +129,30 @@ void ResizeImpl::kern_nchw(const KernParam& kern_param, break; } case InterpolationMode::INTER_LINEAR: { - auto coord_h = get_origin_coord(scale_h, IH, oh); - auto coord_w = get_origin_coord(scale_w, IW, ow); - - float alphah = coord_h.first; - float alphaw = coord_w.first; + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; - int ih0 = coord_h.second; - int ih1 = ih0 + 1; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; + std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( + kern_param.imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( + kern_param.imode, scale_w, IW, ow); rep(c, static_cast(C)) { dptr[c * OH * OW + oh * OW + ow] = output_converter( - sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * - (1.0f - alphaw) * (1.0f - alphah) + - sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * - alphaw * (1.0f - alphah) + - sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * - (1.0f - alphaw) * alphah + - sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * - alphaw * alphah); + sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * ah0 * + aw0 + + sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * ah0 * + aw1 + + sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * ah1 * + aw0 + + sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * ah1 * + aw1); } break; } case InterpolationMode::INTER_CUBIC: { - auto coord_h = get_origin_coord(scale_h, IH, oh, true); - auto coord_w = get_origin_coord(scale_w, IW, ow, true); + auto coord_h = get_cubic_coord(scale_h, oh); + auto coord_w = get_cubic_coord(scale_w, ow); float alphah = coord_h.first; float alphaw = coord_w.first; @@ -193,7 +204,19 @@ void ResizeImpl::kern_naive(const KernParam& kern_param) { return; } else if (kern_param.format == Format::NCHW4) { MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(2)) { - kern_naive_nchw4(kern_param); + kern_naive_nchwx(kern_param); + } + MIDOUT_END(); + return; + } else if (kern_param.format == Format::NCHW44) { + MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(3)) { + kern_naive_nchwx(kern_param); + } + MIDOUT_END(); + return; + } else if (kern_param.format == Format::NCHW88) { + MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(4)) { + kern_naive_nchwx(kern_param); } MIDOUT_END(); return; @@ -209,25 +232,20 @@ void ResizeImpl::kern_naive_nhwc(const KernParam& kern_param) { rep(n, N) { rep(oh, OH) rep(ow, OW) { - auto coord_h = get_origin_coord(scale_h, IH, oh); - auto coord_w = get_origin_coord(scale_w, IW, ow); + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; - float alphah = coord_h.first; - float alphaw = coord_w.first; + std::tie(ah0, ih0, ah1, ih1) = + get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = + get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow); - int ih0 = coord_h.second; - int ih1 = ih0 + 1; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; rep(c, C) { dptr[(oh * OW + ow) * C + c] = output_converter( - sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) * - (1.0f - alphah) + - sptr[(ih0 * IW + iw1) * C + c] * alphaw * - (1.0f - alphah) + - sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) * - alphah + - sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah); + sptr[(ih0 * IW + iw0) * C + c] * ah0 * aw0 + + sptr[(ih0 * IW + iw1) * C + c] * ah0 * aw1 + + sptr[(ih1 * IW + iw0) * C + c] * ah1 * aw0 + + sptr[(ih1 * IW + iw1) * C + c] * ah1 * aw1); } } sptr += C * IH * IW; @@ -251,26 +269,20 @@ void ResizeImpl::kern_naive_nhwcd4(const KernParam& kern_param) { rep(n, N) { rep(oh, OH) rep(ow, OW) { - auto coord_h = get_origin_coord(scale_h, IH, oh); - auto coord_w = get_origin_coord(scale_w, IW, ow); + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; - float alphah = coord_h.first; - float alphaw = coord_w.first; + std::tie(ah0, ih0, ah1, ih1) = + get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = + get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow); - int ih0 = coord_h.second; - int ih1 = ih0 + 1; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; rep(c, C) { dptr[get_tensor_addr(oh, ow, c, OW, C)] = output_converter( - sptr[get_tensor_addr(ih0, iw0, c, IW, C)] * - (1.0f - alphaw) * (1.0f - alphah) + - sptr[get_tensor_addr(ih0, iw1, c, IW, C)] * alphaw * - (1.0f - alphah) + - sptr[get_tensor_addr(ih1, iw0, c, IW, C)] * - (1.0f - alphaw) * alphah + - sptr[get_tensor_addr(ih1, iw1, c, IW, C)] * alphaw * - alphah); + sptr[get_tensor_addr(ih0, iw0, c, IW, C)] * ah0 * aw0 + + sptr[get_tensor_addr(ih0, iw1, c, IW, C)] * ah0 * aw1 + + sptr[get_tensor_addr(ih1, iw0, c, IW, C)] * ah1 * aw0 + + sptr[get_tensor_addr(ih1, iw1, c, IW, C)] * ah1 * aw1); } } sptr += IH * (C / 4) * IW * 4; @@ -278,41 +290,46 @@ void ResizeImpl::kern_naive_nhwcd4(const KernParam& kern_param) { } } -template -void ResizeImpl::kern_naive_nchw4(const KernParam& kern_param) { +template +void ResizeImpl::kern_naive_nchwx(const KernParam& kern_param) { UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); rounding::RoundingConverter output_converter; float scale_h = static_cast(OH) / IH; float scale_w = static_cast(OW) / IW; + megdnn_assert(pack_size == 4 || pack_size == 8); + size_t log_pack_size = 2; + if (pack_size == 8) { + log_pack_size = 3; + } + auto get_tensor_addr = [&](size_t h, size_t w, size_t c, size_t H, size_t W, size_t C) -> size_t { - megdnn_assert((C & 0x3) == 0); - return (((c >> 2) * H * W + h * W + w) << 2) + (c & 0b11); + megdnn_assert((C & (pack_size - 1)) == 0); + return (((c >> log_pack_size) * H * W + h * W + w) << log_pack_size) + + (c & (pack_size - 1)); }; rep(n, N) { rep(oh, OH) rep(ow, OW) { - auto coord_h = get_origin_coord(scale_h, IH, oh); - auto coord_w = get_origin_coord(scale_w, IW, ow); + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; - float alphah = coord_h.first; - float alphaw = coord_w.first; + std::tie(ah0, ih0, ah1, ih1) = + get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = + get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow); - int ih0 = coord_h.second; - int ih1 = ih0 + 1; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; rep(c, C) { dptr[get_tensor_addr(oh, ow, c, OH, OW, C)] = output_converter( - sptr[get_tensor_addr(ih0, iw0, c, IH, IW, C)] * - (1.0f - alphaw) * (1.0f - alphah) + - sptr[get_tensor_addr(ih0, iw1, c, IH, IW, C)] * alphaw * - (1.0f - alphah) + - sptr[get_tensor_addr(ih1, iw0, c, IH, IW, C)] * - (1.0f - alphaw) * alphah + - sptr[get_tensor_addr(ih1, iw1, c, IH, IW, C)] * alphaw * - alphah); + sptr[get_tensor_addr(ih0, iw0, c, IH, IW, C)] * ah0 * + aw0 + + sptr[get_tensor_addr(ih0, iw1, c, IH, IW, C)] * ah0 * + aw1 + + sptr[get_tensor_addr(ih1, iw0, c, IH, IW, C)] * ah1 * + aw0 + + sptr[get_tensor_addr(ih1, iw1, c, IH, IW, C)] * ah1 * + aw1); } } sptr += IH * IW * C; @@ -327,8 +344,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, #define cb(dt, ct, _midout_iv) \ case DTypeTrait
::enumv: { \ MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \ - auto kparam = KernParam::from_tensors(param().format, src, \ - dst, workspace); \ + auto kparam = KernParam::from_tensors( \ + param().format, param().imode, src, dst, workspace); \ MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \ } \ MIDOUT_END(); \ @@ -356,15 +373,15 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, if (((src.layout[3] != 1 && src.layout[3] != 3) || !is_nhwc_contig_wc(src.layout)) || (param().imode == param::Resize::InterpolationMode::LINEAR)) { -#define cb(dt, ct, _midout_iv) \ - case DTypeTrait
::enumv: { \ - MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(_midout_iv)) { \ - auto kparam = KernParam::from_tensors(param().format, src, \ - dst, workspace); \ - MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ - } \ - MIDOUT_END(); \ - return; \ +#define cb(dt, ct, _midout_iv) \ + case DTypeTrait
::enumv: { \ + MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(_midout_iv)) { \ + auto kparam = KernParam::from_tensors( \ + param().format, param().imode, src, dst, workspace); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \ + } \ + MIDOUT_END(); \ + return; \ } switch (src.layout.dtype.enumv()) { @@ -409,27 +426,24 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, rep(oh, OH) rep(ow, OW) { switch (param().imode) { case InterpolationMode::INTER_LINEAR: { - auto coord_h = get_origin_coord(scale_h, IH, oh); - auto coord_w = get_origin_coord(scale_w, IW, ow); - - float alphah = coord_h.first; - float alphaw = coord_w.first; + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; - int ih0 = coord_h.second; - int ih1 = ih0 + 1; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; + std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( + param().imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( + param().imode, scale_w, IW, ow); rep(c, C) { float hidden = hptr[c * OH * OW + oh * OW + ow]; sptr[c * IH * IW + ih0 * IW + iw0] += - (1.0f - alphaw) * (1.0f - alphah) * hidden; + ah0 * aw0 * hidden; sptr[c * IH * IW + ih1 * IW + iw0] += - (1.0f - alphaw) * alphah * hidden; + ah1 * aw0 * hidden; sptr[c * IH * IW + ih0 * IW + iw1] += - alphaw * (1.0f - alphah) * hidden; + ah0 * aw1 * hidden; sptr[c * IH * IW + ih1 * IW + iw1] += - alphaw * alphah * hidden; + ah1 * aw1 * hidden; } break; } @@ -443,8 +457,8 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, break; } case InterpolationMode::INTER_CUBIC: { - auto coord_h = get_origin_coord(scale_h, IH, oh, true); - auto coord_w = get_origin_coord(scale_w, IW, ow, true); + auto coord_h = get_cubic_coord(scale_h, oh); + auto coord_w = get_cubic_coord(scale_w, ow); float alphah = coord_h.first; float alphaw = coord_w.first; @@ -460,7 +474,8 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, rep(kh, ksize) { int h = saturate(ih0 + kh, 0, IH - 1); rep(kw, ksize) { - int w = saturate(iw0 + kw, 0, IW - 1); + int w = saturate(iw0 + kw, 0, + IW - 1); sptr[c * IH * IW + h * IW + w] += hptr[c * OH * OW + oh * OW + ow] * h_coeff[kh] * w_coeff[kw]; diff --git a/dnn/src/naive/resize/opr_impl.h b/dnn/src/naive/resize/opr_impl.h index 59c6fb9dc1c7a845538300eeb4ce9c8a77173083..15a945de725ffb4ed497fdd47d249ae40e92011d 100644 --- a/dnn/src/naive/resize/opr_impl.h +++ b/dnn/src/naive/resize/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -19,15 +20,18 @@ namespace naive { class ResizeImpl : public Resize { public: using Format = Param::Format; + using InterpolationMode = Param::InterpolationMode; template struct KernParam { Format format; + InterpolationMode imode; size_t n, c, ih, iw, oh, ow; ptrdiff_t s_in, s_ic, s_ih, s_iw; ctype *sptr, *dptr; Workspace workspace; - static KernParam from_tensors(Format format, _megdnn_tensor_in src, + static KernParam from_tensors(Format format, InterpolationMode imode, + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace); }; @@ -41,6 +45,7 @@ public: const TensorLayout&) override { return 0; } + private: // ctype: C type of input data type. template @@ -55,8 +60,8 @@ private: template void kern_naive_nhwcd4(const KernParam& kern_param); - template - void kern_naive_nchw4(const KernParam& kern_param); + template + void kern_naive_nchwx(const KernParam& kern_param); }; // class ResizeImpl @@ -65,15 +70,15 @@ private: ctype* __restrict sptr = p.sptr; \ ctype* __restrict dptr = p.dptr; -#define UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(p) \ - UNPACK_RESIZE_FWD_KERN_PARAM(p) \ +#define UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(p) \ + UNPACK_RESIZE_FWD_KERN_PARAM(p) \ auto S_IN = p.s_in, S_IC = p.s_ic, S_IH = p.s_ih, S_IW = p.s_iw; -class ResizeBackwardImpl: public ResizeBackward { +class ResizeBackwardImpl : public ResizeBackward { public: using ResizeBackward::ResizeBackward; - void exec(_megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { return 0; diff --git a/dnn/test/arm_common/resize.cpp b/dnn/test/arm_common/resize.cpp index 360bc8494f998868963d385fad050db50ef77803..1455c725a928d7ca0b8e0a6df60aecd73fe8cd2c 100644 --- a/dnn/test/arm_common/resize.cpp +++ b/dnn/test/arm_common/resize.cpp @@ -6,40 +6,66 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "test/arm_common/fixture.h" #include "test/common/resize.h" +#include "test/arm_common/fixture.h" #include "test/common/checker.h" namespace megdnn { namespace test { -TEST_F(ARM_COMMON, RESIZE_CV) -{ +TEST_F(ARM_COMMON, RESIZE_CV) { using namespace resize; std::vector args = get_cv_args(); Checker checker(handle()); - for (auto &&arg: args) { + for (auto&& arg : args) { checker.set_param(arg.param) - .set_epsilon(1 + 1e-3) - .set_dtype(0, dtype::Uint8()) - .set_dtype(1, dtype::Uint8()) - .execs({arg.src, arg.dst}); + .set_epsilon(1 + 1e-3) + .set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::Uint8()) + .execs({arg.src, arg.dst}); } - for (auto &&arg: args) { + for (auto&& arg : args) { checker.set_param(arg.param) - .set_dtype(0, dtype::Float32()) - .set_dtype(1, dtype::Float32()) - .execs({arg.src, arg.dst}); + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .execs({arg.src, arg.dst}); } +} + +TEST_F(ARM_COMMON, RESIZE_NCHW44) { + using namespace resize; + std::vector args = get_nchw44_args(); + Checker checker(handle()); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .execs({arg.src, arg.dst}); + } +} + +TEST_F(ARM_COMMON, RESIZE_NCHW88) { + using namespace resize; + std::vector args = get_nchw88_args(); + Checker checker(handle()); + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_epsilon(0.01) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .execs({arg.src, arg.dst}); + } } -} // namespace test -} // namespace megdnn +} // namespace test +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/test/common/resize.h b/dnn/test/common/resize.h index d154adefe106fc0d68503a34569b49ed984b5c9e..cd9ead105776cd2245189b3bc3203fd711472a4e 100644 --- a/dnn/test/common/resize.h +++ b/dnn/test/common/resize.h @@ -6,12 +6,13 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once -#include "megdnn/opr_param_defs.h" -#include "megdnn/basic_types.h" #include +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" #include "./rng.h" namespace megdnn { @@ -68,13 +69,15 @@ static inline std::vector get_args(IMode imode = IMode::INTER_LINEAR) { std::vector args; set_nchw_args(args); - if(imode == IMode::INTER_LINEAR) { - //! test NHWC with ch != 1 or ch != 3 + if (imode == IMode::INTER_LINEAR) { + //! test NHWC with ch != 1 or ch != 3 param::Resize param; param.format = param::Resize::Format::NHWC; param.imode = imode; - args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 4, 6, 4}); - args.emplace_back(param, TensorShape{2, 4, 6, 4}, TensorShape{2, 2, 3, 4}); + args.emplace_back(param, TensorShape{2, 2, 3, 4}, + TensorShape{2, 4, 6, 4}); + args.emplace_back(param, TensorShape{2, 4, 6, 4}, + TensorShape{2, 2, 3, 4}); } return args; } @@ -108,6 +111,48 @@ static inline std::vector get_nchw4_args() { return args; } +static inline std::vector get_nchw44_args() { + std::vector args; + + param::Resize param; + param.format = param::Resize::Format::NCHW44; + param.imode = param::Resize::InterpolationMode::LINEAR; + rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) + args.emplace_back( + param, + TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 4ul}, + TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 4ul}); + + param.imode = param::Resize::InterpolationMode::NEAREST; + rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) + args.emplace_back( + param, + TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 4ul}, + TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 4ul}); + return args; +} + +static inline std::vector get_nchw88_args() { + std::vector args; + + param::Resize param; + param.format = param::Resize::Format::NCHW88; + param.imode = param::Resize::InterpolationMode::LINEAR; + rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) + args.emplace_back( + param, + TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 8ul}, + TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 8ul}); + + param.imode = param::Resize::InterpolationMode::NEAREST; + rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) + args.emplace_back( + param, + TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 8ul}, + TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 8ul}); + return args; +} + static inline std::vector get_cv_args() { std::vector args; diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index c2fb749685a6b0c09686d2d5efd62a0cd9dfd1fc..a6be2bd5c2e22993893ae927db59547a89a21360 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -68,87 +68,90 @@ using namespace gopt; * oprs should not get involved in any actual computing. */ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, - cg::SingleCNOperatorNodeBase) // { + cg::SingleCNOperatorNodeBase) // { public: - //! relayout type of this opr - enum class LayoutType { - NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout - NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout - NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout - CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout - NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout - NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose - ///< channel size less than 4 - NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout - NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout - NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout - - WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 - //!< layout - WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to - //!< nchw4 layout - WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout - //!< to nchw4 layout whose - //! channel size less than 4 - - WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 - //!< layout - WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to - //!< nchw88 layout - WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout - //!< to nchw88 layout - //!< the weight layout of input is nchw output is nchw88, special for - //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} - WEIGHT_HYBIRD_NCHW_NCHW88, - - WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 - //!< layout - WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to - //!< nchw44 layout - WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout - //!< to nchw44 layout - //!< the weight layout of input is nchw output is nchw44, special for - //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} - WEIGHT_HYBIRD_NCHW_NCHW44, - WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to - //!< NCHW44_DOT layout dense - WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to - //!< NCHW44_DOT layout group - NCHW32_TO_NCHW, //! EnableNCHW4Pass::make_nchw4_converter() { }; auto replace_deconv_opr = [trans_nchw4, conv_format]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) { return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); @@ -1881,7 +1885,8 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { opr->config()); } VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0]; - auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter); + auto deconv_mode = + trans_nchw4(deconv_opr.param().sparse, deconv_filter); // src: NCHW --> NCWH4 if (deconv_src->shape().ndim != 5) { mgb_assert(deconv_src->shape().ndim == 4); @@ -2028,10 +2033,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { conv_bias_src, conv_bias_filter, new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert( - new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || - new_conv_bias_opr.shape().ndim == 5, - "The conv_bias dst dim is not trans to nchw4"); + mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == + DTypeEnum::Float32 || + new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); return new_opr; } // bias: NCHW --> NCHW4 when bias_dtype is not Float32 @@ -2047,10 +2052,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert( - new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || - new_conv_bias_opr.shape().ndim == 5, - "The conv_bias dst dim is not trans to nchw4"); + mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == + DTypeEnum::Float32 || + new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); return new_opr; } // z_inp: NCHW --> NCHW4 when bias_dtype is not Float32 @@ -2066,10 +2071,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert( - new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || - new_conv_bias_opr.shape().ndim == 5, - "The conv_bias dst dim is not trans to nchw4"); + mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == + DTypeEnum::Float32 || + new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); return new_opr; }; auto replace_elemwise_opr = [=](OperatorNodeBase* opr, @@ -2210,8 +2215,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { auto&& replace_func = ret->m_opr_replace_func; //! supportted nchw4 replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; - replace_func[opr::ConvolutionBackwardData::typeinfo()] = - replace_deconv_opr; + replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; @@ -2348,6 +2352,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { megdnn::param::Convolution::Format::NCHW88; megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format::NCHW88; + megdnn::param::Resize::Format resize_format = + megdnn::param::Resize::Format::NCHW88; std::string convter_pass_name = "conv_format_nchw88"; if (pack_c_size == 4) { @@ -2360,6 +2366,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; conv_format = megdnn::param::Convolution::Format::NCHW44; pooling_format = megdnn::param::Pooling::Format::NCHW44; + resize_format = megdnn::param::Resize::Format::NCHW44; convter_pass_name = "conv_format_nchw44"; } auto test_trans_nchwxx = @@ -2634,6 +2641,43 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { return new_opr; } }; + + auto replace_resize_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& resize_opr = opr->cast_final_safe(); + mgb_throw_if( + resize_opr.param().format != + megdnn::param::Resize::Format::NCHW && + resize_opr.param().format != + megdnn::param::Resize::Format::NHWC, + MegBrainError, + "ConvertFormat Pass only support converting NCHW to NCHWxx"); + + VarNode* inp = new_inp[0]; + if (resize_opr.param().format == megdnn::param::Resize::Format::NHWC) { + auto temp_inp = new_inp; + if (inp->shape().ndim == 5) { + auto new_var = RelayoutPlaceholder::make(inp, src_to_nchw_mode); + temp_inp[0] = new_var.node(); + } + return serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + } else { + auto temp_inp = new_inp; + if (inp->shape().ndim == 5) { + auto new_param = resize_opr.param(); + new_param.format = resize_format; + auto new_resize_opr = opr::ResizeForward::make( + new_inp[0], new_inp[1], new_param, opr->config()); + return new_resize_opr.node()->owner_opr(); + } else { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + } + }; + //! When input change and all input can convert to nchwxx, this opr will run //! in nchwxx mode, else it will run in nchw mode, for example concat and //! elemwise opr @@ -2704,6 +2748,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; @@ -2718,7 +2763,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpPerspectiveForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; @@ -3236,26 +3280,27 @@ public: MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, cg::SingleCNOperatorNodeBase) // { public: - AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); +AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, + TensorFormat out_format); - static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); +static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, + TensorFormat out_format); - TensorFormat inp_format() const { - return m_inp_format; - } +TensorFormat inp_format() const { + return m_inp_format; +} - TensorFormat out_format() const { - return m_out_format; - } +TensorFormat out_format() const { + return m_out_format; +} private: - void init_output_static_infer_desc() override; - void scn_do_execute() override; - const TensorFormat m_inp_format; - const TensorFormat m_out_format; -}; +void init_output_static_infer_desc() override; +void scn_do_execute() override; +const TensorFormat m_inp_format; +const TensorFormat m_out_format; +} +; MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); @@ -3910,8 +3955,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { opr_set.insert(opr); // check dimshuffle - auto shuffle = try_cast_as_op( - reshape->input(0)->owner_opr()); + auto shuffle = + try_cast_as_op(reshape->input(0)->owner_opr()); if (shuffle == nullptr) return false; auto&& param = shuffle->param(); @@ -3981,10 +4026,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto conv_bias_shuffle = opr::ConvBias::make( src, filter, new_bias, new_param, conv_bias->execution_policy(), OperatorNodeConfig{out_dtype}); - rewriter.replace_var( - opr->output(0), conv_bias_shuffle.node(), - mgb_cstr_log("replace conv_bias + " - "reformat to conv_bias(NCHW4_NHWC)")); + rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + " + "reformat to conv_bias(NCHW4_NHWC)")); return true; }; @@ -4036,8 +4080,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { return false; auto inp_dtype = conv_bias->input(0)->dtype(); bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && - conv_bias->param().format == - megdnn::param::ConvBias::Format::NCHW32; + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW32; if (!is_s8nchw32) return false; if (conv_bias->input().size() != 3) @@ -4078,9 +4122,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { &rewriter](OperatorNodeBase* opr) { if (!try_conv_dimshuffle_reshape_typecvt(opr) && !try_conv_reformat_nchw42nchw32(opr) && - !try_conv_reformat_nchw42nhwc(opr) - && !try_conv_reformat_nchw322nchw4(opr) - ) { + !try_conv_reformat_nchw42nhwc(opr) && + !try_conv_reformat_nchw322nchw4(opr)) { rewriter.auto_replace_outputs(opr); } }; @@ -4497,7 +4540,7 @@ void PaddingChannelPass::apply(OptState& opt) const { /* ================ EnableNCHW64Pass =============== */ VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, - VarNode* orig_var) const { + VarNode* orig_var) const { if (!orig_var->shape().eq_shape(new_var->shape())) { auto iter = m_opr_format_map.find(new_var->owner_opr()); mgb_assert(iter != m_opr_format_map.end(), @@ -4532,8 +4575,7 @@ VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, return new_var; } -std::unique_ptr -EnableNCHW64Pass::make_nchw64_converter() { +std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { MIDOUT_B("EnableNCHW64Pass::make") auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ @@ -4618,15 +4660,15 @@ EnableNCHW64Pass::make_nchw64_converter() { [make_new_conv, &format_map]( OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { - mgb_assert(opr->input().size()==new_inp.size()); + mgb_assert(opr->input().size() == new_inp.size()); bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; mgb_assert(opr->output().size() > 0); bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32; if (opr->input().size() >= 3) { - auto dtype_expect = dst_float ? DTypeEnum::Float32 - : DTypeEnum::QuantizedS32; + auto dtype_expect = + dst_float ? DTypeEnum::Float32 : DTypeEnum::QuantizedS32; check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect; } if (opr->input().size() >= 4) { @@ -4677,12 +4719,13 @@ EnableNCHW64Pass::make_nchw64_converter() { for (size_t i = 0; i < inps.size(); ++i) { // do not format bias and z when dst_float is true bool skip = dst_float && i >= 2; - if (!skip) inps[i] = process(i); + if (!skip) + inps[i] = process(i); } auto& conv_bias = opr->cast_final_safe(); - auto ret = make_new_conv( - inps, &conv_bias, - dst_float ? Format::NCHW4_NCHW : Format::NCHW4); + auto ret = + make_new_conv(inps, &conv_bias, + dst_float ? Format::NCHW4_NCHW : Format::NCHW4); if (!dst_float) format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); return ret; @@ -4692,7 +4735,7 @@ EnableNCHW64Pass::make_nchw64_converter() { [make_new_conv, &format_map]( OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { - mgb_assert(opr->input().size()==new_inp.size()); + mgb_assert(opr->input().size() == new_inp.size()); bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; @@ -4754,18 +4797,17 @@ EnableNCHW64Pass::make_nchw64_converter() { OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { // fint4XWint4 and fuint4XWint4 - mgb_assert(opr->input().size()==new_inp.size()); + mgb_assert(opr->input().size() == new_inp.size()); bool check_dtype = (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || - new_inp[0]->dtype().enumv() == - DTypeEnum::Quantized4Asymm) && + new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; if (opr->input().size() >= 3) check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 4) - check_dtype &= new_inp[3]->dtype().enumv() == - new_inp[0]->dtype().enumv(); + check_dtype &= + new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv(); if (!check_dtype) return nullptr; size_t out_channels = opr->input(1)->shape()[0]; @@ -4818,18 +4860,17 @@ EnableNCHW64Pass::make_nchw64_converter() { OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { // fint4XWint4 and fuint4XWint4 - mgb_assert(opr->input().size()==new_inp.size()); + mgb_assert(opr->input().size() == new_inp.size()); bool check_dtype = (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || - new_inp[0]->dtype().enumv() == - DTypeEnum::Quantized4Asymm) && + new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; if (opr->input().size() >= 3) check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 4) - check_dtype &= new_inp[3]->dtype().enumv() == - new_inp[0]->dtype().enumv(); + check_dtype &= + new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv(); if (!check_dtype) return nullptr; size_t out_channels = opr->input(1)->shape()[0]; @@ -4842,8 +4883,7 @@ EnableNCHW64Pass::make_nchw64_converter() { auto iter = format_map.find(new_inp[i]->owner_opr()); if (iter == format_map.end()) { auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC); + inps[i], RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC); return ovar.node(); } else { const auto& fmt = iter->second; @@ -4973,7 +5013,7 @@ EnableNCHW64Pass::make_nchw64_converter() { default: mgb_assert(cur == Format::NCHW4); } - + auto param = deconv.param(); param.format = Format::NCHW4; auto new_deconv = opr::ConvolutionBackwardData::make( @@ -4990,7 +5030,7 @@ EnableNCHW64Pass::make_nchw64_converter() { break; } } - mgb_assert(!shape_changed, + mgb_assert(!shape_changed, "EnableNCHW64Pass won't change format of output tensor " "of non quantized deconv operator(name:%s)", opr->cname()); @@ -5000,8 +5040,9 @@ EnableNCHW64Pass::make_nchw64_converter() { }; // replace rule for elemwise like opr - auto replace_elemwise_like_opr = [&format_map](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + auto replace_elemwise_like_opr = [&format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); ThinHashMap format_size; bool same_format = true; @@ -5073,7 +5114,7 @@ EnableNCHW64Pass::make_nchw64_converter() { cur = Format::NCHW; } if (cur != max_format) { - inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); + inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); } } auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); @@ -5131,8 +5172,7 @@ EnableNCHW64Pass::make_nchw64_converter() { SymbolVar new_warp; if (inps.size() == 3) { new_warp = opr::WarpPerspectiveForward::make( - inps[0], inps[1], inps[2], param, - warp.config()); + inps[0], inps[1], inps[2], param, warp.config()); } else { mgb_assert(inps.size() == 4); new_warp = opr::WarpPerspectiveForward::make( @@ -5179,14 +5219,13 @@ EnableNCHW64Pass::make_nchw64_converter() { default: mgb_assert(cur == Format::NCHW4); } - + auto param = warp.param(); param.format = Format::NCHW4; SymbolVar new_warp; if (inps.size() == 3) { new_warp = opr::WarpPerspectiveForward::make( - inps[0], inps[1], inps[2], param, - warp.config()); + inps[0], inps[1], inps[2], param, warp.config()); } else { mgb_assert(inps.size() == 4); new_warp = opr::WarpPerspectiveForward::make( @@ -5204,7 +5243,7 @@ EnableNCHW64Pass::make_nchw64_converter() { break; } } - mgb_assert(!shape_changed, + mgb_assert(!shape_changed, "EnableNCHW64Pass won't change format of output tensor " "of non quantized warp perspective operator(name:%s)", opr->cname()); @@ -5212,9 +5251,8 @@ EnableNCHW64Pass::make_nchw64_converter() { opr->config()); } }; - auto replace_pooling_opr = [&format_map]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + auto replace_pooling_opr = [&format_map](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& pooling = opr->cast_final_safe(); if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || @@ -5300,7 +5338,7 @@ EnableNCHW64Pass::make_nchw64_converter() { mgb_assert(cur == Format::NCHW4); } Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4; - + auto param = pooling.param(); param.format = out_format; auto new_pool = @@ -5336,7 +5374,7 @@ EnableNCHW64Pass::make_nchw64_converter() { auto inps = new_inp; for (size_t i = 0; i < opr->input().size(); ++i) { auto iter = format_map.find(new_inp[i]->owner_opr()); - auto fmt = iter != format_map.end()?iter->second:Format::NCHW; + auto fmt = iter != format_map.end() ? iter->second : Format::NCHW; if (iter != format_map.end()) { switch (fmt) { case Format::NHWC: diff --git a/src/opr/impl/imgproc.cpp b/src/opr/impl/imgproc.cpp index eedc3e5377bf39065ba419e5ce017046a112d838..22bc515ecc5dba86f4140cf282fe53b4cb434db8 100644 --- a/src/opr/impl/imgproc.cpp +++ b/src/opr/impl/imgproc.cpp @@ -10,9 +10,9 @@ * implied. */ +#include "megbrain/opr/imgproc.h" #include "./internal/megdnn_opr_wrapper.inl" #include "megbrain/graph/grad_impl.h" -#include "megbrain/opr/imgproc.h" #include "megbrain/opr/io.h" #include "megbrain/opr/utility.h" @@ -340,7 +340,9 @@ void ResizeForward::outshape_by_symvar_do_get_output_shape( //! The index of height, e.g.,[b, h, w, c], the height_idx = 1 size_t height_idx = 0; if (param().format == Param::Format::NCHW || - param().format == Param::Format::NCHW4) { + param().format == Param::Format::NCHW4 || + param().format == Param::Format::NCHW44 || + param().format == Param::Format::NCHW88) { height_idx = 2; } else { height_idx = 1;