未验证 提交 0a9f5f17 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13968 from tensor-tang/fix/jit/exp

Fix jit exp
......@@ -18,12 +18,12 @@ namespace paddle {
namespace inference {
using namespace framework; // NOLINT
static std::vector<float> result_data;
struct DataRecord {
std::vector<std::vector<std::vector<float>>> link_step_data_all;
std::vector<size_t> lod;
std::vector<std::vector<float>> rnn_link_data;
std::vector<float> result_data;
size_t num_samples; // total number of samples
size_t batch_iter{0};
size_t batch_size{1};
......@@ -57,6 +57,7 @@ struct DataRecord {
std::ifstream file(path);
std::string line;
int num_lines = 0;
result_data.clear();
while (std::getline(file, line)) {
num_lines++;
std::vector<std::string> data;
......@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], data.result_data[i], 1e-3);
EXPECT_NEAR(result[i], result_data[i], 1e-3);
}
}
}
......
......@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
DEPS cpu_info cblas activation_functions)
DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
......@@ -25,13 +25,18 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
#ifdef __AVX__
namespace jitkernel {
namespace detail {
__m256 Exp(__m256 a);
} // namespace detail
#ifdef __AVX__
__m256 ExpAVX(__m256 x);
#endif
namespace jitkernel {
#ifdef __AVX2__
__m256 ExpAVX2(__m256 x);
#endif
} // namespace detail
namespace jit = platform::jit;
#ifdef __AVX__
......@@ -43,43 +48,72 @@ class AVXAct {
virtual __m256 Compute(__m256 x) const = 0;
};
template <act_type type>
template <act_type type, jit::cpu_isa_t isa>
class AVXActImpl : public AVXAct {
public:
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
};
template <>
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const {
__m256 ones = _mm256_set1_ps(1.0f);
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN));
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX));
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x);
x = detail::Exp(x);
x = _mm256_add_ps(ones, x);
return _mm256_div_ps(ones, x);
}
#define AVX_SIGMOID(isa, expisa) \
template <> \
__m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
return _mm256_div_ps(ones, x); \
}
template <>
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const {
__m256 ones = _mm256_set1_ps(1.0f);
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x);
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT));
x = detail::Exp(x);
x = _mm256_add_ps(ones, x);
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x);
return _mm256_sub_ps(x, ones);
}
#define AVX_TANH(isa, expisa) \
template <> \
__m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
__m256 ones = _mm256_set1_ps(1.0f); \
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
x = expisa(x); \
x = _mm256_add_ps(ones, x); \
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
return _mm256_sub_ps(x, ones); \
}
template <>
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const {
return _mm256_max_ps(x, _mm256_setzero_ps());
}
#define AVX_RELU(isa) \
template <> \
__m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
return _mm256_max_ps(x, _mm256_setzero_ps()); \
}
#define AVX_IDENTITY(isa) \
template <> \
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
return x; \
}
#define FOR_EACH_AVX_ISA(macro_) \
macro_(jit::avx); \
macro_(jit::avx2); \
macro_(jit::avx512f)
FOR_EACH_AVX_ISA(AVX_RELU);
FOR_EACH_AVX_ISA(AVX_IDENTITY);
AVX_SIGMOID(jit::avx, detail::ExpAVX);
AVX_TANH(jit::avx, detail::ExpAVX);
#ifdef __AVX2__
AVX_SIGMOID(jit::avx2, detail::ExpAVX2);
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
AVX_TANH(jit::avx2, detail::ExpAVX2);
AVX_TANH(jit::avx512f, detail::ExpAVX2);
#endif
#undef FOR_EACH_AVX_ISA
#undef AVX_IDENTITY
#undef AVX_RELU
#undef AVX_TANH
#undef AVX_SIGMOID
template <>
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
return x;
}
#endif
template <typename T>
......@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_ = GetActKernel<T>(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
#ifdef __AVX__
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
}
PADDLE_THROW("Not support type: %s", type);
};
avx_act_gate_ = GetAVXAct(act_gate);
avx_act_cand_ = GetAVXAct(act_cand);
avx_act_cell_ = GetAVXAct(act_cell);
#endif
}
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
......@@ -176,6 +193,27 @@ class LSTMKernelImpl : public LSTMKernel<T> {
};
#define INTRI8_FLOAT(isa) \
template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \
: LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
if (type == "sigmoid") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
} else if (type == "relu") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
} else if (type == "tanh") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
} else if (type == "identity" || type == "") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
} \
PADDLE_THROW("Not support type: %s", type); \
}; \
avx_act_gate_ = GetAVXAct(act_gate); \
avx_act_cand_ = GetAVXAct(act_cand); \
avx_act_cell_ = GetAVXAct(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
......@@ -195,6 +233,20 @@ class LSTMKernelImpl : public LSTMKernel<T> {
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册