提交 8dea07f2 编写于 作者: T tensor-tang

fix comopile

上级 612ba41a
...@@ -396,15 +396,15 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -396,15 +396,15 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} }
} else { } else {
// TODO(TJ): unly workaround, clean me // TODO(TJ): unly workaround, clean me
std::function<void(const T*, const T*, T*, T*)> compute_ctht; std::function<void(T*, const T*, T*, T*)> compute_ctht;
if (platform::jit::MayIUse(platform::jit::avx) && if (platform::jit::MayIUse(platform::jit::avx) &&
act_gate_str == "sigmoid" && act_cand_str == "tanh" && act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
act_cell_str == "tanh" && D == 8) { act_cell_str == "tanh" && D == 8) {
compute_ctht = math::lstm_compute_ctht<T>; compute_ctht = math::lstm_compute_ctht<T>;
} else { } else {
compute_ctht = [&](const T* gates, const T* ct_1, T* ct, T* ht) { compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
COMPUTE_CtHt(gates, ct_1, ct, ht); COMPUTE_CtHt(gates, ct_1, ct, ht);
} };
} }
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
PROCESS_H0C0 PROCESS_H0C0
......
...@@ -25,12 +25,15 @@ namespace math { ...@@ -25,12 +25,15 @@ namespace math {
namespace detail { namespace detail {
namespace forward { namespace forward {
namespace avx {} // namespace avx namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward } // namespace forward
} // namespace detail } // namespace detail
template <> template <>
void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct, void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) { float* ht) {
namespace act = detail::forward::avx; namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh // gates: W_ch, W_ih, W_fh, W_oh
...@@ -52,6 +55,7 @@ void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct, ...@@ -52,6 +55,7 @@ void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct,
_mm256_storeu_ps(ht, o); _mm256_storeu_ps(ht, o);
} }
#endif #endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -23,22 +23,19 @@ namespace math { ...@@ -23,22 +23,19 @@ namespace math {
// TODO(TJ): ugly workaround, clean me // TODO(TJ): ugly workaround, clean me
template <typename T> template <typename T>
void lstm_compute_ctht(const T* gates, const T* ct_1, T* ct, T* ht) { void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
// gates: W_ch, W_ih, W_fh, W_oh // gates: W_ch, W_ih, W_fh, W_oh
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8); vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
vec_tanh<T, platform::jit::avx>(8, gates, gates); vec_tanh<T, platform::jit::avx>(8, gates, gates);
const T *i = gates + 8, *f = gates + 16, *o = gates + 24; const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int d = 0; d < 8; ++d) { for (int d = 0; d < 8; ++d) {
// C_t = C_t-1 * fgated + cand_gated * igated // C_t = C_t-1 * fgated + cand_gated * igated
ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated // H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2; T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - (tmp < static_cast<T>(SIGMOID_THRESHOLD_MIN)) tmp = static_cast<T>(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp);
? min
: ((tmp > static_cast<T>(SIGMOID_THRESHOLD_MAX))
? static_cast<T>(SIGMOID_THRESHOLD_MAX)
: tmp);
vec_exp<T>(1, &tmp, &tmp); vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1); tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d]; ht[d] = tmp * o[d];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册