未验证 提交 0a9f5f17 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13968 from tensor-tang/fix/jit/exp

Fix jit exp
......@@ -18,12 +18,12 @@ namespace paddle {
namespace inference {
using namespace framework; // NOLINT
static std::vector<float> result_data;
struct DataRecord {
std::vector<std::vector<std::vector<float>>> link_step_data_all;
std::vector<size_t> lod;
std::vector<std::vector<float>> rnn_link_data;
std::vector<float> result_data;
size_t num_samples; // total number of samples
size_t batch_iter{0};
size_t batch_size{1};
......@@ -57,6 +57,7 @@ struct DataRecord {
std::ifstream file(path);
std::string line;
int num_lines = 0;
result_data.clear();
while (std::getline(file, line)) {
num_lines++;
std::vector<std::string> data;
......@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], data.result_data[i], 1e-3);
EXPECT_NEAR(result[i], result_data[i], 1e-3);
}
}
}
......
......@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
DEPS cpu_info cblas activation_functions)
DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
......@@ -27,13 +27,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
#ifdef __AVX__
namespace detail {
__m256 Exp(__m256 a);
} // namespace detail
#endif
namespace jitkernel {
namespace jit = platform::jit;
......@@ -69,37 +62,186 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
#endif
#define INTRI8_FLOAT(isa) \
namespace detail {
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST(0x7f, 0x7f);
_PS256_CONST(one, 1.f);
_PS256_CONST(0p5, 0.5f);
_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
typedef union imm_xmm_union {
__m256i imm;
__m128i xmm[2];
} imm_xmm_union;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */ \
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */ \
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2(slli_epi32);
AVX2_INTOP_USING_SSE2(add_epi32);
#define AVXEXP_BASE \
__m256 tmp = _mm256_setzero_ps(), fx; \
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
__m256i imm0; \
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
/* express exp(x) as exp(g + n*log(2)) */ \
fx = _mm256_mul_ps(x, \
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
tmp = _mm256_floor_ps(fx); \
/* if greater, substract 1 */ \
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
mask = _mm256_and_ps(mask, one); \
fx = _mm256_sub_ps(tmp, mask); \
tmp = _mm256_mul_ps(fx, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
__m256 z = _mm256_mul_ps( \
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
x = _mm256_sub_ps(x, tmp); \
x = _mm256_sub_ps(x, z); \
z = _mm256_mul_ps(x, x); \
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
y = _mm256_mul_ps(y, z); \
y = _mm256_add_ps(y, x); \
y = _mm256_add_ps(y, one); \
/* build 2^n */ \
imm0 = _mm256_cvttps_epi32(fx)
__m256 ExpAVX(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions using SSE2
imm0 = avx2_mm256_add_epi32(imm0,
*reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = avx2_mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x) {
AVXEXP_BASE;
// two AVX2 instructions
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
imm0 = _mm256_slli_epi32(imm0, 23);
__m256 pow2n = _mm256_castsi256_ps(imm0);
y = _mm256_mul_ps(y, pow2n);
return y;
}
#endif
} // namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y, detail::Exp(tmp)); \
_mm256_storeu_ps(y, expisa(tmp)); \
}
#define INTRI16_FLOAT(isa) \
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = detail::Exp(tmp0); \
tmp1 = detail::Exp(tmp1); \
tmp0 = expisa(tmp0); \
tmp1 = expisa(tmp1); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#endif
// TODO(TJ): eq16 test and complete avx512
......@@ -135,26 +277,27 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
std::shared_ptr<const VExpKernel<T>> vexp_;
};
#define INTRI_SIGMOID(tmp, min, max) \
#define INTRI_SIGMOID(tmp, min, max, expisa) \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp = detail::Exp(tmp); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
/* TODO(TJ): try to use static const*/ \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa) \
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \
float* y) const { \
......@@ -162,13 +305,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max); \
INTRI_SIGMOID(tmp1, min, max); \
INTRI_SIGMOID(tmp0, min, max, expisa); \
INTRI_SIGMOID(tmp1, min, max, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa) \
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
......@@ -184,7 +327,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
......@@ -198,7 +341,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
} \
}
#define INTRI_GT16_FLOAT(isa) \
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
......@@ -215,7 +358,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_SIGMOID(tmp, min, max); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
const float min_ = SIGMOID_THRESHOLD_MIN; \
......@@ -231,22 +374,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_GT8LT16_FLOAT(jit::avx);
INTRI_GT16_FLOAT(jit::avx);
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
// INTRI_GT8LT16_FLOAT(jit::avx2);
// INTRI_GT16_FLOAT(jit::avx2);
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
// INTRI_GT8LT16_FLOAT(jit::avx512f);
// INTRI_GT16_FLOAT(jit::avx512f);
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx2 at gt8lt16 and gt16
#endif
#undef INTRI8_FLOAT
......@@ -280,36 +421,36 @@ class VTanhKernelImpl : public VTanhKernel<T> {
std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
};
#define INTRI_VTANH(tmp) \
#define INTRI_VTANH(tmp, expisa) \
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp = detail::Exp(tmp); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa) \
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0); \
INTRI_VTANH(tmp1); \
INTRI_VTANH(tmp0, expisa); \
INTRI_VTANH(tmp1, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa) \
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
......@@ -327,7 +468,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
......@@ -337,7 +478,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_->Compute(-1.f, y, y); \
}
#define INTRI_GT16_FLOAT(isa) \
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
......@@ -356,7 +497,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
x += this->end_; \
......@@ -368,19 +509,19 @@ class VTanhKernelImpl : public VTanhKernel<T> {
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_GT8LT16_FLOAT(jit::avx);
INTRI_GT16_FLOAT(jit::avx);
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
// maybe use avx at gt8lt16 and gt16
#endif
......
......@@ -25,13 +25,18 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
#ifdef __AVX__
namespace jitkernel {
namespace detail {
__m256 Exp(__m256 a);
} // namespace detail
#ifdef __AVX__
__m256 ExpAVX(__m256 x);
#endif
namespace jitkernel {
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x);
#endif
} // namespace detail
namespace jit = platform::jit;
#ifdef __AVX__
......@@ -43,43 +48,72 @@ class AVXAct {
virtual __m256 Compute(__m256 x) const = 0;
};
template <act_type type>
template <act_type type, jit::cpu_isa_t isa>
class AVXActImpl : public AVXAct {
public:
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
};
template <>
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const {
__m256 ones = _mm256_set1_ps(1.0f);
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN));
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX));
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x);
x = detail::Exp(x);
x = _mm256_add_ps(ones, x);
return _mm256_div_ps(ones, x);
}
#define AVX_SIGMOID(isa, expisa) \
template <> \
__m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
return _mm256_div_ps(ones, x); \
}
template <>
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const {
__m256 ones = _mm256_set1_ps(1.0f);
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x);
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT));
x = detail::Exp(x);
x = _mm256_add_ps(ones, x);
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x);
return _mm256_sub_ps(x, ones);
}
#define AVX_TANH(isa, expisa) \
template <> \
__m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
return _mm256_sub_ps(x, ones); \
}
template <>
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const {
return _mm256_max_ps(x, _mm256_setzero_ps());
}
#define AVX_RELU(isa) \
template <> \
__m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
return _mm256_max_ps(x, _mm256_setzero_ps()); \
}
#define AVX_IDENTITY(isa) \
template <> \
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
return x; \
}
#define FOR_EACH_AVX_ISA(macro_) \
macro_(jit::avx); \
macro_(jit::avx2); \
macro_(jit::avx512f)
FOR_EACH_AVX_ISA(AVX_RELU);
FOR_EACH_AVX_ISA(AVX_IDENTITY);
AVX_SIGMOID(jit::avx, detail::ExpAVX);
AVX_TANH(jit::avx, detail::ExpAVX);
#ifdef __AVX2__
AVX_SIGMOID(jit::avx2, detail::ExpAVX2);
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
AVX_TANH(jit::avx2, detail::ExpAVX2);
AVX_TANH(jit::avx512f, detail::ExpAVX2);
#endif
#undef FOR_EACH_AVX_ISA
#undef AVX_IDENTITY
#undef AVX_RELU
#undef AVX_TANH
#undef AVX_SIGMOID
template <>
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
return x;
}
#endif
template <typename T>
......@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
#ifdef __AVX__
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
}
PADDLE_THROW("Not support type: %s", type);
};
avx_act_gate_ = GetAVXAct(act_gate);
avx_act_cand_ = GetAVXAct(act_cand);
avx_act_cell_ = GetAVXAct(act_cell);
#endif
}
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
......@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
#define INTRI8_FLOAT(isa) \
template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \
: LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
if (type == "sigmoid") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
} else if (type == "relu") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
} else if (type == "tanh") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
} else if (type == "identity" || type == "") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
} \
PADDLE_THROW("Not support type: %s", type); \
}; \
avx_act_gate_ = GetAVXAct(act_gate); \
avx_act_cand_ = GetAVXAct(act_cand); \
avx_act_cell_ = GetAVXAct(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册