From 1940bc2d8392a1f41d3e0f6678afc0f6a77bd4be Mon Sep 17 00:00:00 2001 From: wangguibao Date: Sat, 29 Sep 2018 16:14:40 +0800 Subject: [PATCH] Avoid multiple definitions of lstm_compute_ctht when linking libpaddle_fluid.so test=develop --- .../fluid/operators/math/cpu_lstm_compute.cc | 27 ++++++++++++++++++- .../fluid/operators/math/cpu_lstm_compute.h | 21 ++------------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc index 58e6512021..e96d187933 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.cc +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -13,6 +13,31 @@ limitations under the License. */ namespace paddle { namespace operators { -namespace math {} // namespace math +namespace math { +#ifdef __AVX__ +template <> +void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, + float* ht) { + namespace act = detail::forward::avx; + // gates: W_ch, W_ih, W_fh, W_oh + __m256 c, i, f, o; + c = _mm256_loadu_ps(gates); + i = _mm256_loadu_ps(gates + 8); + f = _mm256_loadu_ps(gates + 16); + o = _mm256_loadu_ps(gates + 24); + + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); + i = _mm256_loadu_ps(ct_1); + f = _mm256_mul_ps(i, act::Sigmoid(f)); + f = _mm256_add_ps(c, f); + _mm256_storeu_ps(ct, f); + + /* H_t = act_cell(C_t) * ogated */ + o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); + _mm256_storeu_ps(ht, o); +} +#endif +} // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h index 28b6f71729..169a9e4b47 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.h +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -48,32 +48,15 @@ namespace forward { namespace avx { __m256 Sigmoid(const __m256 a); __m256 Tanh(const __m256 a); + } // namespace avx } // namespace forward } // namespace detail template <> void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht) { - namespace act = detail::forward::avx; - // gates: W_ch, W_ih, W_fh, W_oh - __m256 c, i, f, o; - c = _mm256_loadu_ps(gates); - i = _mm256_loadu_ps(gates + 8); - f = _mm256_loadu_ps(gates + 16); - o = _mm256_loadu_ps(gates + 24); + float* ht); - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); - i = _mm256_loadu_ps(ct_1); - f = _mm256_mul_ps(i, act::Sigmoid(f)); - f = _mm256_add_ps(c, f); - _mm256_storeu_ps(ct, f); - - /* H_t = act_cell(C_t) * ogated */ - o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); - _mm256_storeu_ps(ht, o); -} #endif } // namespace math -- GitLab