提交 c96dbd29 编写于 作者: M Megvii Engine Team 提交者: XindaH

fix(dnn/arm_common): support more monotonous case in arm typecvt for performance

GitOrigin-RevId: 9e28a64d93799a04501910f5a05a988715ff4b04
上级 695d24f2
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "src/arm_common/type_cvt/opr_impl.h" #include "src/arm_common/type_cvt/opr_impl.h"
#include <cstring> #include <cstring>
#include <deque>
#include "midout.h" #include "midout.h"
#include "src/arm_common/quantized_converter.h" #include "src/arm_common/quantized_converter.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
...@@ -18,6 +19,7 @@ ...@@ -18,6 +19,7 @@
#include "src/naive/handle.h" #include "src/naive/handle.h"
MIDOUT_DECL(megdnn_arm_typecvt_fix2float) MIDOUT_DECL(megdnn_arm_typecvt_fix2float)
MIDOUT_DECL(megdnn_arm_typecvt_quan2float)
MIDOUT_DECL(megdnn_arm_typecvt_quantized) MIDOUT_DECL(megdnn_arm_typecvt_quantized)
MIDOUT_DECL(megdnn_arm_typecvt_float) MIDOUT_DECL(megdnn_arm_typecvt_float)
...@@ -326,8 +328,34 @@ struct FloatTypeCvter<float, __fp16> { ...@@ -326,8 +328,34 @@ struct FloatTypeCvter<float, __fp16> {
}; };
#endif #endif
template <typename TypeCvter>
void do_typecvt(
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst,
DType src_dtype, DType dst_dtype, size_t nr_elems) {
TypeCvter typecvt(src_dtype, dst_dtype);
size_t i = 0;
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) {
typecvt.cvt(src, dst);
src += TypeCvter::SIMD_WIDTH;
dst += TypeCvter::SIMD_WIDTH;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < nr_elems; i++) {
typecvt.cvt_remain(src, dst);
src++;
dst++;
}
}
template <typename ctype, typename dtype> template <typename ctype, typename dtype>
struct Fix2FloatTypeCvter; struct Fix2FloatTypeCvter;
template <typename ctype, typename dtype>
struct Quan2FloatTypeCvter;
template <> template <>
struct Fix2FloatTypeCvter<int16_t, float> { struct Fix2FloatTypeCvter<int16_t, float> {
using stype = int16_t; using stype = int16_t;
...@@ -368,62 +396,184 @@ struct Fix2FloatTypeCvter<uint16_t, float> { ...@@ -368,62 +396,184 @@ struct Fix2FloatTypeCvter<uint16_t, float> {
void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; } void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; }
}; };
template <typename TypeCvter> template <>
void do_typecvt( struct Fix2FloatTypeCvter<int8_t, float> {
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, using stype = int8_t;
DType src_dtype, DType dst_dtype, size_t nr_elems) { using dst_type = float;
TypeCvter typecvt(src_dtype, dst_dtype); static constexpr size_t SIMD_WIDTH = 16;
size_t i = 0;
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
typecvt.cvt(src, dst); MEGDNN_MARK_USED_VAR(src_dtype);
src += TypeCvter::SIMD_WIDTH; MEGDNN_MARK_USED_VAR(dst_dtype);
dst += TypeCvter::SIMD_WIDTH;
} }
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize void cvt(const int8_t* src, float* dst) {
#pragma clang loop vectorize(disable) int8x16_t vitem = vld1q_s8(src);
int16x8_t vtrans_high = vmovl_s8(vget_high_s8(vitem));
int16x8_t vtrans_low = vmovl_s8(vget_low_s8(vitem));
auto vres_high = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_high);
auto vres_low = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_low);
vst1q_f32_x2(dst, vres_low);
vst1q_f32_x2(dst + 8, vres_high);
}
void cvt_remain(const int8_t* src, float* dst) { *dst = *src; }
};
template <>
struct Fix2FloatTypeCvter<uint8_t, float> {
using stype = uint8_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = 16;
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
MEGDNN_MARK_USED_VAR(src_dtype);
MEGDNN_MARK_USED_VAR(dst_dtype);
}
void cvt(const uint8_t* src, float* dst) {
uint8x16_t vitem = vld1q_u8(src);
uint16x8_t vtrans_high = vmovl_u8(vget_high_u8(vitem));
uint16x8_t vtrans_low = vmovl_u8(vget_low_u8(vitem));
auto vres_high = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_high);
auto vres_low = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_low);
vst1q_f32_x2(dst, vres_low);
vst1q_f32_x2(dst + 8, vres_high);
}
void cvt_remain(const uint8_t* src, float* dst) { *dst = *src; }
};
template <>
struct Quan2FloatTypeCvter<int8_t, float> {
using stype = int8_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = 16;
float _scale = 0.0f;
float32x4_t _vscale;
Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
_scale = src_dtype.param<dtype::QuantizedS8>().scale;
_vscale = vdupq_n_f32(_scale);
MEGDNN_MARK_USED_VAR(dst_dtype);
}
void cvt(const int8_t* src, float* dst) {
int8x16_t vitem = vld1q_s8(src);
int16x8_t vtrans_high = vmovl_s8(vget_high_s8(vitem));
int16x8_t vtrans_low = vmovl_s8(vget_low_s8(vitem));
auto vres_high = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_high);
auto vres_low = QConverter::convert<float32x4x2_t, int16x8_t>(vtrans_low);
vst1q_f32(dst, vmulq_f32(vres_low.val[0], _vscale));
vst1q_f32(dst + 4, vmulq_f32(vres_low.val[1], _vscale));
vst1q_f32(dst + 8, vmulq_f32(vres_high.val[0], _vscale));
vst1q_f32(dst + 12, vmulq_f32(vres_high.val[1], _vscale));
}
void cvt_remain(const int8_t* src, float* dst) { *dst = *src * _scale; }
};
#if defined(__ARM_FEATURE_FMA)
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m)
#else
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m)
#endif #endif
for (; i < nr_elems; i++) {
typecvt.cvt_remain(src, dst); template <>
src++; struct Quan2FloatTypeCvter<uint8_t, float> {
dst++; using stype = uint8_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = 16;
float _scale = 0.0f;
float32x4_t _vscale;
uint8_t _zp = 0;
float32x4_t _vbias;
Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
_scale = src_dtype.param<dtype::Quantized8Asymm>().scale;
_vscale = vdupq_n_f32(_scale);
_zp = src_dtype.param<dtype::Quantized8Asymm>().zero_point;
float bias = -_zp * 1.0f * _scale;
_vbias = vdupq_n_f32(bias);
MEGDNN_MARK_USED_VAR(dst_dtype);
} }
}
void cvt(const uint8_t* src, float* dst) {
uint8x16_t vitem = vld1q_u8(src);
uint16x8_t vtrans_high = vmovl_u8(vget_high_u8(vitem));
uint16x8_t vtrans_low = vmovl_u8(vget_low_u8(vitem));
auto vres_high = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_high);
auto vres_low = QConverter::convert<float32x4x2_t, uint16x8_t>(vtrans_low);
vst1q_f32(dst, Vfmaq_f32(_vbias, vres_low.val[0], _vscale));
vst1q_f32(dst + 4, Vfmaq_f32(_vbias, vres_low.val[1], _vscale));
vst1q_f32(dst + 8, Vfmaq_f32(_vbias, vres_high.val[0], _vscale));
vst1q_f32(dst + 12, Vfmaq_f32(_vbias, vres_high.val[1], _vscale));
}
void cvt_remain(const uint8_t* src, float* dst) { *dst = (*src - _zp) * _scale; }
};
#undef Vfmaq_f32
template <typename stype, typename dtype>
struct TypeCvtTask {
const stype* src;
dtype* dst;
size_t dim;
size_t nr_elems;
explicit TypeCvtTask(const stype* s, dtype* d, size_t n, size_t tot)
: src(s), dst(d), dim(n), nr_elems(tot) {}
};
template <typename TypeCvter> template <typename TypeCvter>
void do_typecvt( void do_typecvt(
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst,
DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) { DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) {
TypeCvter typecvt(src_dtype, dst_dtype); TypeCvter typecvt(src_dtype, dst_dtype);
size_t calc_num = 1;
size_t nr_elems = src_layout.total_nr_elems();
size_t src_stride = nr_elems;
//! adjust calc_num nr_elems and src_stride according to src_collapse_layout
auto src_collapse_layout = src_layout.collapse_contiguous(); auto src_collapse_layout = src_layout.collapse_contiguous();
if (src_collapse_layout.ndim == 2) {
calc_num = src_collapse_layout.shape[0]; using TypeCvtTaskWithType =
nr_elems = src_collapse_layout.shape[1]; TypeCvtTask<typename TypeCvter::stype, typename TypeCvter::dst_type>;
src_stride = src_collapse_layout.stride[0]; std::deque<TypeCvtTaskWithType> task_queue;
} task_queue.emplace_back(src, dst, 0, src_collapse_layout.total_nr_elems());
for (size_t c = 0; c < calc_num; ++c) { while (!task_queue.empty()) {
size_t i = 0; auto&& task = task_queue.front();
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { const typename TypeCvter::stype* psrc = task.src;
typecvt.cvt(src, dst); typename TypeCvter::dst_type* pdst = task.dst;
src += TypeCvter::SIMD_WIDTH; size_t dim = task.dim;
dst += TypeCvter::SIMD_WIDTH; size_t nr_elems = task.nr_elems;
}
//! calc according to stride information
if (src_collapse_layout.stride[dim] == 1) {
size_t i = 0;
for (; i + TypeCvter::SIMD_WIDTH < nr_elems; i += TypeCvter::SIMD_WIDTH) {
typecvt.cvt(psrc, pdst);
psrc += TypeCvter::SIMD_WIDTH;
pdst += TypeCvter::SIMD_WIDTH;
}
#if MEGDNN_FIX_AARCH32_BUG #if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize // FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable) #pragma clang loop vectorize(disable)
#endif #endif
for (; i < nr_elems; i++) { for (; i < nr_elems; i++) {
typecvt.cvt_remain(src, dst); typecvt.cvt_remain(psrc, pdst);
src++; psrc++;
dst++; pdst++;
}
} else {
size_t calc_num = src_collapse_layout.shape[dim];
size_t src_stride = src_collapse_layout.stride[dim];
size_t dst_stride = nr_elems / calc_num;
for (size_t i = 0; i < calc_num; ++i) {
task_queue.emplace_back(psrc, pdst, dim + 1, dst_stride);
psrc += src_stride;
pdst += dst_stride;
}
} }
src += src_stride - nr_elems;
task_queue.pop_front();
} }
} }
...@@ -432,15 +582,56 @@ void do_typecvt( ...@@ -432,15 +582,56 @@ void do_typecvt(
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
DType src_dtype = src.layout.dtype; DType src_dtype = src.layout.dtype;
DType dst_dtype = dst.layout.dtype; DType dst_dtype = dst.layout.dtype;
size_t nr_elems = src.layout.total_nr_elems();
bool execed = false; bool execed = false;
auto src_collapse_layout = src.layout.collapse_contiguous(); auto src_collapse_layout = src.layout.collapse_contiguous();
bool has_int16_special_impl =
(src.layout.dtype.enumv() == DTypeEnum::Int16 || if (src.layout.is_contiguous()) {
src.layout.dtype.enumv() == DTypeEnum::Uint16) && using namespace dtype;
(src.layout.is_contiguous() || src_collapse_layout.ndim == 2) && size_t nr_elems = src.layout.total_nr_elems();
dst.layout.is_contiguous(); #define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (has_int16_special_impl) { if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_arm_typecvt_quantized, midout_iv(_midout_iv)) { \
using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, nr_elems)); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0);
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1);
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2);
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS8, int8_t, 3);
DISPATCH_QUANTIZED(Quantized8Asymm, uint8_t, Quantized8Asymm, uint8_t, 4);
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5);
DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6);
DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7);
#undef DISPATCH_QUANTIZED
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \
using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
reinterpret_cast<_stype*>(src.raw_ptr()), \
reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \
nr_elems)); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0);
DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1);
#undef DISPATCH_FLOAT
#endif
}
size_t last_stride = src_collapse_layout.stride[src_collapse_layout.ndim - 1];
if (!execed && last_stride == 1 && dst.layout.is_contiguous()) {
using namespace dtype; using namespace dtype;
#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ #define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
...@@ -456,9 +647,26 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -456,9 +647,26 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
} }
DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0); DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0);
DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1); DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1);
DISPATCH_FIX2FLOAT(Int8, int8_t, Float32, float, 2);
DISPATCH_FIX2FLOAT(Uint8, uint8_t, Float32, float, 3);
#undef DISPATCH_FIX2FLOAT #undef DISPATCH_FIX2FLOAT
} else if (src.layout.is_contiguous()) {
using namespace dtype; #define DISPATCH_QUAN2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_arm_typecvt_quan2float, midout_iv(_midout_iv)) { \
using _TypeCvter = Quan2FloatTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, src.layout)); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_QUAN2FLOAT(QuantizedS8, int8_t, Float32, float, 0);
DISPATCH_QUAN2FLOAT(Quantized8Asymm, uint8_t, Float32, float, 1);
#undef DISPATCH_QUAN2FLOAT
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ #define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
...@@ -466,12 +674,11 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -466,12 +674,11 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, nr_elems)); \ src_dtype, dst_dtype, src.layout)); \
execed = true; \ execed = true; \
} \ } \
MIDOUT_END(); \ MIDOUT_END(); \
} }
DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0); DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0);
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1);
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2);
...@@ -491,7 +698,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -491,7 +698,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
reinterpret_cast<_stype*>(src.raw_ptr()), \ reinterpret_cast<_stype*>(src.raw_ptr()), \
reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \ reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \
nr_elems)); \ src.layout)); \
execed = true; \ execed = true; \
} \ } \
MIDOUT_END(); \ MIDOUT_END(); \
...@@ -501,6 +708,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -501,6 +708,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
#undef DISPATCH_FLOAT #undef DISPATCH_FLOAT
#endif #endif
} }
if (!execed) { if (!execed) {
fallback::TypeCvtImpl::exec(src, dst); fallback::TypeCvtImpl::exec(src, dst);
} }
......
...@@ -161,24 +161,251 @@ TEST_F(ARM_COMMON, TYPE_CVT_RECORD) { ...@@ -161,24 +161,251 @@ TEST_F(ARM_COMMON, TYPE_CVT_RECORD) {
.execs({{1, 32, 24, 128}, {1, 32, 24, 128}}); .execs({{1, 32, 24, 128}, {1, 32, 24, 128}});
} }
TEST_F(ARM_COMMON, TYPE_CVT_16_F32) { TEST_F(ARM_COMMON, TYPE_CVT_NONCONTIGUOUS) {
UniformIntRNG rng32{INT32_MIN >> 1, INT32_MAX >> 1};
UniformIntRNG rng16{INT16_MIN >> 1, INT16_MAX >> 1};
UniformIntRNG rng8{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<TypeCvt> checker(handle()); Checker<TypeCvt> checker(handle());
UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1};
size_t N = 1;
size_t C = 96;
size_t H = 64;
size_t W = 120;
TensorShape shape{N, C, H, W};
std::vector<ptrdiff_t> stride{
static_cast<long>(C * H * (W + 8)), static_cast<long>(H * (W + 8)),
static_cast<long>(W + 8), 1};
TensorLayout src, dst;
//! float32 -> float16
src = TensorLayout{shape, stride, dtype::Float32()};
dst = TensorLayout{shape, dtype::Float16()};
checker.execl({src, dst});
//! float16 -> float32
src = TensorLayout{shape, stride, dtype::Float16()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
//! float -> s8
src = TensorLayout{shape, stride, dtype::Float32()};
dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)};
checker.execl({src, dst});
//! float -> as8
src = TensorLayout{shape, stride, dtype::Float32()};
dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
checker.set_rng(0, &rng32);
//! s32 -> as8
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0000113264f)};
dst = TensorLayout{
shape, dtype::Quantized8Asymm(0.018909f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)};
dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
//! s32 -> s8
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000815917f)};
dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)};
dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)};
checker.execl({src, dst});
checker.set_rng(0, &rng8);
//! s32 -> s32
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0004f)};
dst = TensorLayout{shape, dtype::QuantizedS32(0.0002f)};
checker.execl({src, dst});
//! s8 -> s8
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)};
dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)};
checker.execl({src, dst});
//! as8 -> as8
src = TensorLayout{
shape, stride, dtype::Quantized8Asymm(0.3f, static_cast<uint8_t>(8))};
dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
//! s8 -> s32
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.245121f)};
dst = TensorLayout{shape, dtype::QuantizedS32(0.000815917f)};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.2f)};
dst = TensorLayout{shape, dtype::QuantizedS32(0.0003f)};
checker.execl({src, dst});
//! s8 -> float
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
//! as8 -> float
src = TensorLayout{
shape, stride, dtype::Quantized8Asymm(0.3f, static_cast<uint8_t>(8))};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
//! int8/uint8 -> float
src = TensorLayout{shape, stride, dtype::Int8()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::Uint8()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
//! int16/uint16 -> float
checker.set_rng(0, &rng16);
for (size_t size : {3, 7, 15, 33, 10000}) { for (size_t size : {3, 7, 15, 33, 10000}) {
checker.set_rng(0, &rng);
checker.set_dtype(0, dtype::Int16()).execs({{size}, {size}}); checker.set_dtype(0, dtype::Int16()).execs({{size}, {size}});
checker.set_dtype(0, dtype::Uint16()).execs({{size}, {size}}); checker.set_dtype(0, dtype::Uint16()).execs({{size}, {size}});
} }
TensorLayout src_int16{
{1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Int16()}; src = TensorLayout{shape, stride, dtype::Int16()};
TensorLayout dst_int16{{1, 96, 64, 120}, dtype::Float32()}; dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src_int16, dst_int16}); checker.execl({src, dst});
TensorLayout src_uint16{ src = TensorLayout{shape, stride, dtype::Uint16()};
{1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Uint16()}; dst = TensorLayout{shape, dtype::Float32()};
TensorLayout dst_uint16{{1, 96, 64, 120}, dtype::Float32()}; checker.execl({src, dst});
checker.execl({src_uint16, dst_uint16});
UniformIntRNG narrow_rng{-40000, 40000};
checker.set_rng(0, &narrow_rng);
//! s32 -> as8
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000163794f)};
dst = TensorLayout{
shape, dtype::Quantized8Asymm(0.0479196f, static_cast<uint8_t>(144))};
checker.execl({src, dst});
}
TEST_F(ARM_COMMON, TYPE_CVT_MONOTONOUS) {
UniformIntRNG rng32{INT32_MIN >> 1, INT32_MAX >> 1};
UniformIntRNG rng16{INT16_MIN >> 1, INT16_MAX >> 1};
UniformIntRNG rng8{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<TypeCvt> checker(handle());
size_t N = 1;
size_t C = 96;
size_t H = 64;
size_t W = 120;
TensorShape shape{N, C, H, W};
std::vector<ptrdiff_t> stride{
static_cast<long>((C + 8) * (H + 8) * (W + 8)),
static_cast<long>((H + 8) * (W + 8)), static_cast<long>(W + 8), 1};
TensorLayout src, dst;
//! float32 -> float16
src = TensorLayout{shape, stride, dtype::Float32()};
dst = TensorLayout{shape, dtype::Float16()};
checker.execl({src, dst});
//! float16 -> float32
src = TensorLayout{shape, stride, dtype::Float16()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
//! float -> s8
src = TensorLayout{shape, stride, dtype::Float32()};
dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)};
checker.execl({src, dst});
//! float -> as8
src = TensorLayout{shape, stride, dtype::Float32()};
dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
checker.set_rng(0, &rng32);
//! s32 -> as8
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0000113264f)};
dst = TensorLayout{
shape, dtype::Quantized8Asymm(0.018909f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)};
dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
//! s32 -> s8
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000815917f)};
dst = TensorLayout{shape, dtype::QuantizedS8(0.245121f)};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0003f)};
dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)};
checker.execl({src, dst});
checker.set_rng(0, &rng8);
//! s32 -> s32
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.0004f)};
dst = TensorLayout{shape, dtype::QuantizedS32(0.0002f)};
checker.execl({src, dst});
//! s8 -> s8
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)};
dst = TensorLayout{shape, dtype::QuantizedS8(0.2f)};
checker.execl({src, dst});
//! as8 -> as8
src = TensorLayout{
shape, stride, dtype::Quantized8Asymm(0.3f, static_cast<uint8_t>(8))};
dst = TensorLayout{shape, dtype::Quantized8Asymm(0.1f, static_cast<uint8_t>(3))};
checker.execl({src, dst});
//! s8 -> s32
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.245121f)};
dst = TensorLayout{shape, dtype::QuantizedS32(0.000815917f)};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.2f)};
dst = TensorLayout{shape, dtype::QuantizedS32(0.0003f)};
checker.execl({src, dst});
//! s8 -> float
src = TensorLayout{shape, stride, dtype::QuantizedS8(0.3f)};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
//! as8 -> float
src = TensorLayout{
shape, stride, dtype::Quantized8Asymm(0.3f, static_cast<uint8_t>(8))};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
//! int8/uint8 -> float
src = TensorLayout{shape, stride, dtype::Int8()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::Uint8()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::Int16()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
src = TensorLayout{shape, stride, dtype::Uint16()};
dst = TensorLayout{shape, dtype::Float32()};
checker.execl({src, dst});
UniformIntRNG narrow_rng{-40000, 40000};
checker.set_rng(0, &narrow_rng);
//! s32 -> as8
src = TensorLayout{shape, stride, dtype::QuantizedS32(0.000163794f)};
dst = TensorLayout{
shape, dtype::Quantized8Asymm(0.0479196f, static_cast<uint8_t>(144))};
checker.execl({src, dst});
} }
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册