未验证 提交 b37fe304 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13690 from wangguibao/fix_cpu_lstm_compute_cc

Avoid multiple definitions of lstm_compute_ctht when linking libpaddle_fluid.so
...@@ -13,6 +13,31 @@ limitations under the License. */ ...@@ -13,6 +13,31 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math {} // namespace math namespace math {
#ifdef __AVX__
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -48,32 +48,15 @@ namespace forward { ...@@ -48,32 +48,15 @@ namespace forward {
namespace avx { namespace avx {
__m256 Sigmoid(const __m256 a); __m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a); __m256 Tanh(const __m256 a);
} // namespace avx } // namespace avx
} // namespace forward } // namespace forward
} // namespace detail } // namespace detail
template <> template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct, void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) { float* ht);
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif #endif
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册