提交 36ba1d6d 编写于 作者: M Megvii Engine Team

fix(riscv): fix ci fp16 build and move test GI_TEST_NAIVE by megdnn_gi_api_test

GitOrigin-RevId: e463855d925d6ea8eb2da82c2c911dde4fcb3d45
上级 dcce4610
......@@ -13,18 +13,6 @@ using BcastType = megdnn::elemwise::BcastType;
///////////////////////////////// ParamElemVistor ///////////////////////////
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, _neon_type_v2) \
template <> \
struct ParamElemVisitor<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vdupq_n_##_fun_suffix(*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}; \
template <> \
struct ParamElemVisitorV2<_ctype> { \
_neon_type_v2 operator()(const _ctype* src, const _ctype* src_1) const { \
......@@ -53,16 +41,7 @@ cb(__fp16, __fp16, float16x8_t, f16, float16x8x2_t);
cb(dt_int16, int16_t, int16x8_t, s16, int16x8x2_t);
#undef cb
template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix, _neon_type_v2) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \
reinterpret_cast<const _inner_ctype*>(src))); \
} \
}; \
template <> \
struct ParamElemVisitorBcast101x4V2<_ctype> { \
_neon_type_v2 operator()(const _ctype* src) const { \
......@@ -83,16 +62,20 @@ cb(__fp16, uint64_t, float16x8_t, f16, u64, float16x8x2_t);
#undef cb
template <typename ctype>
struct ParamElemVisitorBcast101x8;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x8<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
} \
struct ParamElemVisitorBcast101x8V2;
#define cb(_ctype, _inner_ctype, _neon_type_v2, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x8V2<_ctype> { \
_neon_type_v2 operator()(const _ctype* src) const { \
_neon_type_v2 ret; \
ret.val[0] = \
vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
ret.val[1] = ret.val[0]; \
return ret; \
} \
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, __fp16, float16x8_t, f16);
cb(__fp16, __fp16, float16x8x2_t, f16);
#endif
#undef cb
......@@ -106,8 +89,8 @@ struct OpCallerBinaryBcast101xXVec<__fp16, 8> {
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
ParamElemVisitorBcast101x8V2<src_ctype> vis0;
ParamElemVisitorV2<src_ctype> vis1;
OpCallerBinaryBcast101xDVec<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
......@@ -122,8 +105,8 @@ struct OpCallerBinaryVecBcast101xX<__fp16, 8> {
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
ParamElemVisitorV2<src_ctype> vis0;
ParamElemVisitorBcast101x8V2<src_ctype> vis1;
OpCallerBinaryVecBcast101xD<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
......@@ -138,9 +121,9 @@ struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> {
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
ParamElemVisitorBcast101x8<src_ctype> vis2;
ParamElemVisitorBcast101x8V2<src_ctype> vis0;
ParamElemVisitorV2<src_ctype> vis1;
ParamElemVisitorBcast101x8V2<src_ctype> vis2;
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
channel_stride);
......@@ -155,9 +138,9 @@ struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> {
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
ParamElemVisitor<src_ctype> vis2;
ParamElemVisitorV2<src_ctype> vis0;
ParamElemVisitorBcast101x8V2<src_ctype> vis1;
ParamElemVisitorV2<src_ctype> vis2;
OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
channel_stride);
......
......@@ -36,66 +36,6 @@ enum BcastType {
UNKNOWN_BCAST_TYPE
};
///////////////////////////////// ParamElemVistor ///////////////////////////
template <typename ctype>
struct ParamElemVisitor;
//! visitor single elemwise, and dup to vector
template <typename ctype>
struct ParamElemVisitorDup;
template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_qint8, int8_t, GI_INT8_t, Int8);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
cb(dt_int8, int8_t, GI_INT8_t, Int8);
#undef cb
template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}
cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32);
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
#undef cb
///////////////////////////////// ParamElemVistor v2///////////////////////////
template <typename ctype>
struct ParamElemVisitorV2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册