提交 19d36fa0 编写于 作者: M Megvii Engine Team

feat(gi): make pooling apply gi class type

GitOrigin-RevId: e60c6a2e7640cdfa75c1e7ad89f2ed78865847f5
上级 8546c15d
...@@ -8,11 +8,11 @@ ...@@ -8,11 +8,11 @@
namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {
#define GI_UZP(s0, s1, d0, d1) \ #define GI_UZP(s0, s1, d0, d1) \
do { \ do { \
auto tmp__ = GiUzpqFloat32(s0, s1); \ auto tmp__ = GiUzpqFloat32(s0, s1); \
d0 = tmp__.val[0]; \ d0 = GiGetSubVectorFloat32V2(tmp__, 0); \
d1 = tmp__.val[1]; \ d1 = GiGetSubVectorFloat32V2(tmp__, 1); \
} while (0) } while (0)
void do_max_pooling_3x3_s2x2_float_gi( void do_max_pooling_3x3_s2x2_float_gi(
......
...@@ -29,17 +29,33 @@ void calculate_xsx_nchw44(T1 result, T2 src) { ...@@ -29,17 +29,33 @@ void calculate_xsx_nchw44(T1 result, T2 src) {
CalXsXNchw44<filter, stride, ow_step, mode, T1, T2>::impl(result, src); CalXsXNchw44<filter, stride, ow_step, mode, T1, T2>::impl(result, src);
}; };
#define CALCULATE_MAX_CB(step) \ #define CALCULATE_MAX_CB(step) \
result[0] = GiMaximumFloat32(result[0], src[0 * stride + step]); \ result[0] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
result[1] = GiMaximumFloat32(result[1], src[1 * stride + step]); \ GiFixLenType2GiFloat32Type(result[0]), \
result[2] = GiMaximumFloat32(result[2], src[2 * stride + step]); \ GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \
result[3] = GiMaximumFloat32(result[3], src[3 * stride + step]); result[1] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[1]), \
GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \
result[2] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[2]), \
GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \
result[3] = GiFloat32Type2FixLenType(GiMaximumFloat32( \
GiFixLenType2GiFloat32Type(result[3]), \
GiFixLenType2GiFloat32Type(src[3 * stride + step])));
#define CALCULATE_AVG_CB(step) \ #define CALCULATE_AVG_CB(step) \
result[0] = GiAddFloat32(result[0], src[0 * stride + step]); \ result[0] = GiFloat32Type2FixLenType(GiAddFloat32( \
result[1] = GiAddFloat32(result[1], src[1 * stride + step]); \ GiFixLenType2GiFloat32Type(result[0]), \
result[2] = GiAddFloat32(result[2], src[2 * stride + step]); \ GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \
result[3] = GiAddFloat32(result[3], src[3 * stride + step]); result[1] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[1]), \
GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \
result[2] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[2]), \
GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \
result[3] = GiFloat32Type2FixLenType(GiAddFloat32( \
GiFixLenType2GiFloat32Type(result[3]), \
GiFixLenType2GiFloat32Type(src[3 * stride + step])));
#define INSTANCE_CAL(filter) \ #define INSTANCE_CAL(filter) \
template <int stride, typename T1, typename T2> \ template <int stride, typename T1, typename T2> \
...@@ -78,13 +94,13 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode ...@@ -78,13 +94,13 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode
constexpr int packed_ic = 4; constexpr int packed_ic = 4;
constexpr int simd_len = 4; constexpr int simd_len = 4;
constexpr float default_float = std::numeric_limits<float>::lowest(); constexpr float default_float = std::numeric_limits<float>::lowest();
GI_FLOAT32_t result[ow_step]; GI_FLOAT32_FIXLEN_t result[ow_step];
GI_FLOAT32_t src[src_reg_size]; GI_FLOAT32_FIXLEN_t src[src_reg_size];
result[0] = GiBroadcastFloat32(default_float); result[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[1] = GiBroadcastFloat32(default_float); result[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[2] = GiBroadcastFloat32(default_float); result[2] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[3] = GiBroadcastFloat32(default_float); result[3] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>( load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>(
...@@ -93,10 +109,10 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode ...@@ -93,10 +109,10 @@ struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode
result, src); result, src);
} }
GiStoreFloat32(dst_ptr + 0 * packed_ic, result[0]); GiStoreFloat32(dst_ptr + 0 * packed_ic, GiFixLenType2GiFloat32Type(result[0]));
GiStoreFloat32(dst_ptr + 1 * packed_ic, result[1]); GiStoreFloat32(dst_ptr + 1 * packed_ic, GiFixLenType2GiFloat32Type(result[1]));
GiStoreFloat32(dst_ptr + 2 * packed_ic, result[2]); GiStoreFloat32(dst_ptr + 2 * packed_ic, GiFixLenType2GiFloat32Type(result[2]));
GiStoreFloat32(dst_ptr + 3 * packed_ic, result[3]); GiStoreFloat32(dst_ptr + 3 * packed_ic, GiFixLenType2GiFloat32Type(result[3]));
} }
}; };
...@@ -110,28 +126,36 @@ struct KerPoolingFilterXStrideXNchw44< ...@@ -110,28 +126,36 @@ struct KerPoolingFilterXStrideXNchw44<
constexpr float default_float = 0; constexpr float default_float = 0;
constexpr float div_filter_size = 1.f / (filter * filter); constexpr float div_filter_size = 1.f / (filter * filter);
const GI_FLOAT32_t div_filter_size_vec = GiBroadcastFloat32(div_filter_size); const GI_FLOAT32_t div_filter_size_vec = GiBroadcastFloat32(div_filter_size);
GI_FLOAT32_t result[ow_step]; GI_FLOAT32_FIXLEN_t result[ow_step];
GI_FLOAT32_t src[src_reg_size]; GI_FLOAT32_FIXLEN_t src[src_reg_size];
result[0] = GiBroadcastFloat32(default_float); result[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[1] = GiBroadcastFloat32(default_float); result[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[2] = GiBroadcastFloat32(default_float); result[2] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
result[3] = GiBroadcastFloat32(default_float); result[3] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float));
for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>( load_helper<src_reg_size, 0, simd_len, 0, GiD1Qf32>(
src, src_ptr + fh_idx * iw * packed_ic, 0); src, src_ptr + fh_idx * iw * packed_ic, 0);
calculate_xsx_nchw44<filter, stride, ow_step, PoolingBase::Mode::AVERAGE>( calculate_xsx_nchw44<filter, stride, ow_step, PoolingBase::Mode::AVERAGE>(
result, src); result, src);
} };
result[0] = GiMultiplyFloat32(result[0], div_filter_size_vec); GiStoreFloat32(
result[1] = GiMultiplyFloat32(result[1], div_filter_size_vec); dst_ptr + 0 * packed_ic,
result[2] = GiMultiplyFloat32(result[2], div_filter_size_vec); GiMultiplyFloat32(
result[3] = GiMultiplyFloat32(result[3], div_filter_size_vec); GiFixLenType2GiFloat32Type(result[0]), div_filter_size_vec));
GiStoreFloat32(dst_ptr + 0 * packed_ic, result[0]); GiStoreFloat32(
GiStoreFloat32(dst_ptr + 1 * packed_ic, result[1]); dst_ptr + 1 * packed_ic,
GiStoreFloat32(dst_ptr + 2 * packed_ic, result[2]); GiMultiplyFloat32(
GiStoreFloat32(dst_ptr + 3 * packed_ic, result[3]); GiFixLenType2GiFloat32Type(result[1]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 2 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[2]), div_filter_size_vec));
GiStoreFloat32(
dst_ptr + 3 * packed_ic,
GiMultiplyFloat32(
GiFixLenType2GiFloat32Type(result[3]), div_filter_size_vec));
} }
}; };
......
...@@ -56,18 +56,20 @@ struct GiMeanPooler<area, dt_float32, float, float> { ...@@ -56,18 +56,20 @@ struct GiMeanPooler<area, dt_float32, float, float> {
static constexpr int MIDOUT_CASE_NUM = 1; static constexpr int MIDOUT_CASE_NUM = 1;
static constexpr int SIMD_WIDTH = 4; static constexpr int SIMD_WIDTH = 4;
static const GI_FLOAT32_t coef; GI_FLOAT32_FIXLEN_t res, coef;
GI_FLOAT32_t res; GiMeanPooler(DType)
GiMeanPooler(DType) : res(GiBroadcastFloat32(0.0f)) {} : res(GiFloat32Type2FixLenType(GiBroadcastFloat32(0.0f))),
void feed(const float* val) { res = GiAddFloat32(res, GiLoadFloat32(val)); } coef(GiFloat32Type2FixLenType(GiBroadcastFloat32(1.0f / area))) {}
void feed(const float* val) {
res = GiFloat32Type2FixLenType(
GiAddFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val)));
}
void post(float* dst) { void post(float* dst) {
res = GiMultiplyFloat32(res, coef); res = GiFloat32Type2FixLenType(GiMultiplyFloat32(
GiStoreFloat32(dst, res); GiFixLenType2GiFloat32Type(res), GiFixLenType2GiFloat32Type(coef)));
GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res));
} }
}; };
template <int area>
const GI_FLOAT32_t GiMeanPooler<area, dt_float32, float, float>::coef =
GiBroadcastFloat32(1.0f / area);
/* ======================= MaxPooler ======================== */ /* ======================= MaxPooler ======================== */
...@@ -96,10 +98,15 @@ struct GiMaxPooler<area, dt_float32, float, float> { ...@@ -96,10 +98,15 @@ struct GiMaxPooler<area, dt_float32, float, float> {
static constexpr int MIDOUT_CASE_NUM = 11; static constexpr int MIDOUT_CASE_NUM = 11;
static constexpr int SIMD_WIDTH = 4; static constexpr int SIMD_WIDTH = 4;
GI_FLOAT32_t res; GI_FLOAT32_FIXLEN_t res;
GiMaxPooler(DType) : res(GiBroadcastFloat32(DTypeTrait<dt_float32>::min())) {} GiMaxPooler(DType)
void feed(const float* val) { res = GiMaximumFloat32(res, GiLoadFloat32(val)); } : res(GiFloat32Type2FixLenType(
void post(float* dst) { GiStoreFloat32(dst, res); } GiBroadcastFloat32(DTypeTrait<dt_float32>::min()))) {}
void feed(const float* val) {
res = GiFloat32Type2FixLenType(
GiMaximumFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val)));
}
void post(float* dst) { GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res)); }
}; };
template <typename Pooler, int window> template <typename Pooler, int window>
...@@ -137,7 +144,8 @@ struct do_pxl_2x2_pack_proxy< ...@@ -137,7 +144,8 @@ struct do_pxl_2x2_pack_proxy<
const int IW, const int OH, const int OW, const int PH, const int PW) { const int IW, const int OH, const int OW, const int PH, const int PW) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(OH);
static const auto avg_coef = GiBroadcastFloat32(0.25f); static const auto avg_coef =
GiFloat32Type2FixLenType(GiBroadcastFloat32(0.25f));
int ih = -PH + 2 * oh; int ih = -PH + 2 * oh;
int iw = -PW + 2 * ow; int iw = -PW + 2 * ow;
auto i00 = GiLoadFloat32(src + (ih + 0) * IW + (iw + 0)), auto i00 = GiLoadFloat32(src + (ih + 0) * IW + (iw + 0)),
...@@ -148,7 +156,7 @@ struct do_pxl_2x2_pack_proxy< ...@@ -148,7 +156,7 @@ struct do_pxl_2x2_pack_proxy<
auto vlow = GiPaddFloat32(GiGetLowFloat32(sum0), GiGetHighFloat32(sum0)); auto vlow = GiPaddFloat32(GiGetLowFloat32(sum0), GiGetHighFloat32(sum0));
auto vhigh = GiPaddFloat32(GiGetLowFloat32(sum1), GiGetHighFloat32(sum1)); auto vhigh = GiPaddFloat32(GiGetLowFloat32(sum1), GiGetHighFloat32(sum1));
auto comb = GiCombineFloat32(vlow, vhigh); auto comb = GiCombineFloat32(vlow, vhigh);
auto result = GiMultiplyFloat32(comb, avg_coef); auto result = GiMultiplyFloat32(comb, GiFixLenType2GiFloat32Type(avg_coef));
GiStoreFloat32(dst + oh * OW + ow, result); GiStoreFloat32(dst + oh * OW + ow, result);
} }
}; };
...@@ -327,8 +335,8 @@ void do_max_pooling_w5x5_s2x2_gi( ...@@ -327,8 +335,8 @@ void do_max_pooling_w5x5_s2x2_gi(
auto s0 = GiLoadFloat32(sptr + iw + 0); auto s0 = GiLoadFloat32(sptr + iw + 0);
auto s1 = GiLoadFloat32(sptr + iw + MEGDNN_SIMD_WIDTH); auto s1 = GiLoadFloat32(sptr + iw + MEGDNN_SIMD_WIDTH);
auto d = GiUzpqFloat32(s0, s1); auto d = GiUzpqFloat32(s0, s1);
GiStoreFloat32(even + even_offset, d.val[0]); GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(d, 0));
GiStoreFloat32(odd + odd_offset, d.val[1]); GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(d, 1));
even_offset += MEGDNN_SIMD_WIDTH; even_offset += MEGDNN_SIMD_WIDTH;
odd_offset += MEGDNN_SIMD_WIDTH; odd_offset += MEGDNN_SIMD_WIDTH;
} }
...@@ -464,8 +472,8 @@ void do_average_pooling_3x3_s2x2_gi( ...@@ -464,8 +472,8 @@ void do_average_pooling_3x3_s2x2_gi(
for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) { for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) {
auto s0 = GiLd2qFloat32(sptr + iw); auto s0 = GiLd2qFloat32(sptr + iw);
GiStoreFloat32(even + even_offset, s0.val[0]); GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(s0, 0));
GiStoreFloat32(odd + odd_offset, s0.val[1]); GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(s0, 1));
even_offset += MEGDNN_SIMD_WIDTH; even_offset += MEGDNN_SIMD_WIDTH;
odd_offset += MEGDNN_SIMD_WIDTH; odd_offset += MEGDNN_SIMD_WIDTH;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册