提交 2d806f9c 编写于 作者: M Megvii Engine Team

feat(gi): make conv_bias apply gi class type

GitOrigin-RevId: daa40f61c1649433b65a6d76f81d602d13ad382e
上级 19d36fa0
......@@ -12,8 +12,8 @@ using namespace fallback;
namespace {
template <int shift>
static inline void shift_src(GI_FLOAT32_t rsrc[3][4]) {
GI_FLOAT32_t t[4];
static inline void shift_src(GI_FLOAT32_FIXLEN_t rsrc[3][4]) {
GI_FLOAT32_FIXLEN_t t[4];
t[0] = rsrc[0][(shift + 0) % 4];
t[1] = rsrc[0][(shift + 1) % 4];
......@@ -57,32 +57,51 @@ struct compute_element {
template <typename Op>
static inline void call(
const float*& src0, const float*& src1, const float*& src2, float*& dst,
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4],
GI_FLOAT32_t rfilter[3][3], const Op& op) {
const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_FIXLEN_t rsrc[3][4], GI_FLOAT32_FIXLEN_t rfilter[3][3],
const Op& op) {
#define RSRC(i, j) rsrc[i][((j) + bw) % 4]
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
RSRC(0, 3) = GiLoadFloat32(src0 + 8);
RSRC(0, 3) = GiFloat32Type2FixLenType(GiLoadFloat32(src0 + 8));
}
{ RSRC(1, 3) = GiLoadFloat32(src1 + 8); }
{ RSRC(1, 3) = GiFloat32Type2FixLenType(GiLoadFloat32(src1 + 8)); }
if (has_bottom) {
RSRC(2, 3) = GiLoadFloat32(src2 + 8);
RSRC(2, 3) = GiFloat32Type2FixLenType(GiLoadFloat32(src2 + 8));
}
if (has_top) {
rdst = GiMlaqFloat32(rdst, RSRC(0, 0), rfilter[0][0]);
rdst = GiMlaqFloat32(rdst, RSRC(0, 1), rfilter[0][1]);
rdst = GiMlaqFloat32(rdst, RSRC(0, 2), rfilter[0][2]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(0, 0)),
GiFixLenType2GiFloat32Type(rfilter[0][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(0, 1)),
GiFixLenType2GiFloat32Type(rfilter[0][1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(0, 2)),
GiFixLenType2GiFloat32Type(rfilter[0][2]));
}
{
rdst = GiMlaqFloat32(rdst, RSRC(1, 0), rfilter[1][0]);
rdst = GiMlaqFloat32(rdst, RSRC(1, 1), rfilter[1][1]);
rdst = GiMlaqFloat32(rdst, RSRC(1, 2), rfilter[1][2]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(1, 0)),
GiFixLenType2GiFloat32Type(rfilter[1][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(1, 1)),
GiFixLenType2GiFloat32Type(rfilter[1][1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(1, 2)),
GiFixLenType2GiFloat32Type(rfilter[1][2]));
}
if (has_bottom) {
rdst = GiMlaqFloat32(rdst, RSRC(2, 0), rfilter[2][0]);
rdst = GiMlaqFloat32(rdst, RSRC(2, 1), rfilter[2][1]);
rdst = GiMlaqFloat32(rdst, RSRC(2, 2), rfilter[2][2]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(2, 0)),
GiFixLenType2GiFloat32Type(rfilter[2][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(2, 1)),
GiFixLenType2GiFloat32Type(rfilter[2][1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(2, 2)),
GiFixLenType2GiFloat32Type(rfilter[2][2]));
}
GiStoreFloat32(dst, op(rdst));
......@@ -113,23 +132,42 @@ struct compute_element_right {
template <typename Op>
static inline void call(
float*& dst, const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) {
GI_FLOAT32_FIXLEN_t rsrc[3][4], GI_FLOAT32_FIXLEN_t rfilter[3][3],
const Op& op) {
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
rdst = GiMlaqFloat32(rdst, rsrc[0][0], rfilter[0][0]);
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][1]);
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][2]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[0][0]),
GiFixLenType2GiFloat32Type(rfilter[0][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[0][1]),
GiFixLenType2GiFloat32Type(rfilter[0][1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[0][2]),
GiFixLenType2GiFloat32Type(rfilter[0][2]));
}
{
rdst = GiMlaqFloat32(rdst, rsrc[1][0], rfilter[1][0]);
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][1]);
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][2]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[1][0]),
GiFixLenType2GiFloat32Type(rfilter[1][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[1][1]),
GiFixLenType2GiFloat32Type(rfilter[1][1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[1][2]),
GiFixLenType2GiFloat32Type(rfilter[1][2]));
}
if (has_bottom) {
rdst = GiMlaqFloat32(rdst, rsrc[2][0], rfilter[2][0]);
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][1]);
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][2]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[2][0]),
GiFixLenType2GiFloat32Type(rfilter[2][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[2][1]),
GiFixLenType2GiFloat32Type(rfilter[2][1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[2][2]),
GiFixLenType2GiFloat32Type(rfilter[2][2]));
}
GiStoreFloat32(dst, op(rdst));
......@@ -144,20 +182,33 @@ struct compute_element_right_pad {
template <typename Op>
static inline void call(
float*& dst, const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) {
GI_FLOAT32_FIXLEN_t rsrc[3][4], GI_FLOAT32_FIXLEN_t rfilter[3][3],
const Op& op) {
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][0]);
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][1]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[0][1]),
GiFixLenType2GiFloat32Type(rfilter[0][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[0][2]),
GiFixLenType2GiFloat32Type(rfilter[0][1]));
}
{
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][0]);
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][1]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[1][1]),
GiFixLenType2GiFloat32Type(rfilter[1][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[1][2]),
GiFixLenType2GiFloat32Type(rfilter[1][1]));
}
if (has_bottom) {
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][0]);
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][1]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[2][1]),
GiFixLenType2GiFloat32Type(rfilter[2][0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[2][2]),
GiFixLenType2GiFloat32Type(rfilter[2][1]));
}
GiStoreFloat32(dst, op(rdst));
......@@ -171,22 +222,23 @@ struct compute_row {
template <typename Op>
static inline void call(
const float*& src0, const float*& src1, const float*& src2, float*& dst,
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4],
GI_FLOAT32_t rfilter[3][3], int W, const Op& op) {
const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_FIXLEN_t rsrc[3][4], GI_FLOAT32_FIXLEN_t rfilter[3][3], int W,
const Op& op) {
if (has_top) {
rsrc[0][0] = GiZeroFloat32();
rsrc[0][1] = GiLoadFloat32(src0 + 0);
rsrc[0][2] = GiLoadFloat32(src0 + 4);
rsrc[0][0] = GiFloat32Type2FixLenType(GiZeroFloat32());
rsrc[0][1] = GiFloat32Type2FixLenType(GiLoadFloat32(src0 + 0));
rsrc[0][2] = GiFloat32Type2FixLenType(GiLoadFloat32(src0 + 4));
}
{
rsrc[1][0] = GiZeroFloat32();
rsrc[1][1] = GiLoadFloat32(src1 + 0);
rsrc[1][2] = GiLoadFloat32(src1 + 4);
rsrc[1][0] = GiFloat32Type2FixLenType(GiZeroFloat32());
rsrc[1][1] = GiFloat32Type2FixLenType(GiLoadFloat32(src1 + 0));
rsrc[1][2] = GiFloat32Type2FixLenType(GiLoadFloat32(src1 + 4));
}
if (has_bottom) {
rsrc[2][0] = GiZeroFloat32();
rsrc[2][1] = GiLoadFloat32(src2 + 0);
rsrc[2][2] = GiLoadFloat32(src2 + 4);
rsrc[2][0] = GiFloat32Type2FixLenType(GiZeroFloat32());
rsrc[2][1] = GiFloat32Type2FixLenType(GiLoadFloat32(src2 + 0));
rsrc[2][2] = GiFloat32Type2FixLenType(GiLoadFloat32(src2 + 4));
}
int w = 0;
......@@ -246,18 +298,18 @@ void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1(
const float* src1 = src;
const float* src2 = src + W * 4;
GI_FLOAT32_t rfilter[3][3];
rfilter[0][0] = GiLoadFloat32(filter + 0);
rfilter[0][1] = GiLoadFloat32(filter + 4);
rfilter[0][2] = GiLoadFloat32(filter + 8);
rfilter[1][0] = GiLoadFloat32(filter + 12);
rfilter[1][1] = GiLoadFloat32(filter + 16);
rfilter[1][2] = GiLoadFloat32(filter + 20);
rfilter[2][0] = GiLoadFloat32(filter + 24);
rfilter[2][1] = GiLoadFloat32(filter + 28);
rfilter[2][2] = GiLoadFloat32(filter + 32);
GI_FLOAT32_t rsrc[3][4];
GI_FLOAT32_FIXLEN_t rfilter[3][3];
rfilter[0][0] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 0));
rfilter[0][1] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 4));
rfilter[0][2] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 8));
rfilter[1][0] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 12));
rfilter[1][1] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 16));
rfilter[1][2] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 20));
rfilter[2][0] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 24));
rfilter[2][1] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 28));
rfilter[2][2] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 32));
GI_FLOAT32_FIXLEN_t rsrc[3][4];
compute_row<false, true, bias_mode>::call(
src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op);
......
......@@ -12,8 +12,8 @@ using namespace fallback;
namespace {
template <int shift>
static inline void shift_src(GI_FLOAT32_t rsrc[6]) {
GI_FLOAT32_t t[6];
static inline void shift_src(GI_FLOAT32_FIXLEN_t rsrc[6]) {
GI_FLOAT32_FIXLEN_t t[6];
t[0] = rsrc[(shift + 0) % 6];
t[1] = rsrc[(shift + 1) % 6];
......@@ -29,12 +29,12 @@ static inline void shift_src(GI_FLOAT32_t rsrc[6]) {
rsrc[5] = t[5];
}
static inline void load_filter(const float* filter, GI_FLOAT32_t rfilter[5]) {
rfilter[0] = GiLoadFloat32(filter + 0);
rfilter[1] = GiLoadFloat32(filter + 4);
rfilter[2] = GiLoadFloat32(filter + 8);
rfilter[3] = GiLoadFloat32(filter + 12);
rfilter[4] = GiLoadFloat32(filter + 16);
static inline void load_filter(const float* filter, GI_FLOAT32_FIXLEN_t rfilter[5]) {
rfilter[0] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 0));
rfilter[1] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 4));
rfilter[2] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 8));
rfilter[3] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 12));
rfilter[4] = GiFloat32Type2FixLenType(GiLoadFloat32(filter + 16));
}
template <BiasMode bias_mode>
......@@ -51,8 +51,8 @@ struct compute_element {
template <typename Op>
static inline void call(
const float*& src, float*& dst, const float*& bias,
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5],
const Op& op) {
const GI_FLOAT32_t& init, GI_FLOAT32_FIXLEN_t rsrc[6],
GI_FLOAT32_FIXLEN_t rfilter[5], const Op& op) {
#define RSRC(i) rsrc[((i) + bw) % 6]
GI_FLOAT32_t rdst;
if (need_load_bias) {
......@@ -60,13 +60,23 @@ struct compute_element {
} else {
rdst = GiLoadFloat32(dst);
}
RSRC(5) = GiLoadFloat32(src + 12);
rdst = GiMlaqFloat32(rdst, RSRC(0), rfilter[0]);
rdst = GiMlaqFloat32(rdst, RSRC(1), rfilter[1]);
rdst = GiMlaqFloat32(rdst, RSRC(2), rfilter[2]);
rdst = GiMlaqFloat32(rdst, RSRC(3), rfilter[3]);
rdst = GiMlaqFloat32(rdst, RSRC(4), rfilter[4]);
RSRC(5) = GiFloat32Type2FixLenType(GiLoadFloat32(src + 12));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(0)),
GiFixLenType2GiFloat32Type(rfilter[0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(1)),
GiFixLenType2GiFloat32Type(rfilter[1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(2)),
GiFixLenType2GiFloat32Type(rfilter[2]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(3)),
GiFixLenType2GiFloat32Type(rfilter[3]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(RSRC(4)),
GiFixLenType2GiFloat32Type(rfilter[4]));
if (need_do_op) {
rdst = op(rdst);
......@@ -93,7 +103,7 @@ struct compute_element_right {
template <typename Op>
static inline void call(
float*& dst, const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], const Op& op) {
GI_FLOAT32_FIXLEN_t rsrc[6], GI_FLOAT32_FIXLEN_t rfilter[5], const Op& op) {
GI_FLOAT32_t rdst;
if (need_load_bias) {
rdst = load_bias<bias_mode>(bias, init);
......@@ -101,14 +111,24 @@ struct compute_element_right {
rdst = GiLoadFloat32(dst);
}
rdst = GiMlaqFloat32(rdst, rsrc[0 + padding], rfilter[0]);
rdst = GiMlaqFloat32(rdst, rsrc[1 + padding], rfilter[1]);
rdst = GiMlaqFloat32(rdst, rsrc[2 + padding], rfilter[2]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[0 + padding]),
GiFixLenType2GiFloat32Type(rfilter[0]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[1 + padding]),
GiFixLenType2GiFloat32Type(rfilter[1]));
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[2 + padding]),
GiFixLenType2GiFloat32Type(rfilter[2]));
if (padding < 2) {
rdst = GiMlaqFloat32(rdst, rsrc[3 + padding], rfilter[3]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[3 + padding]),
GiFixLenType2GiFloat32Type(rfilter[3]));
}
if (padding < 1) {
rdst = GiMlaqFloat32(rdst, rsrc[4 + padding], rfilter[4]);
rdst = GiMlaqFloat32(
rdst, GiFixLenType2GiFloat32Type(rsrc[4 + padding]),
GiFixLenType2GiFloat32Type(rfilter[4]));
}
if (need_do_op) {
......@@ -126,12 +146,13 @@ struct compute_row_src_1x5 {
template <typename Op>
static inline void call(
const float* src, float* dst, const float* bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], int W, const Op& op) {
rsrc[0] = GiZeroFloat32();
rsrc[1] = GiZeroFloat32();
rsrc[2] = GiLoadFloat32(src + 0);
rsrc[3] = GiLoadFloat32(src + 4);
rsrc[4] = GiLoadFloat32(src + 8);
GI_FLOAT32_FIXLEN_t rsrc[6], GI_FLOAT32_FIXLEN_t rfilter[5], int W,
const Op& op) {
rsrc[0] = GiFloat32Type2FixLenType(GiZeroFloat32());
rsrc[1] = GiFloat32Type2FixLenType(GiZeroFloat32());
rsrc[2] = GiFloat32Type2FixLenType(GiLoadFloat32(src + 0));
rsrc[3] = GiFloat32Type2FixLenType(GiLoadFloat32(src + 4));
rsrc[4] = GiFloat32Type2FixLenType(GiLoadFloat32(src + 8));
int w = 0;
......@@ -172,8 +193,8 @@ struct compute_row {
template <typename Op>
static inline void call(
const float*& src, float*& dst, const float* filter, const float*& bias,
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5],
int W, const Op& op) {
const GI_FLOAT32_t& init, GI_FLOAT32_FIXLEN_t rsrc[6],
GI_FLOAT32_FIXLEN_t rfilter[5], int W, const Op& op) {
if (top_padding < 1) {
load_filter(filter + 0, rfilter);
compute_row_src_1x5<bias_mode, top_padding == 0, false>::call(
......@@ -222,8 +243,8 @@ void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2(
init = GiLoadFloat32(bias);
}
GI_FLOAT32_t rsrc[6];
GI_FLOAT32_t rfilter[5];
GI_FLOAT32_FIXLEN_t rsrc[6];
GI_FLOAT32_FIXLEN_t rfilter[5];
compute_row<2, 0, bias_mode>::call(
src, dst, filter, bias, init, rsrc, rfilter, W, op);
......
......@@ -93,7 +93,12 @@ struct do_pixel_proxy<1, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp;
GI_FLOAT32_t zero = GiZeroFloat32();
GI_FLOAT32_t out0, out1, out2, out3, kr0, inp;
out0 = zero;
out1 = zero;
out2 = zero;
out3 = zero;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
......@@ -134,7 +139,12 @@ struct do_pixel_proxy<2, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp;
GI_FLOAT32_t zero = GiZeroFloat32();
GI_FLOAT32_t out0, out1, out2, out3, kr0, kr1, inp;
out0 = zero;
out1 = zero;
out2 = zero;
out3 = zero;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
......@@ -187,7 +197,12 @@ struct do_pixel_proxy<3, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp;
GI_FLOAT32_t zero = GiZeroFloat32();
GI_FLOAT32_t out0, out1, out2, out3, kr0, kr1, kr2, inp;
out0 = zero;
out1 = zero;
out2 = zero;
out3 = zero;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
......@@ -252,7 +267,12 @@ struct do_pixel_proxy<4, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp;
GI_FLOAT32_t zero = GiZeroFloat32();
GI_FLOAT32_t out0, out1, out2, out3, kr0, kr1, kr2, kr3, inp;
out0 = zero;
out1 = zero;
out2 = zero;
out3 = zero;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
......@@ -329,7 +349,12 @@ struct do_pixel_proxy<5, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp;
GI_FLOAT32_t zero = GiZeroFloat32();
GI_FLOAT32_t out0, out1, out2, out3, kr0, kr1, kr2, kr3, kr4, inp;
out0 = zero;
out1 = zero;
out2 = zero;
out3 = zero;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
......@@ -418,8 +443,12 @@ struct do_pixel_proxy<6, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5,
inp;
GI_FLOAT32_t zero = GiZeroFloat32();
GI_FLOAT32_t out0, out1, out2, out3, kr0, kr1, kr2, kr3, kr4, kr5, inp;
out0 = zero;
out1 = zero;
out2 = zero;
out3 = zero;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
......@@ -520,8 +549,12 @@ struct do_pixel_proxy<7, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5,
kr6, inp;
GI_FLOAT32_t zero = GiZeroFloat32();
GI_FLOAT32_t out0, out1, out2, out3, kr0, kr1, kr2, kr3, kr4, kr5, kr6, inp;
out0 = zero;
out1 = zero;
out2 = zero;
out3 = zero;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
......
......@@ -38,23 +38,23 @@ static inline void odd_even_split_iw8_even(
const int src_offset = src_idx * ic_step;
const int even_offset = iw_idx / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
GI_FLOAT32_t temp[8];
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step);
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step);
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step);
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step);
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step);
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step);
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step);
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step);
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[0]);
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[2]);
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[4]);
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[6]);
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
GI_FLOAT32_t a0, a1, a2, a3, a4, a5, a6, a7;
a0 = GiLoadFloat32(sptr + src_offset + 0 * ic_step);
a1 = GiLoadFloat32(sptr + src_offset + 1 * ic_step);
a2 = GiLoadFloat32(sptr + src_offset + 2 * ic_step);
a3 = GiLoadFloat32(sptr + src_offset + 3 * ic_step);
a4 = GiLoadFloat32(sptr + src_offset + 4 * ic_step);
a5 = GiLoadFloat32(sptr + src_offset + 5 * ic_step);
a6 = GiLoadFloat32(sptr + src_offset + 6 * ic_step);
a7 = GiLoadFloat32(sptr + src_offset + 7 * ic_step);
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, a0);
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, a2);
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, a4);
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, a6);
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, a1);
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, a3);
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, a5);
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, a7);
}
static inline void odd_even_split_iw8_odd(
......@@ -64,23 +64,23 @@ static inline void odd_even_split_iw8_odd(
const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
GI_FLOAT32_t temp[8];
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step);
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step);
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step);
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step);
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step);
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step);
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step);
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step);
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[1]);
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[3]);
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[5]);
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[7]);
GI_FLOAT32_t a0, a1, a2, a3, a4, a5, a6, a7;
a0 = GiLoadFloat32(sptr + src_offset + 0 * ic_step);
a1 = GiLoadFloat32(sptr + src_offset + 1 * ic_step);
a2 = GiLoadFloat32(sptr + src_offset + 2 * ic_step);
a3 = GiLoadFloat32(sptr + src_offset + 3 * ic_step);
a4 = GiLoadFloat32(sptr + src_offset + 4 * ic_step);
a5 = GiLoadFloat32(sptr + src_offset + 5 * ic_step);
a6 = GiLoadFloat32(sptr + src_offset + 6 * ic_step);
a7 = GiLoadFloat32(sptr + src_offset + 7 * ic_step);
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, a0);
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, a2);
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, a4);
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, a6);
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, a1);
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, a3);
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, a5);
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, a7);
}
} // namespace
......
......@@ -25,14 +25,20 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> {
};
#define cb2(step, lane, ow_block) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \
c[1][step] = GiSimdFmaLane( \
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane);
#define cb(step, lane, ow_block) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane);
c[0][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[0][step]), \
GiFixLenType2GiFloat32Type(weight[0][lane]), \
GiFixLenType2GiFloat32Type(src[(step + src_idx) % ow_block]), lane)); \
c[1][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[1][step]), \
GiFixLenType2GiFloat32Type(weight[1][lane]), \
GiFixLenType2GiFloat32Type(src[(step + src_idx) % ow_block]), lane));
#define cb(step, lane, ow_block) \
c[0][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[0][step]), \
GiFixLenType2GiFloat32Type(weight[0][lane]), \
GiFixLenType2GiFloat32Type(src[(step + src_idx) % ow_block]), lane));
#define SHIFT_CAL_HELPER(ow_block, remain_w) \
template < \
......@@ -133,19 +139,20 @@ struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block)*ic_step));
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -175,23 +182,25 @@ struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block)*ic_step));
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step);
src[1] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step));
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -220,35 +229,39 @@ struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block)*ic_step));
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step);
src[1] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step));
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step);
src[2] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step));
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step);
src[3] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step));
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -278,45 +291,51 @@ struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block)*ic_step));
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step);
src[1] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step));
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step);
src[2] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step));
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step);
src[3] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step));
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[4] = GiLoadFloat32(src_ptr + (ow_block + 4) * ic_step);
src[4] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 4) * ic_step));
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[5] = GiLoadFloat32(src_ptr + (ow_block + 5) * ic_step);
src[5] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 5) * ic_step));
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight);
......
......@@ -25,14 +25,20 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> {
};
#define cb2(step, lane, ow_block) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \
c[1][step] = GiSimdFmaLane( \
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane);
#define cb(step, lane, ow_block) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane);
c[0][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[0][step]), \
GiFixLenType2GiFloat32Type(weight[0][lane]), \
GiFixLenType2GiFloat32Type(src[(step + src_idx) % ow_block]), lane)); \
c[1][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[1][step]), \
GiFixLenType2GiFloat32Type(weight[1][lane]), \
GiFixLenType2GiFloat32Type(src[(step + src_idx) % ow_block]), lane));
#define cb(step, lane, ow_block) \
c[0][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[0][step]), \
GiFixLenType2GiFloat32Type(weight[0][lane]), \
GiFixLenType2GiFloat32Type(src[(step + src_idx) % ow_block]), lane));
#define SHIFT_CAL_HELPER(ow_block, remain_w) \
template < \
......@@ -133,15 +139,15 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][4];
/////////row 0/////////////
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
......@@ -191,21 +197,22 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][4];
/////////row 0/////////////
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + ow_block * simd_len));
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -222,7 +229,8 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + ow_block * simd_len));
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -239,7 +247,8 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + ow_block * simd_len));
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
......@@ -275,7 +284,7 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
......@@ -283,18 +292,20 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][4];
// even element
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + ow_block * simd_len));
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len);
src[1] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len));
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -303,7 +314,8 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr_odd + ow_block * simd_len));
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -340,7 +352,7 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][ow_block];
GI_FLOAT32_FIXLEN_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
......@@ -348,22 +360,25 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
GI_FLOAT32_FIXLEN_t src[ow_block];
GI_FLOAT32_FIXLEN_t weight[c_dim][4];
// even element
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + ow_block * simd_len));
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len);
src[1] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len));
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * simd_len);
src[2] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr + (ow_block + 2) * simd_len));
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight);
......@@ -372,11 +387,13 @@ struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len);
src[0] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr_odd + ow_block * simd_len));
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = GiLoadFloat32(src_ptr_odd + (ow_block + 1) * simd_len);
src[1] = GiFloat32Type2FixLenType(
GiLoadFloat32(src_ptr_odd + (ow_block + 1) * simd_len));
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
......
......@@ -37,18 +37,24 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, stride, 0, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {}
};
#define cb(step) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4); \
c[1][step] = GiSimdFmaLane( \
c[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);
#define cb2(step) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);
#define cb(step) \
c[0][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[0][step]), \
GiFixLenType2GiFloat32Type(weight[0][weight_idx]), \
GiFixLenType2GiFloat32Type(src[(step * stride + src_idx) / 4]), \
(step * stride + src_idx) % 4)); \
c[1][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[1][step]), \
GiFixLenType2GiFloat32Type(weight[1][weight_idx]), \
GiFixLenType2GiFloat32Type(src[(step * stride + src_idx) / 4]), \
(step * stride + src_idx) % 4));
#define cb2(step) \
c[0][step] = GiFloat32Type2FixLenType(GiSimdFmaLane( \
GiFixLenType2GiFloat32Type(c[0][step]), \
GiFixLenType2GiFloat32Type(weight[0][weight_idx]), \
GiFixLenType2GiFloat32Type(src[(step * stride + src_idx) / 4]), \
(step * stride + src_idx) % 4));
#define SHIFT_CAL_HELPER(ow_remain) \
template < \
......@@ -141,12 +147,12 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, ow_
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][8];
GI_FLOAT32_FIXLEN_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];
GI_FLOAT32_FIXLEN_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \
......@@ -190,12 +196,12 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, ow_
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][8];
GI_FLOAT32_FIXLEN_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];
GI_FLOAT32_FIXLEN_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \
......@@ -236,12 +242,12 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, ow_
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][8];
GI_FLOAT32_FIXLEN_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];
GI_FLOAT32_FIXLEN_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>(
......@@ -295,7 +301,7 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> {
const int ld_src_ic_skip_bytes =
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[1][8];
GI_FLOAT32_FIXLEN_t c[1][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
const int img_stride = ih * iw;
constexpr int filter_stride = filter_size * filter_size * oc_step;
......@@ -467,7 +473,7 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU
const int ld_src_ic_skip_bytes =
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[1][8];
GI_FLOAT32_FIXLEN_t c[1][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
/**
* c q8-q15
......@@ -627,12 +633,12 @@ struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, ow_
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
GI_FLOAT32_t c[c_dim][8];
GI_FLOAT32_FIXLEN_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];
GI_FLOAT32_FIXLEN_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>(
......
......@@ -38,16 +38,16 @@ void conv_stride2::do_conv_2x2_stride2(
GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0);
GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7
GI_FLOAT32_t _r00 = GiGetSubVectorFloat32V2(_r0, 0); // 0 2 4 6
GI_FLOAT32_t _r01 = GiGetSubVectorFloat32V2(_r0, 1); // 1 3 5 7
_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0);
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1);
GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1);
GI_FLOAT32_t _r10 = _r1.val[0];
GI_FLOAT32_t _r11 = _r1.val[1];
GI_FLOAT32_t _r10 = GiGetSubVectorFloat32V2(_r1, 0);
GI_FLOAT32_t _r11 = GiGetSubVectorFloat32V2(_r1, 1);
_outp = GiSimdFmaLane(_outp, _r10, _k0123, 2);
_outp = GiSimdFmaLane(_outp, _r11, _k0123, 3);
......@@ -97,9 +97,10 @@ void conv_stride2::do_conv_3x3_stride2(
GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0);
GI_FLOAT32_V2_t _r0n = GiLd2qFloat32(r0 + 8);
GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0n.val[0], 1); // 2 4 6 8
GI_FLOAT32_t _r00 = GiGetSubVectorFloat32V2(_r0, 0); // 0 2 4 6
GI_FLOAT32_t _r01 = GiGetSubVectorFloat32V2(_r0, 1); // 1 3 5 7
GI_FLOAT32_t _r02 = GiExtqFloat32(
_r00, GiGetSubVectorFloat32V2(_r0n, 0), 1); // 2 4 6 8
_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0);
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1);
......@@ -108,9 +109,10 @@ void conv_stride2::do_conv_3x3_stride2(
GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1);
GI_FLOAT32_V2_t _r1n = GiLd2qFloat32(r1 + 8);
GI_FLOAT32_t _r10 = _r1.val[0];
GI_FLOAT32_t _r11 = _r1.val[1];
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1n.val[0], 1);
GI_FLOAT32_t _r10 = GiGetSubVectorFloat32V2(_r1, 0);
GI_FLOAT32_t _r11 = GiGetSubVectorFloat32V2(_r1, 1);
GI_FLOAT32_t _r12 =
GiExtqFloat32(_r10, GiGetSubVectorFloat32V2(_r1n, 0), 1);
_outp = GiSimdFmaLane(_outp, _r10, _k3456, 0);
_outp = GiSimdFmaLane(_outp, _r11, _k3456, 1);
......@@ -119,9 +121,10 @@ void conv_stride2::do_conv_3x3_stride2(
GI_FLOAT32_V2_t _r2 = GiLd2qFloat32(r2);
GI_FLOAT32_V2_t _r2n = GiLd2qFloat32(r2 + 8);
GI_FLOAT32_t _r20 = _r2.val[0];
GI_FLOAT32_t _r21 = _r2.val[1];
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2n.val[0], 1);
GI_FLOAT32_t _r20 = GiGetSubVectorFloat32V2(_r2, 0);
GI_FLOAT32_t _r21 = GiGetSubVectorFloat32V2(_r2, 1);
GI_FLOAT32_t _r22 =
GiExtqFloat32(_r20, GiGetSubVectorFloat32V2(_r2n, 0), 1);
_outp = GiSimdFmaLane(_outp, _r20, _k6789, 0);
_outp = GiSimdFmaLane(_outp, _r21, _k6789, 1);
......@@ -175,50 +178,54 @@ void conv_stride2::do_conv_5x5_stride2(
GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0);
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8);
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7
GI_FLOAT32_t _r0_8101214 =
GiGetSubVectorFloat32V2(_r00nx2, 0); // 8 10 12 14
GI_FLOAT32_t _r0_9111315 =
GiGetSubVectorFloat32V2(_r00nx2, 1); // 9 11 13 15
GI_FLOAT32_t _r00 =
GiGetSubVectorFloat32V2(_r00_02461357, 0); // 0 2 4 6
GI_FLOAT32_t _r01 =
GiGetSubVectorFloat32V2(_r00_02461357, 1); // 1 3 5 7
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10
GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1);
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8);
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0];
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1];
GI_FLOAT32_t _r10 = _r10_02461357.val[0];
GI_FLOAT32_t _r11 = _r10_02461357.val[1];
GI_FLOAT32_t _r1_8101214 = GiGetSubVectorFloat32V2(_r10nx2, 0);
GI_FLOAT32_t _r1_9111315 = GiGetSubVectorFloat32V2(_r10nx2, 1);
GI_FLOAT32_t _r10 = GiGetSubVectorFloat32V2(_r10_02461357, 0);
GI_FLOAT32_t _r11 = GiGetSubVectorFloat32V2(_r10_02461357, 1);
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1);
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1);
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2);
GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2);
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8);
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0];
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1];
GI_FLOAT32_t _r20 = _r20_02461357.val[0];
GI_FLOAT32_t _r21 = _r20_02461357.val[1];
GI_FLOAT32_t _r2_8101214 = GiGetSubVectorFloat32V2(_r20nx2, 0);
GI_FLOAT32_t _r2_9111315 = GiGetSubVectorFloat32V2(_r20nx2, 1);
GI_FLOAT32_t _r20 = GiGetSubVectorFloat32V2(_r20_02461357, 0);
GI_FLOAT32_t _r21 = GiGetSubVectorFloat32V2(_r20_02461357, 1);
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1);
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1);
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2);
GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3);
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8);
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0];
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1];
GI_FLOAT32_t _r30 = _r30_02461357.val[0];
GI_FLOAT32_t _r31 = _r30_02461357.val[1];
GI_FLOAT32_t _r3_8101214 = GiGetSubVectorFloat32V2(_r30nx2, 0);
GI_FLOAT32_t _r3_9111315 = GiGetSubVectorFloat32V2(_r30nx2, 1);
GI_FLOAT32_t _r30 = GiGetSubVectorFloat32V2(_r30_02461357, 0);
GI_FLOAT32_t _r31 = GiGetSubVectorFloat32V2(_r30_02461357, 1);
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1);
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1);
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2);
GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4);
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8);
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0];
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1];
GI_FLOAT32_t _r40 = _r40_02461357.val[0];
GI_FLOAT32_t _r41 = _r40_02461357.val[1];
GI_FLOAT32_t _r4_8101214 = GiGetSubVectorFloat32V2(_r40nx2, 0);
GI_FLOAT32_t _r4_9111315 = GiGetSubVectorFloat32V2(_r40nx2, 1);
GI_FLOAT32_t _r40 = GiGetSubVectorFloat32V2(_r40_02461357, 0);
GI_FLOAT32_t _r41 = GiGetSubVectorFloat32V2(_r40_02461357, 1);
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1);
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1);
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2);
......@@ -310,10 +317,14 @@ void conv_stride2::do_conv_7x7_stride2(
GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0);
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8);
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7
GI_FLOAT32_t _r0_8101214 =
GiGetSubVectorFloat32V2(_r00nx2, 0); // 8 10 12 14
GI_FLOAT32_t _r0_9111315 =
GiGetSubVectorFloat32V2(_r00nx2, 1); // 9 11 13 15
GI_FLOAT32_t _r00 =
GiGetSubVectorFloat32V2(_r00_02461357, 0); // 0 2 4 6
GI_FLOAT32_t _r01 =
GiGetSubVectorFloat32V2(_r00_02461357, 1); // 1 3 5 7
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10
......@@ -333,10 +344,10 @@ void conv_stride2::do_conv_7x7_stride2(
GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1);
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8);
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0];
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1];
GI_FLOAT32_t _r10 = _r10_02461357.val[0];
GI_FLOAT32_t _r11 = _r10_02461357.val[1];
GI_FLOAT32_t _r1_8101214 = GiGetSubVectorFloat32V2(_r10nx2, 0);
GI_FLOAT32_t _r1_9111315 = GiGetSubVectorFloat32V2(_r10nx2, 1);
GI_FLOAT32_t _r10 = GiGetSubVectorFloat32V2(_r10_02461357, 0);
GI_FLOAT32_t _r11 = GiGetSubVectorFloat32V2(_r10_02461357, 1);
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1);
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1);
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2);
......@@ -356,10 +367,10 @@ void conv_stride2::do_conv_7x7_stride2(
GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2);
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8);
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0];
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1];
GI_FLOAT32_t _r20 = _r20_02461357.val[0];
GI_FLOAT32_t _r21 = _r20_02461357.val[1];
GI_FLOAT32_t _r2_8101214 = GiGetSubVectorFloat32V2(_r20nx2, 0);
GI_FLOAT32_t _r2_9111315 = GiGetSubVectorFloat32V2(_r20nx2, 1);
GI_FLOAT32_t _r20 = GiGetSubVectorFloat32V2(_r20_02461357, 0);
GI_FLOAT32_t _r21 = GiGetSubVectorFloat32V2(_r20_02461357, 1);
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1);
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1);
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2);
......@@ -379,10 +390,10 @@ void conv_stride2::do_conv_7x7_stride2(
GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3);
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8);
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0];
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1];
GI_FLOAT32_t _r30 = _r30_02461357.val[0];
GI_FLOAT32_t _r31 = _r30_02461357.val[1];
GI_FLOAT32_t _r3_8101214 = GiGetSubVectorFloat32V2(_r30nx2, 0);
GI_FLOAT32_t _r3_9111315 = GiGetSubVectorFloat32V2(_r30nx2, 1);
GI_FLOAT32_t _r30 = GiGetSubVectorFloat32V2(_r30_02461357, 0);
GI_FLOAT32_t _r31 = GiGetSubVectorFloat32V2(_r30_02461357, 1);
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1);
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1);
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2);
......@@ -402,10 +413,10 @@ void conv_stride2::do_conv_7x7_stride2(
GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4);
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8);
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0];
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1];
GI_FLOAT32_t _r40 = _r40_02461357.val[0];
GI_FLOAT32_t _r41 = _r40_02461357.val[1];
GI_FLOAT32_t _r4_8101214 = GiGetSubVectorFloat32V2(_r40nx2, 0);
GI_FLOAT32_t _r4_9111315 = GiGetSubVectorFloat32V2(_r40nx2, 1);
GI_FLOAT32_t _r40 = GiGetSubVectorFloat32V2(_r40_02461357, 0);
GI_FLOAT32_t _r41 = GiGetSubVectorFloat32V2(_r40_02461357, 1);
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1);
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1);
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2);
......@@ -425,10 +436,10 @@ void conv_stride2::do_conv_7x7_stride2(
GI_FLOAT32_V2_t _r50_02461357 = GiLd2qFloat32(r5);
GI_FLOAT32_V2_t _r50nx2 = GiLd2qFloat32(r5 + 8);
GI_FLOAT32_t _r5_8101214 = _r50nx2.val[0];
GI_FLOAT32_t _r5_9111315 = _r50nx2.val[1];
GI_FLOAT32_t _r50 = _r50_02461357.val[0];
GI_FLOAT32_t _r51 = _r50_02461357.val[1];
GI_FLOAT32_t _r5_8101214 = GiGetSubVectorFloat32V2(_r50nx2, 0);
GI_FLOAT32_t _r5_9111315 = GiGetSubVectorFloat32V2(_r50nx2, 1);
GI_FLOAT32_t _r50 = GiGetSubVectorFloat32V2(_r50_02461357, 0);
GI_FLOAT32_t _r51 = GiGetSubVectorFloat32V2(_r50_02461357, 1);
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r5_8101214, 1);
GI_FLOAT32_t _r53 = GiExtqFloat32(_r51, _r5_9111315, 1);
GI_FLOAT32_t _r54 = GiExtqFloat32(_r50, _r5_8101214, 2);
......@@ -448,10 +459,10 @@ void conv_stride2::do_conv_7x7_stride2(
GI_FLOAT32_V2_t _r60_02461357 = GiLd2qFloat32(r6);
GI_FLOAT32_V2_t _r60nx2 = GiLd2qFloat32(r6 + 8);
GI_FLOAT32_t _r6_8101214 = _r60nx2.val[0];
GI_FLOAT32_t _r6_9111315 = _r60nx2.val[1];
GI_FLOAT32_t _r60 = _r60_02461357.val[0];
GI_FLOAT32_t _r61 = _r60_02461357.val[1];
GI_FLOAT32_t _r6_8101214 = GiGetSubVectorFloat32V2(_r60nx2, 0);
GI_FLOAT32_t _r6_9111315 = GiGetSubVectorFloat32V2(_r60nx2, 1);
GI_FLOAT32_t _r60 = GiGetSubVectorFloat32V2(_r60_02461357, 0);
GI_FLOAT32_t _r61 = GiGetSubVectorFloat32V2(_r60_02461357, 1);
GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r6_8101214, 1);
GI_FLOAT32_t _r63 = GiExtqFloat32(_r61, _r6_9111315, 1);
GI_FLOAT32_t _r64 = GiExtqFloat32(_r60, _r6_8101214, 2);
......
......@@ -54,7 +54,8 @@ struct FilterTransform6X3 {
Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1);
GI_FLOAT32_t zeros = GiZeroFloat32();
g2.value = GiExtqFloat32(g2.value, zeros, 1);
g2.value = GiFloat32Type2FixLenType(
GiExtqFloat32(GiFixLenType2GiFloat32Type(g2.value), zeros, 1));
#define cb(i) Vector<float, 4> wd##i;
UNROLL_CALL_NOWRAPPER(8, cb);
......@@ -115,7 +116,8 @@ struct FilterTransform6X3 {
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \
mid_buf1 += 8; \
} while (0);
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value)
#define GET_VECTOR_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value))
float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
......
......@@ -6,18 +6,22 @@ namespace megdnn {
namespace fallback {
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
GI_FLOAT32_V2_t a0, a1;
a0.val[0] = GiLoadFloat32(src + 0 * lda);
a0.val[1] = GiLoadFloat32(src + 1 * lda);
a1.val[0] = GiLoadFloat32(src + 2 * lda);
a1.val[1] = GiLoadFloat32(src + 3 * lda);
GI_FLOAT32_V2_t b0 = GiZipqFloat32(a0.val[0], a1.val[0]);
GI_FLOAT32_V2_t b1 = GiZipqFloat32(a0.val[1], a1.val[1]);
GI_FLOAT32_V2_t c0 = GiZipqFloat32(b0.val[0], b1.val[0]);
GI_FLOAT32_V2_t c1 = GiZipqFloat32(b0.val[1], b1.val[1]);
GiStoreFloat32(dst + 0 * ldb, c0.val[0]);
GiStoreFloat32(dst + 1 * ldb, c0.val[1]);
GiStoreFloat32(dst + 2 * ldb, c1.val[0]);
GiStoreFloat32(dst + 3 * ldb, c1.val[1]);
GiSetSubVectorFloat32V2(a0, 0, GiLoadFloat32(src + 0 * lda));
GiSetSubVectorFloat32V2(a0, 1, GiLoadFloat32(src + 1 * lda));
GiSetSubVectorFloat32V2(a1, 0, GiLoadFloat32(src + 2 * lda));
GiSetSubVectorFloat32V2(a1, 1, GiLoadFloat32(src + 3 * lda));
GI_FLOAT32_V2_t b0 = GiZipqFloat32(
GiGetSubVectorFloat32V2(a0, 0), GiGetSubVectorFloat32V2(a1, 0));
GI_FLOAT32_V2_t b1 = GiZipqFloat32(
GiGetSubVectorFloat32V2(a0, 1), GiGetSubVectorFloat32V2(a1, 1));
GI_FLOAT32_V2_t c0 = GiZipqFloat32(
GiGetSubVectorFloat32V2(b0, 0), GiGetSubVectorFloat32V2(b1, 0));
GI_FLOAT32_V2_t c1 = GiZipqFloat32(
GiGetSubVectorFloat32V2(b0, 1), GiGetSubVectorFloat32V2(b1, 1));
GiStoreFloat32(dst + 0 * ldb, GiGetSubVectorFloat32V2(c0, 0));
GiStoreFloat32(dst + 1 * ldb, GiGetSubVectorFloat32V2(c0, 1));
GiStoreFloat32(dst + 2 * ldb, GiGetSubVectorFloat32V2(c1, 0));
GiStoreFloat32(dst + 3 * ldb, GiGetSubVectorFloat32V2(c1, 1));
}
} // namespace fallback
} // namespace megdnn
......@@ -159,27 +163,43 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
GiReinterpretqFloat32ToS64(b3.val[1])));
#else
#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = \
GiCombineFloat32(GiGetLowFloat32(b0.val[0]), GiGetLowFloat32(b1.val[0])); \
CONCAT(ret, 1).value.val[0] = GiCombineFloat32( \
GiGetHighFloat32(b0.val[0]), GiGetHighFloat32(b1.val[0])); \
CONCAT(ret, 2).value.val[0] = \
GiCombineFloat32(GiGetLowFloat32(b0.val[1]), GiGetLowFloat32(b1.val[1])); \
CONCAT(ret, 3).value.val[0] = GiCombineFloat32( \
GiGetHighFloat32(b0.val[1]), GiGetHighFloat32(b1.val[1])); \
CONCAT(ret, 0).value.val[1] = \
GiCombineFloat32(GiGetLowFloat32(b2.val[0]), GiGetLowFloat32(b3.val[0])); \
CONCAT(ret, 1).value.val[1] = GiCombineFloat32( \
GiGetHighFloat32(b2.val[0]), GiGetHighFloat32(b3.val[0])); \
CONCAT(ret, 2).value.val[1] = \
GiCombineFloat32(GiGetLowFloat32(b2.val[1]), GiGetLowFloat32(b3.val[1])); \
CONCAT(ret, 3).value.val[1] = GiCombineFloat32( \
GiGetHighFloat32(b2.val[1]), GiGetHighFloat32(b3.val[1]));
#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 0).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 1).value)); \
auto b1 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 2).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 3).value)); \
auto b2 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 4).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 5).value)); \
auto b3 = GiZipqFloat32( \
GiFixLenType2GiFloat32Type(CONCAT(a, 6).value), \
GiFixLenType2GiFloat32Type(CONCAT(a, 7).value)); \
CONCAT(ret, 0).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 0)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \
CONCAT(ret, 1).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 0)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 0)))); \
CONCAT(ret, 2).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b0, 1)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \
CONCAT(ret, 3).value.val[0] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b0, 1)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b1, 1)))); \
CONCAT(ret, 0).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 0)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \
CONCAT(ret, 1).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 0)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 0)))); \
CONCAT(ret, 2).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b2, 1)), \
GiGetLowFloat32(GiGetSubVectorFloat32V2(b3, 1)))); \
CONCAT(ret, 3).value.val[1] = GiFloat32Type2FixLenType(GiCombineFloat32( \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b2, 1)), \
GiGetHighFloat32(GiGetSubVectorFloat32V2(b3, 1))));
#endif
// vim: syntax=cpp.doxygen
......@@ -155,10 +155,10 @@ struct OutputTransform2X3 {
v11 += vbias;
}
if (bmode != BiasMode::BIAS) {
v00 = op(v00.value);
v01 = op(v01.value);
v10 = op(v10.value);
v11 = op(v11.value);
v00 = op(GiFixLenType2GiFloat32Type(v00.value));
v01 = op(GiFixLenType2GiFloat32Type(v01.value));
v10 = op(GiFixLenType2GiFloat32Type(v10.value));
v11 = op(GiFixLenType2GiFloat32Type(v11.value));
}
v00.save(transform_mid_buf + (0 * 2 + 0) * 4);
......@@ -194,10 +194,28 @@ void winograd_gi_2x3_4x4_f::filter(
size_t OC, size_t IC, size_t oc_start, size_t oc_end) {
constexpr int alpha = 2 + 3 - 1;
//! G * g * GT
GI_FLOAT32_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0},
g3{0, 0, 1, 0};
GI_FLOAT32_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1},
gt3{0, 0, 0, 0};
float tmp[4];
auto init_g = [&](float a0, float a1, float a2, float a3) {
tmp[0] = a0;
tmp[1] = a1;
tmp[2] = a2;
tmp[3] = a3;
};
init_g(1.f, 0, 0, 0);
GI_FLOAT32_t g0 = GiLoadFloat32(tmp);
init_g(0.5, 0.5, 0.5, 0);
GI_FLOAT32_t g1 = GiLoadFloat32(tmp);
init_g(0.5, -0.5, 0.5, 0);
GI_FLOAT32_t g2 = GiLoadFloat32(tmp);
init_g(0, 0, 1, 0);
GI_FLOAT32_t g3 = GiLoadFloat32(tmp);
init_g(1, 0.5, 0.5, 0);
GI_FLOAT32_t gt0 = GiLoadFloat32(tmp);
init_g(0, 0.5, -0.5, 0);
GI_FLOAT32_t gt1 = GiLoadFloat32(tmp);
init_g(0, 0.5, 0.5, 1);
GI_FLOAT32_t gt2 = GiLoadFloat32(tmp);
GI_FLOAT32_t gt3 = GiZeroFloat32();
size_t OCB = OC / 4;
size_t ICB = IC / 4;
......@@ -217,15 +235,15 @@ void winograd_gi_2x3_4x4_f::filter(
GI_FLOAT32_t vf1 = GiLoadFloat32(filter_ptr + 4);
GI_FLOAT32_t vf2 = GiBroadcastFloat32(filter_ptr[8]);
GI_FLOAT32_t v3(GiBroadcastFloat32(0));
GI_FLOAT32_t v3 = GiBroadcastFloat32(0);
auto vtmp = GiExtqFloat32(vf1, vf2, 2);
vtmp = GiSetqLaneFloat32(0, vtmp, 3);
GI_FLOAT32_t v2(vtmp);
GI_FLOAT32_t v2 = vtmp;
vtmp = GiExtqFloat32(vf0, vf1, 3);
vtmp = GiSetqLaneFloat32(0, vtmp, 3);
GI_FLOAT32_t v1(vtmp);
GI_FLOAT32_t v1 = vtmp;
vtmp = GiSetqLaneFloat32(0, vf0, 3);
GI_FLOAT32_t v0(vtmp);
GI_FLOAT32_t v0 = vtmp;
GI_FLOAT32_t vsum0 = GiBroadcastFloat32(0), vsum1 = GiBroadcastFloat32(0),
vsum2 = GiBroadcastFloat32(0), vsum3 = GiBroadcastFloat32(0);
......
......@@ -115,10 +115,19 @@ struct FilterTransform4X5 {
FILTER_TRANSFORM(g, Gg)
GI_FLOAT32_V2_t vgr;
GI_FLOAT32_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3};
GI_FLOAT32_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7};
vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3};
vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7};
float tmp[4];
tmp[0] = Ggr0;
tmp[1] = Ggr1;
tmp[2] = Ggr2;
tmp[3] = Ggr3;
GI_FLOAT32_t vgr0 = GiLoadFloat32(tmp);
tmp[0] = Ggr4;
tmp[1] = Ggr5;
tmp[2] = Ggr6;
tmp[3] = Ggr7;
GI_FLOAT32_t vgr1 = GiLoadFloat32(tmp);
GiSetSubVectorFloat32V2(vgr, 0, vgr0); //{Ggr0, Ggr1, Ggr2, Ggr3};
GiSetSubVectorFloat32V2(vgr, 1, vgr1); //{Ggr4, Ggr5, Ggr6, Ggr7};
Vector<float, 8> Ggt4(vgr);
TRANSPOSE_8x4(Gg, Ggt);
FILTER_TRANSFORM_FINAL(Ggt, result);
......@@ -155,10 +164,12 @@ struct InputTransform4X5 {
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0)
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1])
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0])
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 1))
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiGetSubVectorFloat32V2( \
GiFixLenType2GiFloat32V2Type(CONCAT(s, i).value), 0))
template <bool inner>
static void transform(
......
......@@ -104,7 +104,8 @@ struct FilterTransform5X4 {
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \
mid_buf1 += 8; \
} while (0);
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value)
#define GET_VECTOR_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value))
float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
......@@ -142,9 +143,9 @@ struct InputTransform5X4 {
} while (0)
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1])
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1]))
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0])
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0]))
template <bool inner>
static void transform(
......
......@@ -46,9 +46,9 @@ namespace {
} while (0);
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1])
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[1]))
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0])
GiExtractLane##idx##Float32(GiFixLenType2GiFloat32Type(CONCAT(s, i).value.val[0]))
struct InputTransform6X3 {
template <bool inner>
static void transform(
......
......@@ -215,7 +215,7 @@ struct OutputTransform6X3 {
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}
......
......@@ -153,7 +153,7 @@ struct OutputTransformF23_NCHW44 {
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb
}
......@@ -165,7 +165,7 @@ struct OutputTransformF23_NCHW44 {
if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 4>::load( \
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
v##oho##owo = op(v##oho##owo.value); \
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \
} \
v##oho##owo.save( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
......
......@@ -102,17 +102,17 @@ struct InputTransformF63_NCHW44 {
auto t##i##4 = d6; \
auto t##i##5 = d6; \
auto t##i##6 = d6; \
t##i##0 = t##i##0 - d6; \
t##i##1 = t##i##1 + d1; \
t##i##2 = t##i##2 - d1; \
t##i##0 = GiSubtractFloat32(t##i##0, d6); \
t##i##1 = GiAddFloat32(t##i##1, d1); \
t##i##2 = GiSubtractFloat32(t##i##2, d1); \
t##i##3 = GiSimdFmaLane(t##i##3, d1, v0, 2); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d1, v0, 2); \
t##i##5 = GiSimdFmaLane(t##i##5, d1, v1, 2); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d1, v1, 2); \
t##i##7 = t##i##7 - d1; \
t##i##7 = GiSubtractFloat32(t##i##7, d1); \
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 0); \
t##i##1 = t##i##1 + d2; \
t##i##2 = t##i##2 + d2; \
t##i##1 = GiAddFloat32(t##i##1, d2); \
t##i##2 = GiAddFloat32(t##i##2, d2); \
t##i##3 = GiSimdFmaLane(t##i##3, d2, v0, 3); \
t##i##4 = GiSimdFmaLane(t##i##4, d2, v0, 3); \
t##i##5 = GiSimdFmaLane(t##i##5, d2, v1, 3); \
......@@ -131,8 +131,8 @@ struct InputTransformF63_NCHW44 {
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d4, v1, 1); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d4, v2, 0); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d4, v2, 0); \
t##i##1 = t##i##1 + d5; \
t##i##2 = t##i##2 - d5; \
t##i##1 = GiAddFloat32(t##i##1, d5); \
t##i##2 = GiSubtractFloat32(t##i##2, d5); \
t##i##3 = GiSimdFmaLane(t##i##3, d5, v1, 2); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d5, v1, 2); \
t##i##5 = GiSimdFmaLane(t##i##5, d5, v0, 2); \
......@@ -150,17 +150,17 @@ struct InputTransformF63_NCHW44 {
d5 = t6##i; \
d6 = t6##i; \
d7 = t7##i; \
d0 = d0 - t6##i; \
d1 = d1 + t1##i; \
d2 = d2 - t1##i; \
d0 = GiSubtractFloat32(d0, t6##i); \
d1 = GiAddFloat32(d1, t1##i); \
d2 = GiSubtractFloat32(d2, t1##i); \
d3 = GiSimdFmaLane(d3, t1##i, v0, 2); \
d4 = GiFmsqLaneQFloat32(d4, t1##i, v0, 2); \
d5 = GiSimdFmaLane(d5, t1##i, v1, 2); \
d6 = GiFmsqLaneQFloat32(d6, t1##i, v1, 2); \
d7 = d7 - t1##i; \
d7 = GiSubtractFloat32(d7, t1##i); \
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 0); \
d1 = d1 + t2##i; \
d2 = d2 + t2##i; \
d1 = GiAddFloat32(d1, t2##i); \
d2 = GiAddFloat32(d2, t2##i); \
d3 = GiSimdFmaLane(d3, t2##i, v0, 3); \
d4 = GiSimdFmaLane(d4, t2##i, v0, 3); \
d5 = GiSimdFmaLane(d5, t2##i, v1, 3); \
......@@ -179,8 +179,8 @@ struct InputTransformF63_NCHW44 {
d4 = GiFmsqLaneQFloat32(d4, t4##i, v1, 1); \
d5 = GiFmsqLaneQFloat32(d5, t4##i, v2, 0); \
d6 = GiFmsqLaneQFloat32(d6, t4##i, v2, 0); \
d1 = d1 + t5##i; \
d2 = d2 - t5##i; \
d1 = GiAddFloat32(d1, t5##i); \
d2 = GiSubtractFloat32(d2, t5##i); \
d3 = GiSimdFmaLane(d3, t5##i, v1, 2); \
d4 = GiFmsqLaneQFloat32(d4, t5##i, v1, 2); \
d5 = GiSimdFmaLane(d5, t5##i, v0, 2); \
......@@ -311,7 +311,7 @@ struct OutputTransformF63_NCHW44 {
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}
......@@ -323,7 +323,7 @@ struct OutputTransformF63_NCHW44 {
if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 4>::load( \
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
v##oho##owo = op(v##oho##owo.value); \
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \
} \
v##oho##owo.save( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
......
......@@ -121,14 +121,14 @@ struct InputTransformF73_NCHW44 {
auto t##i##6 = d7; \
auto t##i##7 = d7; \
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d7, v0, 0); \
t##i##0 = t##i##0 - d1; \
t##i##0 = GiSubtractFloat32(t##i##0, d1); \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d1, v0, 0); \
t##i##2 = GiSimdFmaLane(t##i##2, d1, v0, 0); \
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d1, v0, 1); \
t##i##4 = GiSimdFmaLane(t##i##4, d1, v0, 1); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d1, v0, 2); \
t##i##6 = GiSimdFmaLane(t##i##6, d1, v0, 2); \
t##i##7 = t##i##7 - d1; \
t##i##7 = GiSubtractFloat32(t##i##7, d1); \
t##i##8 = GiSimdFmaLane(t##i##8, d1, v0, 0); \
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 3); \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d2, v1, 0); \
......@@ -137,7 +137,7 @@ struct InputTransformF73_NCHW44 {
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d2, v1, 3); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d2, v2, 0); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d2, v2, 1); \
t##i##8 = t##i##8 - d2; \
t##i##8 = GiSubtractFloat32(t##i##8, d2); \
t##i##0 = GiSimdFmaLane(t##i##0, d3, v2, 2); \
t##i##1 = GiSimdFmaLane(t##i##1, d3, v2, 3); \
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d3, v3, 0); \
......@@ -169,7 +169,7 @@ struct InputTransformF73_NCHW44 {
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d6, v1, 1); \
t##i##3 = GiSimdFmaLane(t##i##3, d6, v1, 0); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d6, v3, 1); \
t##i##5 = t##i##5 - d6; \
t##i##5 = GiSubtractFloat32(t##i##5, d6); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d6, v6, 2); \
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d6, v2, 2); \
t##i##0 = GiSimdFmaLane(t##i##0, d0, v0, 0);
......@@ -188,14 +188,14 @@ struct InputTransformF73_NCHW44 {
d6 = t7##i; \
d7 = t7##i; \
d8 = GiFmsqLaneQFloat32(d8, t7##i, v0, 0); \
d0 = d0 - t1##i; \
d0 = GiSubtractFloat32(d0, t1##i); \
d1 = GiFmsqLaneQFloat32(d1, t1##i, v0, 0); \
d2 = GiSimdFmaLane(d2, t1##i, v0, 0); \
d3 = GiFmsqLaneQFloat32(d3, t1##i, v0, 1); \
d4 = GiSimdFmaLane(d4, t1##i, v0, 1); \
d5 = GiFmsqLaneQFloat32(d5, t1##i, v0, 2); \
d6 = GiSimdFmaLane(d6, t1##i, v0, 2); \
d7 = d7 - t1##i; \
d7 = GiSubtractFloat32(d7, t1##i); \
d8 = GiSimdFmaLane(d8, t1##i, v0, 0); \
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 3); \
d1 = GiFmsqLaneQFloat32(d1, t2##i, v1, 0); \
......@@ -204,7 +204,7 @@ struct InputTransformF73_NCHW44 {
d4 = GiFmsqLaneQFloat32(d4, t2##i, v1, 3); \
d5 = GiFmsqLaneQFloat32(d5, t2##i, v2, 0); \
d6 = GiFmsqLaneQFloat32(d6, t2##i, v2, 1); \
d8 = d8 - t2##i; \
d8 = GiSubtractFloat32(d8, t2##i); \
d0 = GiSimdFmaLane(d0, t3##i, v2, 2); \
d1 = GiSimdFmaLane(d1, t3##i, v2, 3); \
d2 = GiFmsqLaneQFloat32(d2, t3##i, v3, 0); \
......@@ -236,7 +236,7 @@ struct InputTransformF73_NCHW44 {
d2 = GiFmsqLaneQFloat32(d2, t6##i, v1, 1); \
d3 = GiSimdFmaLane(d3, t6##i, v1, 0); \
d4 = GiFmsqLaneQFloat32(d4, t6##i, v3, 1); \
d5 = d5 - t6##i; \
d5 = GiSubtractFloat32(d5, t6##i); \
d6 = GiFmsqLaneQFloat32(d6, t6##i, v6, 2); \
d8 = GiFmsqLaneQFloat32(d8, t6##i, v2, 2); \
d0 = GiSimdFmaLane(d0, t0##i, v0, 0); \
......@@ -377,7 +377,7 @@ struct OutputTransformF73_NCHW44 {
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
#define cb(m, n) v##m##n = op(GiFixLenType2GiFloat32Type(CONCAT(v##m, n).value));
UNROLL_CALL_RAW_D2(7, 7, cb);
#undef cb
}
......@@ -389,7 +389,7 @@ struct OutputTransformF73_NCHW44 {
if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 4>::load( \
bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
v##oho##owo = op(v##oho##owo.value); \
v##oho##owo = op(GiFixLenType2GiFloat32Type(v##oho##owo.value)); \
} \
v##oho##owo.save( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \
......
......@@ -12,49 +12,64 @@ struct Vector;
template <>
struct Vector<float, 4> {
GI_FLOAT32_t value;
GI_FLOAT32_FIXLEN_t value;
Vector() {}
Vector(const float v) { value = GiBroadcastFloat32(v); }
Vector(const float v) { value = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); }
Vector(const Vector& lr) { value = lr.value; }
Vector(const Vector&& lr) { value = std::move(lr.value); }
Vector(const GI_FLOAT32_t& v) { value = v; }
Vector(const GI_FLOAT32_t& v) { value = GiFloat32Type2FixLenType(v); }
static Vector load(const float* addr) {
Vector v;
v.value = GiLoadFloat32(addr);
v.value = GiFloat32Type2FixLenType(GiLoadFloat32(addr));
return v;
}
static void save(float* addr, const Vector& v) { GiStoreFloat32(addr, v.value); }
static void save(float* addr, const Vector& v) {
GiStoreFloat32(addr, GiFixLenType2GiFloat32Type(v.value));
}
void save(float* addr) { save(addr, *this); }
Vector operator+(const Vector& lr) {
Vector dst;
dst.value = GiAddFloat32(value, lr.value);
dst.value = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return dst;
}
Vector& operator+=(const Vector& lr) {
value = GiAddFloat32(value, lr.value);
value = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return *this;
}
Vector operator-(const Vector& lr) {
Vector dst;
dst.value = GiSubtractFloat32(value, lr.value);
dst.value = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return dst;
}
Vector& operator-=(const Vector& lr) {
value = GiSubtractFloat32(value, lr.value);
value = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return *this;
}
Vector operator*(float lr) {
Vector dst;
dst.value = GiMultiplyScalerFloat32(value, lr);
dst.value = GiFloat32Type2FixLenType(
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value), lr));
return dst;
}
Vector operator*(const Vector& lr) {
Vector dst;
dst.value = GiMultiplyFloat32(value, lr.value);
dst.value = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return dst;
}
Vector& operator*=(const Vector& lr) {
value = GiMultiplyFloat32(value, lr.value);
value = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value),
GiFixLenType2GiFloat32Type(lr.value)));
return *this;
}
Vector& operator=(const Vector& lr) {
......@@ -74,72 +89,108 @@ struct Vector<float, 4> {
template <>
struct Vector<float, 8> {
GI_FLOAT32_V2_t value;
GI_FLOAT32_FIXLEN_V2_t value;
Vector() {}
Vector(const float v) {
value.val[0] = GiBroadcastFloat32(v);
value.val[1] = GiBroadcastFloat32(v);
value.val[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v));
value.val[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v));
}
Vector(const Vector& lr) { value = lr.value; }
Vector(const Vector&& lr) { value = std::move(lr.value); }
Vector(const GI_FLOAT32_V2_t& v) { value = v; }
Vector(const GI_FLOAT32_V2_t& v) { value = GiFloat32Type2FixLenV2Type(v); }
static Vector load(const float* addr) {
Vector v;
v.value = GiLoadFloat32V2(addr);
v.value = GiFloat32Type2FixLenV2Type(GiLoadFloat32V2(addr));
return v;
}
static void save(float* addr, const Vector& v) { GiStoreFloat32V2(addr, v.value); }
static void save(float* addr, const Vector& v) {
GiStoreFloat32V2(addr, GiFixLenType2GiFloat32V2Type(v.value));
}
void save(float* addr) { save(addr, *this); }
Vector operator+(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]);
dst.value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]);
dst.value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
dst.value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return dst;
}
Vector& operator+=(const Vector& lr) {
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]);
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector& add(const Vector& lr) {
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]);
value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector operator-(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]);
dst.value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]);
dst.value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
dst.value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return dst;
}
Vector& operator-=(const Vector& lr) {
value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]);
value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector operator*(float lr) {
Vector dst;
dst.value.val[0] = GiMultiplyScalerFloat32(value.val[0], lr);
dst.value.val[1] = GiMultiplyScalerFloat32(value.val[1], lr);
dst.value.val[0] = GiFloat32Type2FixLenType(
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[0]), lr));
dst.value.val[1] = GiFloat32Type2FixLenType(
GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[1]), lr));
return dst;
}
//! val + lr * n
Vector& mla(const Vector& lr, float n) {
value.val[0] = GiMultiplyAddScalarFloat32(value.val[0], lr.value.val[0], n);
value.val[1] = GiMultiplyAddScalarFloat32(value.val[1], lr.value.val[1], n);
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0]), n));
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1]), n));
return *this;
}
Vector operator*(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]);
dst.value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]);
dst.value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
dst.value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return dst;
}
Vector& operator*=(const Vector& lr) {
value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]);
value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[0]),
GiFixLenType2GiFloat32Type(lr.value.val[0])));
value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(value.val[1]),
GiFixLenType2GiFloat32Type(lr.value.val[1])));
return *this;
}
Vector& operator=(const Vector& lr) {
......
......@@ -515,26 +515,26 @@ typedef GI_INT8_t GI_INT8_FIXLEN_t;
typedef GI_INT16_t GI_INT16_FIXLEN_t;
typedef GI_INT32_t GI_INT32_FIXLEN_t;
typedef GI_UINT32_t GI_UINT32_FIXLEN_t;
#define GiFloat32Type2FixLenType(s) (s)
#define GiFixLenType2GiFloat32Type(s) (s)
#define GiFloat32Type2FixLenType(s) s
#define GiFixLenType2GiFloat32Type(s) s
#define GiFloat32Type2FixLenV2Type(s) (s)
#define GiFixLenType2GiFloat32V2Type(s) (s)
#define GiFloat32Type2FixLenV2Type(s) s
#define GiFixLenType2GiFloat32V2Type(s) s
#define GiUint8Type2FixLenType(s) (s)
#define GiFixLenType2GiUint8Type(s) (s)
#define GiUint8Type2FixLenType(s) s
#define GiFixLenType2GiUint8Type(s) s
#define GiInt8Type2FixLenType(s) (s)
#define GiFixLenType2GiInt8Type(s) (s)
#define GiInt8Type2FixLenType(s) s
#define GiFixLenType2GiInt8Type(s) s
#define GiInt16Type2FixLenType(s) (s)
#define GiFixLenType2GiInt16Type(s) (s)
#define GiInt16Type2FixLenType(s) s
#define GiFixLenType2GiInt16Type(s) s
#define GiInt32Type2FixLenType(s) (s)
#define GiFixLenType2GiInt32Type(s) (s)
#define GiInt32Type2FixLenType(s) s
#define GiFixLenType2GiInt32Type(s) s
#define GiUint32Type2FixLenType(s) (s)
#define GiFixLenType2GiUint32Type(s) (s)
#define GiUint32Type2FixLenType(s) s
#define GiFixLenType2GiUint32Type(s) s
//! get subvector
#define GiGetSubVectorFloat32V2(s, index) s.val[index]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册