提交 bde5cf35 编写于 作者: M Megvii Engine Team

feat(dnn): add resize linear for arm

GitOrigin-RevId: 14ac5bda3f60f530ca9d42e94b1f4e401d0a1309
上级 b6142bee
/**
* \file dnn/src/arm_common/resize/direct_nchwxx.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/resize/direct_nchwxx.h"
#include "src/arm_common/resize/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using namespace megdnn;
using namespace arm_common;
using namespace resize;
namespace {
template <typename ctype, InterpolationMode imode>
void resize_direct_nchwxx(const ctype* sptr, ctype* dptr, size_t N, size_t IH,
size_t IW, size_t OH, size_t OW) {
using simd_helper = SIMDHelper<ctype>;
constexpr size_t PC = simd_helper::simd_width;
using simd_type = typename simd_helper::simd_type;
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
for (size_t n = 0; n < N; ++n) {
for (size_t oh = 0; oh < OH; ++oh) {
for (size_t ow = 0; ow < OW; ++ow) {
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;
std::tie(ah0, ih0, ah1, ih1) =
get_nearest_linear_coord(imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) =
get_nearest_linear_coord(imode, scale_w, IW, ow);
simd_type r0 = simd_helper::load(sptr + (ih0 * IW + iw0) * PC);
simd_type r1 = simd_helper::load(sptr + (ih0 * IW + iw1) * PC);
simd_type r2 = simd_helper::load(sptr + (ih1 * IW + iw0) * PC);
simd_type r3 = simd_helper::load(sptr + (ih1 * IW + iw1) * PC);
// FIXME: weight fp16 may cause precision problem
ctype a0 = ah0 * aw0;
ctype a1 = ah0 * aw1;
ctype a2 = ah1 * aw0;
ctype a3 = ah1 * aw1;
simd_type c = simd_helper::dup(0);
c = simd_helper::fma(c, r0, a0);
c = simd_helper::fma(c, r1, a1);
c = simd_helper::fma(c, r2, a2);
c = simd_helper::fma(c, r3, a3);
simd_helper::store(dptr + (oh * OW + ow) * PC, c);
}
}
sptr += IH * IW * PC;
dptr += OH * OW * PC;
}
}
}
void megdnn::arm_common::resize_direct_nearest_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param) {
resize_direct_nchwxx<float, InterpolationMode::INTER_NEAREST>(
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4,
kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow);
}
void megdnn::arm_common::resize_direct_linear_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param) {
resize_direct_nchwxx<float, InterpolationMode::INTER_LINEAR>(
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4,
kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void megdnn::arm_common::resize_direct_nearest_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param) {
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr);
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr);
resize_direct_nchwxx<__fp16, InterpolationMode::INTER_NEAREST>(
sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih,
kern_param.iw, kern_param.oh, kern_param.ow);
}
void megdnn::arm_common::resize_direct_linear_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param) {
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr);
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr);
resize_direct_nchwxx<__fp16, InterpolationMode::INTER_LINEAR>(
sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih,
kern_param.iw, kern_param.oh, kern_param.ow);
}
#endif
/**
* \file dnn/src/arm_common/resize/direct_nchwxx.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/resize/opr_impl.h"
namespace megdnn {
namespace arm_common {
void resize_direct_linear_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param);
void resize_direct_nearest_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void resize_direct_linear_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param);
void resize_direct_nearest_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param);
#endif
} // namespace arm_common
} // namespace megdnn
/**
* \file dnn/src/arm_common/resize/helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace arm_common {
namespace resize {
using InterpolationMode = Resize::InterpolationMode;
template <typename ctype>
struct SIMDHelper {};
template <>
struct SIMDHelper<float> {
using simd_type = float32x4_t;
using simd_type_x2 = float32x4x2_t;
using ctype = float;
static constexpr size_t simd_width = 4;
static inline simd_type load(const ctype* src_ptr) {
return vld1q_f32(src_ptr);
}
static inline void store(ctype* dst_ptr, const simd_type& rdst) {
vst1q_f32(dst_ptr, rdst);
}
static inline void store2_interleave(ctype* dst_ptr, const simd_type& rdst1,
const simd_type& rdst2) {
simd_type_x2 rdst;
rdst.val[0] = rdst1;
rdst.val[1] = rdst2;
vst2q_f32(dst_ptr, rdst);
}
static inline simd_type fma(const simd_type& a, const simd_type& b,
ctype n) {
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
return vfmaq_n_f32(a, b, n);
#else
return vmlaq_n_f32(a, b, n);
#endif
}
static inline simd_type fma(const simd_type& a, const simd_type& b,
const simd_type& c) {
#if defined(__ARM_FEATURE_FMA)
return vfmaq_f32(a, b, c);
#else
return vmlaq_f32(a, b, c);
#endif
}
static inline simd_type dup(float val) { return vdupq_n_f32(val); }
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct SIMDHelper<__fp16> {
using simd_type = float16x8_t;
using simd_type_x2 = float16x8x2_t;
using ctype = __fp16;
static constexpr size_t simd_width = 8;
static inline simd_type load(const ctype* src_ptr) {
return vld1q_f16(src_ptr);
}
static inline void store(ctype* dst_ptr, const simd_type& rdst) {
vst1q_f16(dst_ptr, rdst);
}
static inline void store2_interleave(ctype* dst_ptr, const simd_type& rdst1,
const simd_type& rdst2) {
simd_type_x2 rdst;
rdst.val[0] = rdst1;
rdst.val[1] = rdst2;
vst2q_f16(dst_ptr, rdst);
}
static inline simd_type fma(const simd_type& a, const simd_type& b,
ctype n) {
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
return vfmaq_n_f16(a, b, n);
#else
return vaddq_f16(a, vmulq_n_f16(b, n));
#endif
}
static inline simd_type fma(const simd_type& a, const simd_type& b,
const simd_type& c) {
return vfmaq_f16(a, b, c);
}
static inline simd_type dup(float val) { return vdupq_n_f16(val); }
};
#endif
static inline int get_nearest_src(float scale, int size, int idx) {
return std::min(static_cast<int>(idx / scale), size - 1);
}
static inline std::tuple<float, int, float, int> get_nearest_linear_coord(
InterpolationMode imode, float scale, int size, int idx) {
if (size == 1) {
return std::make_tuple(1.0f, 0, 0.0f, 0);
}
float alpha = (idx + 0.5f) / scale - 0.5f;
int origin_idx = static_cast<int>(floor(alpha));
alpha -= origin_idx;
if (imode == InterpolationMode::INTER_NEAREST) {
origin_idx = get_nearest_src(scale, size, idx);
alpha = 0;
}
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
}
return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1);
}
};
};
};
......@@ -12,212 +12,181 @@
#include "src/arm_common/resize/opr_impl.h"
#include "src/arm_common/handle.h"
#include "src/arm_common/resize/direct_nchwxx.h"
#include "src/arm_common/resize/resize_cv.h"
#include "src/arm_common/resize/upsample2_nchw.h"
#include "src/arm_common/resize/upsample2_nchwxx.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using namespace megdnn;
using namespace arm_common;
#include "midout.h"
MIDOUT_DECL(megdnn_arm_resize)
namespace megdnn {
namespace arm_common {
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW44 ||
param().format == param::Resize::Format::NCHW88) {
bool is_contiguous =
src.layout.is_contiguous() && dst.layout.is_contiguous();
bool dtype_same = src.layout.dtype == dst.layout.dtype;
bool nchw44_enable = param().format == param::Resize::Format::NCHW44 &&
src.layout.dtype == dtype::Float32();
bool nchw88_enable =
param().format == param::Resize::Format::NCHW88 &&
DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false);
bool interp_supported =
param().imode ==
param::Resize::InterpolationMode::INTER_NEAREST ||
param().imode == param::Resize::InterpolationMode::INTER_LINEAR;
bool is_upsample2 =
param().imode ==
param::Resize::InterpolationMode::INTER_NEAREST &&
src.layout.shape[2] * 2 == dst.layout.shape[2] &&
src.layout.shape[3] * 2 == dst.layout.shape[3];
bool need_fallback = !is_contiguous || !dtype_same ||
!interp_supported ||
(!nchw44_enable && !nchw88_enable);
bool is_contiguous =
src.layout.is_contiguous() && dst.layout.is_contiguous();
bool is_dtype_same = src.layout.dtype == dst.layout.dtype;
bool is_dtype_fp32 = src.layout.dtype == dtype::Float32();
bool is_dtype_fp16 =
DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false);
bool is_dtype_supported = is_dtype_same && (is_dtype_fp32 || is_dtype_fp16);
bool is_nchw = param().format == param::Resize::Format::NCHW &&
(is_dtype_fp32 || is_dtype_fp16);
bool is_nchw44_fp32 =
param().format == param::Resize::Format::NCHW44 && is_dtype_fp32;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
bool is_nchw88_fp16 =
param().format == param::Resize::Format::NCHW88 && is_dtype_fp16;
#endif
if (need_fallback) {
fallback::ResizeImpl::exec(src, dst, workspace);
} else if (nchw44_enable) {
auto kern_param = KernParam<float>::from_tensors(
param().format, param().imode, src, dst, workspace);
bool is_imode_nearest =
param().imode == param::Resize::InterpolationMode::INTER_NEAREST;
bool is_imode_linear =
param().imode == param::Resize::InterpolationMode::INTER_LINEAR;
bool is_imode_supported = is_imode_nearest || is_imode_linear;
bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] &&
src.layout.shape[3] * 2 == dst.layout.shape[3];
bool usable = is_contiguous && is_dtype_supported && is_imode_supported;
if (param().format == param::Resize::Format::NHWC &&
(src.layout[3] == 1 || src.layout[3] == 3) &&
is_nhwc_contig_wc(src.layout)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode));
} else if (!usable) {
fallback::ResizeImpl::exec(src, dst, workspace);
} else if (is_dtype_fp32) {
auto kern_param = KernParam<float>::from_tensors(
param().format, param().imode, src, dst, workspace);
if (is_nchw44_fp32) {
if (is_upsample2) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
kern_nearest_upsample2_pack_simd_width(src, dst));
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(0)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw44_fp32(
kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(1)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw44_fp32(
kern_param));
}
MIDOUT_END();
}
} else {
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw44_fp32(kern_param));
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(2)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_nearest_nchw44_fp32(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(3)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_linear_nchw44_fp32(kern_param));
}
MIDOUT_END();
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (nchw88_enable) {
auto kern_param = KernParam<dt_float16>::from_tensors(
param().format, param().imode, src, dst, workspace);
} else if (is_nchw) {
if (is_upsample2) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
kern_nearest_upsample2_pack_simd_width(src, dst));
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(4)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw_fp32(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(5)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw_fp32(kern_param));
}
MIDOUT_END();
}
} else {
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw88_fp16(kern_param));
fallback::ResizeImpl::exec(src, dst, workspace);
}
#endif
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
} else if (param().format == param::Resize::Format::NCHW ||
(src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) {
fallback::ResizeImpl::exec(src, dst, workspace);
} else {
megdnn_assert(param().format == param::Resize::Format::NHWC,
"invalid resize format");
MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode));
}
}
template <typename ctype>
void ResizeImpl::kern_nchw44_fp32(const KernParam<ctype>& kern_param) {
UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
for (size_t n = 0; n < N; ++n) {
for (size_t c = 0; c < C / 4; ++c) {
for (size_t oh = 0; oh < OH; ++oh) {
for (size_t ow = 0; ow < OW; ++ow) {
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord(
kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord(
kern_param.imode, scale_w, IW, ow);
#define SRC_ADDRESS(ih, iw) \
(sptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 4)
#define DST_ADDRESS(oh, ow) \
(dptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 4)
float32x4_t r0 = vld1q_f32(SRC_ADDRESS(ih0, iw0));
float32_t a0 = ah0 * aw0;
float32x4_t r1 = vld1q_f32(SRC_ADDRESS(ih0, iw1));
float32_t a1 = ah0 * aw1;
float32x4_t r2 = vld1q_f32(SRC_ADDRESS(ih1, iw0));
float32_t a2 = ah1 * aw0;
float32x4_t r3 = vld1q_f32(SRC_ADDRESS(ih1, iw1));
float32_t a3 = ah1 * aw1;
r0 = vmulq_n_f32(r0, a0);
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
r0 = vfmaq_n_f32(r0, r1, a1);
r0 = vfmaq_n_f32(r0, r2, a2);
r0 = vfmaq_n_f32(r0, r3, a3);
#else
r0 = vaddq_f32(r0, vmulq_n_f32(r1, a1));
r0 = vaddq_f32(r0, vmulq_n_f32(r2, a2));
r0 = vaddq_f32(r0, vmulq_n_f32(r3, a3));
#endif
vst1q_f32(DST_ADDRESS(oh, ow), r0);
#undef SRC_ADDRESS
#undef DST_ADDRESS
}
}
}
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename ctype>
void ResizeImpl::kern_nchw88_fp16(const KernParam<ctype>& kern_param) {
UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
const float16_t* src_ptr = reinterpret_cast<float16_t*>(sptr);
float16_t* dst_ptr = reinterpret_cast<float16_t*>(dptr);
for (size_t n = 0; n < N; ++n) {
for (size_t c = 0; c < C / 8; ++c) {
for (size_t oh = 0; oh < OH; ++oh) {
for (size_t ow = 0; ow < OW; ++ow) {
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord(
kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord(
kern_param.imode, scale_w, IW, ow);
#define SRC_ADDRESS(ih, iw) \
(src_ptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 8)
#define DST_ADDRESS(oh, ow) \
(dst_ptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 8)
float16x8_t r0 = vld1q_f16(SRC_ADDRESS(ih0, iw0));
float32_t a0 = ah0 * aw0;
float16x8_t r1 = vld1q_f16(SRC_ADDRESS(ih0, iw1));
float32_t a1 = ah0 * aw1;
float16x8_t r2 = vld1q_f16(SRC_ADDRESS(ih1, iw0));
float32_t a2 = ah1 * aw0;
float16x8_t r3 = vld1q_f16(SRC_ADDRESS(ih1, iw1));
float32_t a3 = ah1 * aw1;
r0 = vmulq_n_f16(r0, a0);
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
r0 = vfmaq_n_f16(r0, r1, a1);
r0 = vfmaq_n_f16(r0, r2, a2);
r0 = vfmaq_n_f16(r0, r3, a3);
#else
r0 = vaddq_f16(r0, vmulq_n_f16(r1, a1));
r0 = vaddq_f16(r0, vmulq_n_f16(r2, a2));
r0 = vaddq_f16(r0, vmulq_n_f16(r3, a3));
#endif
vst1q_f16(DST_ADDRESS(oh, ow), r0);
#undef SRC_ADDRESS
#undef DST_ADDRESS
} else if (is_dtype_fp16) {
auto kern_param = KernParam<dt_float16>::from_tensors(
param().format, param().imode, src, dst, workspace);
if (is_nchw88_fp16) {
if (is_upsample2) {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(6)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw88_fp16(
kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(7)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw88_fp16(
kern_param));
}
MIDOUT_END();
}
} else {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(8)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_nearest_nchw88_fp16(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(9)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_direct_linear_nchw88_fp16(kern_param));
}
MIDOUT_END();
}
}
}
}
}
#endif
void ResizeImpl::kern_nearest_upsample2_pack_simd_width(
_megdnn_tensor_in src, _megdnn_tensor_out dst) {
const uint8_t* src_ptr = reinterpret_cast<uint8_t*>(src.raw_ptr);
uint8_t* dst_ptr = reinterpret_cast<uint8_t*>(dst.raw_ptr);
size_t S = 2;
size_t N = src.layout.shape[0];
size_t IC = src.layout.shape[1];
size_t IH = src.layout.shape[2];
size_t IW = src.layout.shape[3];
size_t OH = dst.layout.shape[2];
size_t OW = dst.layout.shape[3];
for (size_t i = 0; i < N * IC; ++i) {
for (size_t ih = 0; ih < IH; ++ih) {
for (size_t iw = 0; iw < IW; ++iw) {
size_t oh = ih * S;
size_t ow = iw * S;
uint8x16_t r0 = vld1q_u8(src_ptr + i * IH * IW * 16 +
ih * IW * 16 + iw * 16);
for (size_t fh = 0; fh < S; ++fh) {
for (size_t fw = 0; fw < S; ++fw) {
vst1q_u8(dst_ptr + i * OH * OW * 16 +
(oh + fh) * OW * 16 + (ow + fw) * 16,
r0);
} else if (is_nchw) {
if (is_upsample2) {
if (is_imode_nearest) {
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(10)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_nearest_upsample2_nchw_fp16(kern_param));
}
MIDOUT_END();
} else {
megdnn_assert(is_imode_linear, "invalid imode");
MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(11)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
resize_linear_upsample2_nchw_fp16(kern_param));
}
MIDOUT_END();
}
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
#endif
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -26,16 +26,6 @@ public:
const TensorLayout&) override {
return 0;
}
private:
template <typename ctype>
void kern_nchw44_fp32(const KernParam<ctype>& kern_param);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename ctype>
void kern_nchw88_fp16(const KernParam<ctype>& kern_param);
#endif
void kern_nearest_upsample2_pack_simd_width(_megdnn_tensor_in src,
_megdnn_tensor_out dst);
};
} // namespace arm_common
......
/**
* \file dnn/src/arm_common/resize/upsample2_nchw.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/resize/upsample2_nchw.h"
#include "src/arm_common/resize/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using namespace megdnn;
using namespace arm_common;
using namespace resize;
namespace {
template <typename ctype, size_t fh, size_t fw>
static inline ctype compute_linear_element(const ctype src[4],
const ctype alpha[2]) {
return src[0] * alpha[0 ^ fh] * alpha[0 ^ fw] +
src[1] * alpha[0 ^ fh] * alpha[1 ^ fw] +
src[2] * alpha[1 ^ fh] * alpha[0 ^ fw] +
src[3] * alpha[1 ^ fh] * alpha[1 ^ fw];
}
template <typename simd_helper, size_t fh, size_t fw>
static inline typename simd_helper::simd_type compute_linear_element_simd(
const typename simd_helper::simd_type src[4],
const typename simd_helper::simd_type alpha[2][2]) {
typename simd_helper::simd_type c = simd_helper::dup(0);
c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]);
c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]);
c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]);
c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]);
return c;
}
template <typename ctype, bool has_right, bool has_bottom>
static inline void compute_linear_2x2_element(const ctype* src, ctype* dst,
size_t IW, size_t OW,
const ctype alpha[2]) {
const ctype* src_ptr[4] = {src, src, src, src};
if (has_right) {
src_ptr[1] += 1;
src_ptr[3] += 1;
}
if (has_bottom) {
src_ptr[2] += IW;
src_ptr[3] += IW;
}
ctype rsrc[4];
rsrc[0] = *src_ptr[0];
rsrc[1] = *src_ptr[1];
rsrc[2] = *src_ptr[2];
rsrc[3] = *src_ptr[3];
dst[0] = compute_linear_element<ctype, 0, 0>(rsrc, alpha);
if (has_right) {
dst[1] = compute_linear_element<ctype, 0, 1>(rsrc, alpha);
}
if (has_bottom) {
dst[OW] = compute_linear_element<ctype, 1, 0>(rsrc, alpha);
}
if (has_right && has_bottom) {
dst[OW + 1] = compute_linear_element<ctype, 1, 1>(rsrc, alpha);
}
}
template <typename simd_helper>
static inline void compute_linear_2x2_element_simd(
const typename simd_helper::ctype* src,
typename simd_helper::ctype* dst, size_t IW, size_t OW,
const typename simd_helper::simd_type alpha[2][2]) {
using simd_type = typename simd_helper::simd_type;
simd_type rsrc[4];
rsrc[0] = simd_helper::load(src);
rsrc[1] = simd_helper::load(src + 1);
rsrc[2] = simd_helper::load(src + IW);
rsrc[3] = simd_helper::load(src + IW + 1);
simd_type rdst[4];
rdst[0] = compute_linear_element_simd<simd_helper, 0, 0>(rsrc, alpha);
rdst[1] = compute_linear_element_simd<simd_helper, 0, 1>(rsrc, alpha);
rdst[2] = compute_linear_element_simd<simd_helper, 1, 0>(rsrc, alpha);
rdst[3] = compute_linear_element_simd<simd_helper, 1, 1>(rsrc, alpha);
simd_helper::store2_interleave(dst, rdst[0], rdst[1]);
simd_helper::store2_interleave(dst + OW, rdst[2], rdst[3]);
}
template <typename ctype>
void linear_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N,
size_t IH, size_t IW) {
using simd_helper = SIMDHelper<ctype>;
size_t OW = IW * 2;
constexpr size_t PC = simd_helper::simd_width;
ctype alpha[2] = {0.75, 0.25};
typename simd_helper::simd_type simd_alpha[2][2];
simd_alpha[0][0] = simd_helper::dup(0.75 * 0.75);
simd_alpha[0][1] = simd_helper::dup(0.75 * 0.25);
simd_alpha[1][0] = simd_helper::dup(0.25 * 0.75);
simd_alpha[1][1] = simd_helper::dup(0.25 * 0.25);
for (size_t i = 0; i < N; ++i) {
compute_linear_2x2_element<ctype, false, false>(src_ptr, dst_ptr, IW,
OW, alpha);
{
for (size_t iw = 0; iw + 1 < IW; ++iw) {
compute_linear_2x2_element<ctype, true, false>(
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha);
}
}
compute_linear_2x2_element<ctype, false, false>(
src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha);
dst_ptr += OW;
for (size_t ih = 0; ih + 1 < IH; ++ih) {
compute_linear_2x2_element<ctype, false, true>(src_ptr, dst_ptr, IW,
OW, alpha);
size_t iw = 0;
for (; iw + PC < IW; iw += PC) {
compute_linear_2x2_element_simd<simd_helper>(
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW,
simd_alpha);
}
for (; iw + 1 < IW; ++iw) {
compute_linear_2x2_element<ctype, true, true>(
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha);
}
compute_linear_2x2_element<ctype, false, true>(
src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha);
src_ptr += IW;
dst_ptr += 2 * OW;
}
compute_linear_2x2_element<ctype, false, false>(src_ptr, dst_ptr, IW,
OW, alpha);
{
for (size_t iw = 0; iw + 1 < IW; ++iw) {
compute_linear_2x2_element<ctype, true, false>(
src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha);
}
}
compute_linear_2x2_element<ctype, false, false>(
src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha);
src_ptr += IW;
dst_ptr += OW;
}
}
template <typename ctype>
void nearest_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N,
size_t IH, size_t IW) {
using simd_helper = SIMDHelper<ctype>;
size_t OW = IW * 2;
constexpr size_t PC = simd_helper::simd_width;
for (size_t i = 0; i < N; ++i) {
for (size_t ih = 0; ih < IH; ++ih) {
size_t iw = 0;
for (; iw + PC - 1 < IW; iw += PC) {
typename simd_helper::simd_type r0 =
simd_helper::load(src_ptr + iw);
simd_helper::store2_interleave(dst_ptr + (iw * 2), r0, r0);
simd_helper::store2_interleave(dst_ptr + (OW + iw * 2), r0, r0);
}
for (; iw < IW; iw += 1) {
ctype v = src_ptr[iw];
dst_ptr[iw * 2] = v;
dst_ptr[iw * 2 + 1] = v;
dst_ptr[OW + iw * 2] = v;
dst_ptr[OW + iw * 2 + 1] = v;
}
src_ptr += IW;
dst_ptr += 2 * OW;
}
}
}
} // namespace
void megdnn::arm_common::resize_linear_upsample2_nchw_fp32(
const ResizeImpl::KernParam<float>& kern_param) {
linear_upsample2_nchw(kern_param.sptr, kern_param.dptr,
kern_param.n * kern_param.c, kern_param.ih,
kern_param.iw);
}
void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32(
const ResizeImpl::KernParam<float>& kern_param) {
nearest_upsample2_nchw(kern_param.sptr, kern_param.dptr,
kern_param.n * kern_param.c, kern_param.ih,
kern_param.iw);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void megdnn::arm_common::resize_linear_upsample2_nchw_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param) {
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr);
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr);
linear_upsample2_nchw(sptr, dptr, kern_param.n * kern_param.c,
kern_param.ih, kern_param.iw);
}
void megdnn::arm_common::resize_nearest_upsample2_nchw_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param) {
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr);
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr);
nearest_upsample2_nchw(sptr, dptr, kern_param.n * kern_param.c,
kern_param.ih, kern_param.iw);
}
#endif
/**
* \file dnn/src/arm_common/resize/upsample2_nchw.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/resize/opr_impl.h"
namespace megdnn {
namespace arm_common {
void resize_linear_upsample2_nchw_fp32(
const ResizeImpl::KernParam<float>& kern_param);
void resize_nearest_upsample2_nchw_fp32(
const ResizeImpl::KernParam<float>& kern_param);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void resize_linear_upsample2_nchw_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param);
void resize_nearest_upsample2_nchw_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param);
#endif
} // namespace arm_common
} // namespace megdnn
/**
* \file dnn/src/arm_common/resize/upsample2_nchwxx.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/resize/upsample2_nchwxx.h"
#include "src/arm_common/resize/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using namespace megdnn;
using namespace arm_common;
using namespace resize;
namespace {
template <typename simd_helper, size_t fh, size_t fw>
static inline typename simd_helper::simd_type compute_linear_element(
const typename simd_helper::simd_type src[4],
const typename simd_helper::simd_type alpha[2][2]) {
typename simd_helper::simd_type c = simd_helper::dup(0);
c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]);
c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]);
c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]);
c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]);
return c;
}
template <typename simd_helper, bool has_right, bool has_bottom>
static inline void compute_linear_2x2_element(
const typename simd_helper::ctype* src,
typename simd_helper::ctype* dst, size_t IW, size_t OW,
const typename simd_helper::simd_type alpha[2][2]) {
constexpr size_t PC = simd_helper::simd_width;
const typename simd_helper::ctype* src_ptr[4] = {src, src, src, src};
if (has_right) {
src_ptr[1] += PC;
src_ptr[3] += PC;
}
if (has_bottom) {
src_ptr[2] += IW * PC;
src_ptr[3] += IW * PC;
}
typename simd_helper::simd_type rsrc[4];
rsrc[0] = simd_helper::load(src_ptr[0]);
rsrc[1] = simd_helper::load(src_ptr[1]);
rsrc[2] = simd_helper::load(src_ptr[2]);
rsrc[3] = simd_helper::load(src_ptr[3]);
typename simd_helper::simd_type rdst[4];
rdst[0] = compute_linear_element<simd_helper, 0, 0>(rsrc, alpha);
rdst[1] = compute_linear_element<simd_helper, 0, 1>(rsrc, alpha);
rdst[2] = compute_linear_element<simd_helper, 1, 0>(rsrc, alpha);
rdst[3] = compute_linear_element<simd_helper, 1, 1>(rsrc, alpha);
simd_helper::store(dst, rdst[0]);
if (has_right) {
simd_helper::store(dst + PC, rdst[1]);
}
if (has_bottom) {
simd_helper::store(dst + OW * PC, rdst[2]);
}
if (has_right && has_bottom) {
simd_helper::store(dst + (OW + 1) * PC, rdst[3]);
}
}
template <typename ctype>
void linear_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N,
size_t IH, size_t IW) {
using simd_helper = SIMDHelper<ctype>;
size_t OW = IW * 2;
constexpr size_t PC = simd_helper::simd_width;
typename simd_helper::simd_type alpha[2][2];
alpha[0][0] = simd_helper::dup(0.75 * 0.75);
alpha[0][1] = simd_helper::dup(0.75 * 0.25);
alpha[1][0] = simd_helper::dup(0.25 * 0.75);
alpha[1][1] = simd_helper::dup(0.25 * 0.25);
for (size_t i = 0; i < N; ++i) {
compute_linear_2x2_element<simd_helper, false, false>(src_ptr, dst_ptr,
IW, OW, alpha);
{
for (size_t iw = 0; iw + 1 < IW; ++iw) {
compute_linear_2x2_element<simd_helper, true, false>(
src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW,
alpha);
}
}
compute_linear_2x2_element<simd_helper, false, false>(
src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW,
alpha);
dst_ptr += OW * PC;
for (size_t ih = 0; ih + 1 < IH; ++ih) {
compute_linear_2x2_element<simd_helper, false, true>(
src_ptr, dst_ptr, IW, OW, alpha);
for (size_t iw = 0; iw + 1 < IW; ++iw) {
compute_linear_2x2_element<simd_helper, true, true>(
src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW,
alpha);
}
compute_linear_2x2_element<simd_helper, false, true>(
src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW,
alpha);
src_ptr += IW * PC;
dst_ptr += 2 * OW * PC;
}
compute_linear_2x2_element<simd_helper, false, false>(src_ptr, dst_ptr,
IW, OW, alpha);
{
for (size_t iw = 0; iw + 1 < IW; ++iw) {
compute_linear_2x2_element<simd_helper, true, false>(
src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW,
alpha);
}
}
compute_linear_2x2_element<simd_helper, false, false>(
src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW,
alpha);
src_ptr += IW * PC;
dst_ptr += OW * PC;
}
}
template <typename ctype>
void nearest_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N,
size_t IH, size_t IW) {
using simd_helper = SIMDHelper<ctype>;
size_t OW = IW * 2;
constexpr size_t PC = simd_helper::simd_width;
for (size_t i = 0; i < N; ++i) {
for (size_t ih = 0; ih < IH; ++ih) {
for (size_t iw = 0; iw < IW; ++iw) {
typename simd_helper::simd_type r0 =
simd_helper::load(src_ptr + iw * PC);
simd_helper::store(dst_ptr + (iw * 2) * PC, r0);
simd_helper::store(dst_ptr + (iw * 2 + 1) * PC, r0);
simd_helper::store(dst_ptr + (OW + iw * 2) * PC, r0);
simd_helper::store(dst_ptr + (OW + iw * 2 + 1) * PC, r0);
}
src_ptr += IW * PC;
dst_ptr += 2 * OW * PC;
}
}
}
} // namespace
void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param) {
linear_upsample2_nchwxx(kern_param.sptr, kern_param.dptr,
kern_param.n * kern_param.c / 4, kern_param.ih,
kern_param.iw);
}
void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param) {
nearest_upsample2_nchwxx(kern_param.sptr, kern_param.dptr,
kern_param.n * kern_param.c / 4, kern_param.ih,
kern_param.iw);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param) {
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr);
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr);
linear_upsample2_nchwxx(sptr, dptr, kern_param.n * kern_param.c / 8,
kern_param.ih, kern_param.iw);
}
void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param) {
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr);
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr);
nearest_upsample2_nchwxx(sptr, dptr, kern_param.n * kern_param.c / 8,
kern_param.ih, kern_param.iw);
}
#endif
/**
* \file dnn/src/arm_common/resize/upsample2_nchwxx.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/resize/opr_impl.h"
namespace megdnn {
namespace arm_common {
void resize_linear_upsample2_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param);
void resize_nearest_upsample2_nchw44_fp32(
const ResizeImpl::KernParam<float>& kern_param);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void resize_linear_upsample2_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param);
void resize_nearest_upsample2_nchw88_fp16(
const ResizeImpl::KernParam<dt_float16>& kern_param);
#endif
} // namespace arm_common
} // namespace megdnn
......@@ -16,8 +16,25 @@
namespace megdnn {
namespace test {
using namespace resize;
static void set_nchw_args(IMode imode, std::vector<TestArg>& args) {
param::Resize param;
param.format = param::Resize::Format::NCHW;
param.imode = imode;
rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
args.emplace_back(
param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul},
TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul});
args.emplace_back(param, TensorShape{1, 1, 10, 10},
TensorShape{1, 1, 20, 20});
args.emplace_back(param, TensorShape{1, 1, 10, 10},
TensorShape{1, 1, 7, 9});
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
}
TEST_F(ARM_COMMON, RESIZE_CV) {
using namespace resize;
std::vector<TestArg> args = get_cv_args();
Checker<Resize> checker(handle());
......@@ -37,8 +54,38 @@ TEST_F(ARM_COMMON, RESIZE_CV) {
}
}
TEST_F(ARM_COMMON, RESIZE_NCHW44) {
using namespace resize;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON, RESIZE_NCHW_FP16) {
std::vector<TestArg> args;
set_nchw_args(IMode::INTER_LINEAR, args);
set_nchw_args(IMode::INTER_NEAREST, args);
Checker<Resize> checker(handle());
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_epsilon(0.01)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.execs({arg.src, arg.dst});
}
}
#endif
TEST_F(ARM_COMMON, RESIZE_NCHW_FP32) {
std::vector<TestArg> args;
set_nchw_args(IMode::INTER_LINEAR, args);
set_nchw_args(IMode::INTER_NEAREST, args);
Checker<Resize> checker(handle());
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.execs({arg.src, arg.dst});
}
}
TEST_F(ARM_COMMON, RESIZE_NCHW44_FP32) {
std::vector<TestArg> args = get_nchw44_args();
Checker<Resize> checker(handle());
......@@ -50,8 +97,8 @@ TEST_F(ARM_COMMON, RESIZE_NCHW44) {
}
}
TEST_F(ARM_COMMON, RESIZE_NCHW88) {
using namespace resize;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON, RESIZE_NCHW88_FP16) {
std::vector<TestArg> args = get_nchw88_args();
Checker<Resize> checker(handle());
......@@ -63,6 +110,7 @@ TEST_F(ARM_COMMON, RESIZE_NCHW88) {
.execs({arg.src, arg.dst});
}
}
#endif
} // namespace test
} // namespace megdnn
......
......@@ -52,6 +52,7 @@ TEST_F(CUDA, RESIZE_FORWARD) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Uint8())
.set_epsilon(1)
.execs({arg.src, arg.dst});
}
......@@ -67,7 +68,7 @@ TEST_F(CUDA, RESIZE_FORWARD) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_epsilon(1e-3)
.set_epsilon(1)
.execs({arg.src, arg.dst});
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册