提交 19d36fa0 编写于 作者: M Megvii Engine Team

feat(gi): make pooling apply gi class type

GitOrigin-RevId: e60c6a2e7640cdfa75c1e7ad89f2ed78865847f5
上级 8546c15d
......@@ -8,11 +8,11 @@
namespace megdnn {
namespace fallback {
#define GI_UZP(s0, s1, d0, d1) \
do { \
auto tmp__ = GiUzpqFloat32(s0, s1); \
d0 = tmp__.val[0]; \
d1 = tmp__.val[1]; \
#define GI_UZP(s0, s1, d0, d1) \
do { \
auto tmp__ = GiUzpqFloat32(s0, s1); \
d0 = GiGetSubVectorFloat32V2(tmp__, 0); \
d1 = GiGetSubVectorFloat32V2(tmp__, 1); \
} while (0)
void do_max_pooling_3x3_s2x2_float_gi(
......
......@@ -29,17 +29,33 @@ void calculate_xsx_nchw44(T1 result, T2 src) {
CalXsXNchw44<filter, stride, ow_step, mode, T1, T2>::impl(result, src);
};
#define CALCULATE_MAX_CB(step) \
result[0] = GiMaximumFloat32(result[0], src[0 * stride + step]); \
result[1] = GiMaximumFloat32(result[1], src[1 * stride + step]); \
result[2] = GiMaximumFloat32(result[2], src[2 * stride + step]); \
result[3] = GiMaximumFloat32(result[3], src[3 * stride + step]);
#define CALCULATE_MAX_CB(step) \
result[0] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[0]), \
GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \
result[1] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[1]), \
GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \
result[2] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[2]), \
GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \
result[3] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[3]), \
GiFixLenType2GiFloat32Type(src[3 * stride + step])));
#define CALCULATE_AVG_CB(step) \
result[0] = GiAddFloat32(result[0], src[0 * stride + step]); \
result[1] = GiAddFloat32(result[1], src[1 * stride + step]); \
result[2] = GiAddFloat32(result[2], src[2 * stride + step]); \
result[3] = GiAddFloat32(result[3], src[3 * stride + step]);
#define CALCULATE_AVG_CB(step) \
result[0] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[0]), \
GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \
result[1] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[1]), \
GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \
result[2] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[2]), \
GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \
result[3] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[3]), \
GiFixLenType2GiFloat32Type(src[3 * stride + step])));
#define INSTANCE_CAL(filter) \
template <int stride, typename T1, typename T2> \
......@@ -78,13 +94,13 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode
constexpr int packed_ic = 4;
constexpr int simd_len = 4;
constexpr float default_float = std::numeric_limits<float>::lowest();
GI_FLOAT32_t result[ow_step];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t result[ow_step];
GI_FLOAT32_FIXLEN_t src[src_reg_size];
result[0] = GiBroadcastFloat32(default_float);
result[1] = GiBroadcastFloat32(default_float);
result[2] = GiBroadcastFloat32(default_float);
result[3] = GiBroadcastFloat32(default_float);
result[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[2] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[3] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>(
......@@ -93,10 +109,10 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode
result, src);
}
GiStoreFloat32(dst_ptr + 0 * packed_ic, result[0]);
GiStoreFloat32(dst_ptr + 1 * packed_ic, result[1]);
GiStoreFloat32(dst_ptr + 2 * packed_ic, result[2]);
GiStoreFloat32(dst_ptr + 3 * packed_ic, result[3]);
GiStoreFloat32(dst_ptr + 0 * packed_ic, GiFixLenType2GiFloat32Type(result[0]));
GiStoreFloat32(dst_ptr + 1 * packed_ic, GiFixLenType2GiFloat32Type(result[1]));
GiStoreFloat32(dst_ptr + 2 * packed_ic, GiFixLenType2GiFloat32Type(result[2]));
GiStoreFloat32(dst_ptr + 3 * packed_ic, GiFixLenType2GiFloat32Type(result[3]));
}
};
......@@ -110,28 +126,36 @@ struct KerPoolingFilterXStrideXNchw44<
constexpr float default_float = 0;
constexpr float div_filter_size = 1.f / (filter * filter);
const GI_FLOAT32_t div_filter_size_vec = GiBroadcastFloat32(div_filter_size);
GI_FLOAT32_t result[ow_step];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_FIXLEN_t result[ow_step];
GI_FLOAT32_FIXLEN_t src[src_reg_size];
result[0] = GiBroadcastFloat32(default_float);
result[1] = GiBroadcastFloat32(default_float);
result[2] = GiBroadcastFloat32(default_float);
result[3] = GiBroadcastFloat32(default_float);
result[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[2] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[3] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>(
src, src_ptr + fh_idx * iw * packed_ic, 0);
calculate_xsx_nchw44<filter, stride, ow_step, PoolingBase::Mode::AVERAGE>(
result, src);
}
result[0] = GiMultiplyFloat32(result[0], div_filter_size_vec);
result[1] = GiMultiplyFloat32(result[1], div_filter_size_vec);
result[2] = GiMultiplyFloat32(result[2], div_filter_size_vec);
result[3] = GiMultiplyFloat32(result[3], div_filter_size_vec);
GiStoreFloat32(dst_ptr + 0 * packed_ic, result[0]);
GiStoreFloat32(dst_ptr + 1 * packed_ic, result[1]);
GiStoreFloat32(dst_ptr + 2 * packed_ic, result[2]);
GiStoreFloat32(dst_ptr + 3 * packed_ic, result[3]);
};
GiStoreFloat32(
dst_ptr + 0 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[0]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 1 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[1]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 2 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[2]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 3 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[3]), div_filter_size_vec));
}
};
......
......@@ -56,18 +56,20 @@ struct GiMeanPooler<area, dt_float32, float, float> {
static constexpr int MIDOUT_CASE_NUM = 1;
static constexpr int SIMD_WIDTH = 4;
static const GI_FLOAT32_t coef;
GI_FLOAT32_t res;
GiMeanPooler(DType) : res(GiBroadcastFloat32(0.0f)) {}
void feed(const float* val) { res = GiAddFloat32(res, GiLoadFloat32(val)); }
GI_FLOAT32_FIXLEN_t res, coef;
GiMeanPooler(DType)
: res(GiFloat32Type2FixLenType(GiBroadcastFloat32(0.0f))),
coef(GiFloat32Type2FixLenType(GiBroadcastFloat32(1.0f / area))) {}
void feed(const float* val) {
res = GiFloat32Type2FixLenType(
GiAddFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val)));
}
void post(float* dst) {
res = GiMultiplyFloat32(res, coef);
GiStoreFloat32(dst, res);
res = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(res), GiFixLenType2GiFloat32Type(coef)));
GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res));
}
};
template <int area>
const GI_FLOAT32_t GiMeanPooler<area, dt_float32, float, float>::coef =
GiBroadcastFloat32(1.0f / area);
/* ======================= MaxPooler ======================== */
......@@ -96,10 +98,15 @@ struct GiMaxPooler<area, dt_float32, float, float> {
static constexpr int MIDOUT_CASE_NUM = 11;
static constexpr int SIMD_WIDTH = 4;
GI_FLOAT32_t res;
GiMaxPooler(DType) : res(GiBroadcastFloat32(DTypeTrait<dt_float32>::min())) {}
void feed(const float* val) { res = GiMaximumFloat32(res, GiLoadFloat32(val)); }
void post(float* dst) { GiStoreFloat32(dst, res); }
GI_FLOAT32_FIXLEN_t res;
GiMaxPooler(DType)
: res(GiFloat32Type2FixLenType(
GiBroadcastFloat32(DTypeTrait<dt_float32>::min()))) {}
void feed(const float* val) {
res = GiFloat32Type2FixLenType(
GiMaximumFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val)));
}
void post(float* dst) { GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res)); }
};
template <typename Pooler, int window>
......@@ -137,7 +144,8 @@ struct do_pxl_2x2_pack_proxy<
const int IW, const int OH, const int OW, const int PH, const int PW) {
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH);
static const auto avg_coef = GiBroadcastFloat32(0.25f);
static const auto avg_coef =
GiFloat32Type2FixLenType(GiBroadcastFloat32(0.25f));
int ih = -PH + 2 * oh;
int iw = -PW + 2 * ow;
auto i00 = GiLoadFloat32(src + (ih + 0) * IW + (iw + 0)),
......@@ -148,7 +156,7 @@ struct do_pxl_2x2_pack_proxy<
auto vlow = GiPaddFloat32(GiGetLowFloat32(sum0), GiGetHighFloat32(sum0));
auto vhigh = GiPaddFloat32(GiGetLowFloat32(sum1), GiGetHighFloat32(sum1));
auto comb = GiCombineFloat32(vlow, vhigh);
auto result = GiMultiplyFloat32(comb, avg_coef);
auto result = GiMultiplyFloat32(comb, GiFixLenType2GiFloat32Type(avg_coef));
GiStoreFloat32(dst + oh * OW + ow, result);
}
};
......@@ -327,8 +335,8 @@ void do_max_pooling_w5x5_s2x2_gi(
auto s0 = GiLoadFloat32(sptr + iw + 0);
auto s1 = GiLoadFloat32(sptr + iw + MEGDNN_SIMD_WIDTH);
auto d = GiUzpqFloat32(s0, s1);
GiStoreFloat32(even + even_offset, d.val[0]);
GiStoreFloat32(odd + odd_offset, d.val[1]);
GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(d, 0));
GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(d, 1));
even_offset += MEGDNN_SIMD_WIDTH;
odd_offset += MEGDNN_SIMD_WIDTH;
}
......@@ -464,8 +472,8 @@ void do_average_pooling_3x3_s2x2_gi(
for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) {
auto s0 = GiLd2qFloat32(sptr + iw);
GiStoreFloat32(even + even_offset, s0.val[0]);
GiStoreFloat32(odd + odd_offset, s0.val[1]);
GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(s0, 0));
GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(s0, 1));
even_offset += MEGDNN_SIMD_WIDTH;
odd_offset += MEGDNN_SIMD_WIDTH;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册