diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc index 84e387c974b17f5b4b8cb9c8dae8616f2b3609aa..f7c55c215bacdafc99da5fcd0b750a058dfed21c 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.cc +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -13,9 +13,76 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cpu_lstm_compute.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" +#ifdef __AVX__ +#include +#endif namespace paddle { namespace operators { -namespace math {} // namespace math +namespace math { + +// TODO(TJ): ugly workaround, clean me +template +void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { + // gates: W_ch, W_ih, W_fh, W_oh + vec_sigmoid(24, gates + 8, gates + 8); + vec_tanh(8, gates, gates); + const T *i = gates + 8, *f = gates + 16, *o = gates + 24; + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int d = 0; d < 8; ++d) { + // C_t = C_t-1 * fgated + cand_gated * igated + ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; + // H_t = act_cell(C_t) * ogated + T tmp = ct[d] * 2; + tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); + vec_exp(1, &tmp, &tmp); + tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); + ht[d] = tmp * o[d]; + } +} + +#ifdef __AVX__ +namespace detail { +namespace forward { +namespace avx { +__m256 Sigmoid(const __m256 a); +__m256 Tanh(const __m256 a); +} // namespace avx +} // namespace forward +} // namespace detail + +template <> +void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, + float* ht) { + namespace act = detail::forward::avx; + // gates: W_ch, W_ih, W_fh, W_oh + __m256 c, i, f, o; + c = _mm256_loadu_ps(gates); + i = _mm256_loadu_ps(gates + 8); + f = _mm256_loadu_ps(gates + 16); + o = _mm256_loadu_ps(gates + 24); + + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); + i = _mm256_loadu_ps(ct_1); + f = _mm256_mul_ps(i, act::Sigmoid(f)); + f = _mm256_add_ps(c, f); + _mm256_storeu_ps(ct, f); + + /* H_t = act_cell(C_t) * ogated */ + o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); + _mm256_storeu_ps(ht, o); +} +#endif + +template void lstm_compute_ctht(float* gates, const float* ct_1, + float* ct, float* ht); +template void lstm_compute_ctht(double* gates, const double* ct_1, + double* ct, double* ht); + +} // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h index 00e9e4f32ae790dcd4417d47db2b0db510a68549..244164f08c4bb70833a9bfc884982a4225945bf0 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.h +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -14,11 +14,6 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/platform/cpu_info.h" -#ifdef __AVX__ -#include -#endif namespace paddle { namespace operators { @@ -26,58 +21,7 @@ namespace math { // TODO(TJ): ugly workaround, clean me template -void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { - // gates: W_ch, W_ih, W_fh, W_oh - vec_sigmoid(24, gates + 8, gates + 8); - vec_tanh(8, gates, gates); - const T *i = gates + 8, *f = gates + 16, *o = gates + 24; - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int d = 0; d < 8; ++d) { - // C_t = C_t-1 * fgated + cand_gated * igated - ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; - // H_t = act_cell(C_t) * ogated - T tmp = ct[d] * 2; - tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vec_exp(1, &tmp, &tmp); - tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); - ht[d] = tmp * o[d]; - } -} - -#ifdef __AVX__ -namespace detail { -namespace forward { -namespace avx { -__m256 Sigmoid(const __m256 a); -__m256 Tanh(const __m256 a); -} // namespace avx -} // namespace forward -} // namespace detail - -template <> -void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht) { - namespace act = detail::forward::avx; - // gates: W_ch, W_ih, W_fh, W_oh - __m256 c, i, f, o; - c = _mm256_loadu_ps(gates); - i = _mm256_loadu_ps(gates + 8); - f = _mm256_loadu_ps(gates + 16); - o = _mm256_loadu_ps(gates + 24); - - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); - i = _mm256_loadu_ps(ct_1); - f = _mm256_mul_ps(i, act::Sigmoid(f)); - f = _mm256_add_ps(c, f); - _mm256_storeu_ps(ct, f); - - /* H_t = act_cell(C_t) * ogated */ - o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); - _mm256_storeu_ps(ht, o); -} -#endif +void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht); } // namespace math } // namespace operators