未验证 提交 0a9f5f17 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13968 from tensor-tang/fix/jit/exp

Fix jit exp
...@@ -18,12 +18,12 @@ namespace paddle { ...@@ -18,12 +18,12 @@ namespace paddle {
namespace inference { namespace inference {
using namespace framework; // NOLINT using namespace framework; // NOLINT
static std::vector<float> result_data;
struct DataRecord { struct DataRecord {
std::vector<std::vector<std::vector<float>>> link_step_data_all; std::vector<std::vector<std::vector<float>>> link_step_data_all;
std::vector<size_t> lod; std::vector<size_t> lod;
std::vector<std::vector<float>> rnn_link_data; std::vector<std::vector<float>> rnn_link_data;
std::vector<float> result_data;
size_t num_samples; // total number of samples size_t num_samples; // total number of samples
size_t batch_iter{0}; size_t batch_iter{0};
size_t batch_size{1}; size_t batch_size{1};
...@@ -57,6 +57,7 @@ struct DataRecord { ...@@ -57,6 +57,7 @@ struct DataRecord {
std::ifstream file(path); std::ifstream file(path);
std::string line; std::string line;
int num_lines = 0; int num_lines = 0;
result_data.clear();
while (std::getline(file, line)) { while (std::getline(file, line)) {
num_lines++; num_lines++;
std::vector<std::string> data; std::vector<std::string> data;
...@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) { ...@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result // the first inference result
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
PADDLE_ENFORCE_GT(outputs.size(), 0); PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]); size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0); PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data()); float *result = static_cast<float *>(outputs[0].data.data());
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], data.result_data[i], 1e-3); EXPECT_NEAR(result[i], result_data[i], 1e-3);
} }
} }
} }
......
...@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat) ...@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
DEPS cpu_info cblas activation_functions) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -25,13 +25,18 @@ limitations under the License. */ ...@@ -25,13 +25,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#ifdef __AVX__ namespace jitkernel {
namespace detail { namespace detail {
__m256 Exp(__m256 a); #ifdef __AVX__
} // namespace detail __m256 ExpAVX(__m256 x);
#endif #endif
namespace jitkernel { #ifdef __AVX2__
__m256 ExpAVX2(__m256 x);
#endif
} // namespace detail
namespace jit = platform::jit; namespace jit = platform::jit;
#ifdef __AVX__ #ifdef __AVX__
...@@ -43,43 +48,72 @@ class AVXAct { ...@@ -43,43 +48,72 @@ class AVXAct {
virtual __m256 Compute(__m256 x) const = 0; virtual __m256 Compute(__m256 x) const = 0;
}; };
template <act_type type> template <act_type type, jit::cpu_isa_t isa>
class AVXActImpl : public AVXAct { class AVXActImpl : public AVXAct {
public: public:
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); } __m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
}; };
template <> #define AVX_SIGMOID(isa, expisa) \
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const { template <> \
__m256 ones = _mm256_set1_ps(1.0f); __m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); __m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
x = detail::Exp(x); x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
x = _mm256_add_ps(ones, x); x = expisa(x); \
return _mm256_div_ps(ones, x); x = _mm256_add_ps(ones, x); \
} return _mm256_div_ps(ones, x); \
}
template <> #define AVX_TANH(isa, expisa) \
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const { template <> \
__m256 ones = _mm256_set1_ps(1.0f); __m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); __m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
x = detail::Exp(x); x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
x = _mm256_add_ps(ones, x); x = expisa(x); \
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); x = _mm256_add_ps(ones, x); \
return _mm256_sub_ps(x, ones); x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
} return _mm256_sub_ps(x, ones); \
}
template <> #define AVX_RELU(isa) \
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const { template <> \
return _mm256_max_ps(x, _mm256_setzero_ps()); __m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
} return _mm256_max_ps(x, _mm256_setzero_ps()); \
}
#define AVX_IDENTITY(isa) \
template <> \
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
return x; \
}
#define FOR_EACH_AVX_ISA(macro_) \
macro_(jit::avx); \
macro_(jit::avx2); \
macro_(jit::avx512f)
FOR_EACH_AVX_ISA(AVX_RELU);
FOR_EACH_AVX_ISA(AVX_IDENTITY);
AVX_SIGMOID(jit::avx, detail::ExpAVX);
AVX_TANH(jit::avx, detail::ExpAVX);
#ifdef __AVX2__
AVX_SIGMOID(jit::avx2, detail::ExpAVX2);
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
AVX_TANH(jit::avx2, detail::ExpAVX2);
AVX_TANH(jit::avx512f, detail::ExpAVX2);
#endif
#undef FOR_EACH_AVX_ISA
#undef AVX_IDENTITY
#undef AVX_RELU
#undef AVX_TANH
#undef AVX_SIGMOID
template <>
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
return x;
}
#endif #endif
template <typename T> template <typename T>
...@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_ = GetActKernel<T>(act_cell, d); act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d); vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d); vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
#ifdef __AVX__
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
}
PADDLE_THROW("Not support type: %s", type);
};
avx_act_gate_ = GetAVXAct(act_gate);
avx_act_cand_ = GetAVXAct(act_cand);
avx_act_cell_ = GetAVXAct(act_cell);
#endif
} }
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
...@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
float* gates, const float* ct_1, float* ct, float* ht, \ const std::string& act_gate, const std::string& act_cand, \
const float* wp_data, float* checked) const { \ const std::string& act_cell, int d) \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ : LSTMKernel<float>() { \
__m256 c, i, f, o; \ auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
c = _mm256_loadu_ps(gates); \ if (type == "sigmoid") { \
i = _mm256_loadu_ps(gates + 8); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
f = _mm256_loadu_ps(gates + 16); \ } else if (type == "relu") { \
o = _mm256_loadu_ps(gates + 24); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \ } else if (type == "tanh") { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
i = _mm256_loadu_ps(ct_1); \ } else if (type == "identity" || type == "") { \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
f = _mm256_add_ps(c, f); \ } \
_mm256_storeu_ps(ct, f); \ PADDLE_THROW("Not support type: %s", type); \
/* H_t = act_cell(C_t) * ogated */ \ }; \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ avx_act_gate_ = GetAVXAct(act_gate); \
_mm256_storeu_ps(ht, o); \ avx_act_cand_ = GetAVXAct(act_cand); \
avx_act_cell_ = GetAVXAct(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} }
// TODO(TJ): optimize keq16 // TODO(TJ): optimize keq16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册