diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index d1a0a05c7092235f67bf9e096684405d29a59168..d67029a392e5d973424fd9322b9000ea5f34ac3f 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -172,19 +172,19 @@ class FusionGRUKernel : public framework::OpKernel { bool is_reverse = ctx.Attr("is_reverse"); std::function act_gate, act_state; - std::function bias_sub; + std::function cross; auto& act_gate_str = ctx.Attr("gate_activation"); auto& act_state_str = ctx.Attr("activation"); if (platform::jit::MayIUse(platform::jit::avx)) { math::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_state = act_functor(act_state_str); - bias_sub = math::vec_bias_sub; + cross = math::vec_cross; } else { math::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_state = act_functor(act_state_str); - bias_sub = math::vec_bias_sub; + cross = math::vec_cross; } const T* x_data = x->data(); @@ -288,15 +288,9 @@ class FusionGRUKernel : public framework::OpKernel { for (int i = 0; i < cur_bs; ++i) { // ht~ = act_state(...) act_state(D, cur_batched_data + D2, cur_batched_data + D2); - // ht~~ = zt*ht~ inplace result - blas.VMUL(D, cur_batched_data, cur_batched_data + D2, - cur_batched_data + D2); - // zt = 1 - zt inplace result - bias_sub(D, static_cast(1), cur_batched_data, cur_batched_data); - // zt = ht_1 * zt - blas.VMUL(D, cur_prev_hidden_data, cur_batched_data, cur_batched_data); - // out = zt + ht~~ - blas.VADD(D, cur_batched_data, cur_batched_data + D2, cur_out_data); + // out = zt*ht~ + (1-zt)*ht_1 + cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data, + cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index bf6f552ad3ca5128f8ac850edf54e416acf0f04a..9560e3a3c15ca63892fbe3552679a22f027f11e2 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -188,6 +188,65 @@ inline void vec_bias_sub(const int n, vec_bias_sub(n, a, x, y); } +// out = x*y + (1-x)*z +template +inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { + for (int i = 0; i < n; ++i) { + out[i] = x[i] * y[i] + (static_cast(1) - x[i]) * z[i]; + } +} + +template <> +inline void vec_cross(const int n, const float* x, + const float* y, const float* z, + float* out) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_cross(n, x, y, z, out); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(1.f); + __m256 tmpx, tmpy, tmpz; + for (i = 0; i < end; i += block) { + tmpx = _mm256_loadu_ps(x + i); + tmpy = _mm256_loadu_ps(y + i); + tmpz = _mm256_loadu_ps(z + i); + tmpy = _mm256_mul_ps(tmpx, tmpy); + tmpx = _mm256_sub_ps(bias, tmpx); + tmpz = _mm256_mul_ps(tmpx, tmpz); + tmpz = _mm256_add_ps(tmpy, tmpz); + _mm256_storeu_ps(out + i, tmpz); + } + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + out[i] = x[i] * y[i] + (1.f - x[i]) * z[i]; + } +#else + vec_cross(n, x, y, z, out); +#endif +} + +template <> +inline void vec_cross(const int n, const float* x, + const float* y, + const float* z, float* out) { + vec_cross(n, x, y, z, out); +} + +template <> +inline void vec_cross( + const int n, const float* x, const float* y, const float* z, float* out) { + // TODO(TJ): enable me + vec_cross(n, x, y, z, out); +} + template inline void vec_add_bias(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) {