From cf5ea925c3eea2f63b099513b85eaf5032db38fa Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 16:10:55 +0800 Subject: [PATCH] fix bugs --- paddle/fluid/operators/attention_lstm_op.cc | 123 +++++++++----------- paddle/fluid/operators/math/blas.h | 17 ++- paddle/fluid/operators/math/blas_impl.h | 3 +- 3 files changed, 75 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 636deb04a1..87fda12ea6 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -15,12 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/sequence2batch.h" - -#include "paddle/fluid/operators/math/cpu_vec.h" +// #include "paddle/fluid/operators/math/detail/activation_functions.h" +// #include "paddle/fluid/operators/math/cpu_vec.h" namespace paddle { namespace operators { @@ -233,6 +230,13 @@ use lstm_x_t as input and compute as standard LSTM. )DOC"); } +template +inline void vec_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + // y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { @@ -240,14 +244,14 @@ inline void bias_relu(const int n, const T* x, const T* bias, T* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] + bias[0]; } - vec_relu(n, y, y); + vec_relu(n, y, y); } else { - vec_relu(n, x, y); + vec_relu(n, x, y); } } template -inline void vec_softmax(const BlasT& blas, const int n, +inline void vec_softmax(const math::BlasT& blas, const int n, const T* x, T* y) { T scalar = x[0]; // max @@ -257,7 +261,7 @@ inline void vec_softmax(const BlasT& blas, const int n, // sub for (int i = 0; i < n; ++i) { - y[c] = x[c] - alpha; + y[i] = x[i] - scalar; } // exp @@ -270,57 +274,45 @@ inline void vec_softmax(const BlasT& blas, const int n, } // scale - blas.VSCAL(n, static_cast(1) / scalar, y); + blas.SCAL(n, static_cast(1) / scalar, y); } -__m256 exp(__m256 a) { return exp256_ps(a); } +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 -__m256 log(__m256 a) { return log256_ps(a); } - -__m256 sin(__m256 a) { return sin256_ps(a); } - -__m256 cos(__m256 a) { return cos256_ps(a); } - -__m256 relu(const __m256 a) { - __m256 tmp = _mm256_set1_ps(0.0f); - return _mm256_max_ps(a, tmp); +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); } -__m256 sigmoid(const __m256 a) { - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); - __m256 tmp = _mm256_max_ps(a, min); - tmp = _mm256_min_ps(tmp, max); - tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); - tmp = exp(tmp); - tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); - tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); - return tmp; +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; } -__m256 tanh(const __m256 a) { - __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); - __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); - tmp = _mm256_min_ps(tmp, max); - tmp = exp(tmp); - return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), - _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), - _mm256_set1_ps(1.0f)); +template +inline void vec_sigmoid(const int n, const T* x, T* y) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = 1.0 / (1.0 + std::exp(-tmp)); + } } -__m256 linear(const __m256 a) { return a; } - -inline void vec_sigmoid(const T* x, T* y) { - const real min = SIGMOID_THRESHOLD_MIN; - const real max = SIGMOID_THRESHOLD_MAX; - real tmp = (a < min) ? min : ((a > max) ? max : a); - return 1.0 / (1.0 + exp(-tmp)); +template +inline void vec_tanh(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = tanh(x[i]); + } } -template +template class AttentionLSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); // T x M auto* h0 = ctx.Input("H0"); // N x D auto* c0 = ctx.Input("C0"); // N x D @@ -334,7 +326,7 @@ class AttentionLSTMKernel : public framework::OpKernel { auto* hidden_out = ctx.Output("Hidden"); // TxD auto* cell_out = ctx.Output("Cell"); // TxD auto* atted_x = ctx.Output("AttentionedX"); // T x 1 - auto* fc_out = ctx.Output('AttentionFCOut'); // max_seq_len x 1 + auto* fc_out = ctx.Output("AttentionFCOut"); // max_seq_len x 1 auto* lstm_x = ctx.Output("LSTMX"); // 1 x M auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D @@ -342,9 +334,10 @@ class AttentionLSTMKernel : public framework::OpKernel { auto x_lod = x->lod(); const int N = x_lod[0].size() - 1; // batch size auto x_dims = x->dims(); // T x M - auto w_dims = w->dims(); // (D+M) x 4D - const int M = x_dims[1]; // x frame size - const int D = w_dims[1] / 4; // gate frame size + auto w_dims = lstm_w->dims(); // (D+M) x 4D + const int total_T = x_dims[0]; + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size const int D2 = D * 2; const int D3 = D * 3; const int D4 = w_dims[1]; @@ -357,6 +350,8 @@ class AttentionLSTMKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); fc_out->Resize({max_seq_len, 1}); + // TODO(TJ): act functor init here + const T* x_data = x->data(); const T* h0_data = h0->data(); const T* c0_data = c0->data(); @@ -368,16 +363,16 @@ class AttentionLSTMKernel : public framework::OpKernel { const T* atten_scalar_bias_data = atten_scalar_bias ? atten_scalar_bias->data() : NULL; - T* hidden_out_data = hidden_out->mutable_data(); - T* cell_out_data = cell_out->mutable_data(); - T* atted_x_data = atted_x->mutable_data(); - T* fc_out_data = fc_out->mutable_data(); - T* lstm_x_data = lstm_x->mutable_data(); - T* lstm_out_data = lstm_out->mutable_data(); + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + T* atted_x_data = atted_x->mutable_data(ctx.GetPlace()); + T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); + T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); + T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto blas = math::GetBlas(ctx); - math::FCCompute(blas, T, 1, M, x_data, atten_w_data, + math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, atted_x_data, atten_b_data); const T* cur_x_data = x_data; @@ -400,7 +395,7 @@ class AttentionLSTMKernel : public framework::OpKernel { // fc2: scalar if (atten_scalar_data) { // x = a*x - blas.SCAL(seq_len, atten_scalar_data, fc_out_data); + blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data); } @@ -431,16 +426,16 @@ class AttentionLSTMKernel : public framework::OpKernel { blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); // b = input * tilde - blas.VMUL(D, lstm_out_data + D, lstm_out + D3, lstm_out_data + D); + blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D); // cell_out = a + b blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); // state act tanh(cell_out) * output_gate vec_tanh(D, cur_cell_out_data, lstm_out_data); - blas.VMUL(D, lstm_out_data, lstm_out + D2, cur_hidden_out_data); + blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); - prev_hidden_data = hidden_out + i * gate_size; + prev_hidden_data = cur_hidden_out_data; prev_cell_data = cur_cell_out_data; cur_cell_out_data = cur_cell_out_data + D; cur_hidden_out_data = cur_hidden_out_data + D; @@ -458,7 +453,5 @@ REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp, ops::AttentionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL( - attention_lstm, - ops::AttentionLSTMKernel, - ops::AttentionLSTMKernel); +REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel, + ops::AttentionLSTMKernel); diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 5aba170221..da185d93c0 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -160,7 +160,7 @@ class Blas { T DOT(int n, const T* x, const T* y) const; template - void SCAL(int n, const T a, const T* x) const; + void SCAL(int n, const T a, T* x) const; template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, @@ -233,11 +233,26 @@ class BlasT : private Blas { Base()->template VCOPY(args...); } + template + void VEXP(ARGS... args) const { + Base()->template VEXP(args...); + } + template void GEMV(ARGS... args) const { Base()->template GEMV(args...); } + template + T DOT(ARGS... args) const { + return Base()->template DOT(args...); + } + + template + void SCAL(ARGS... args) const { + Base()->template SCAL(args...); + } + template void BatchedGEMM(ARGS... args) const { Base()->template BatchedGEMM(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index eaad83ba18..e1df78d11e 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -415,8 +415,7 @@ T Blas::DOT(int n, const T *x, const T *y) const { template <> template -void Blas::SCAL(int n, const T a, - const T *x) const { +void Blas::SCAL(int n, const T a, T *x) const { #ifdef PADDLE_WITH_MKLML CBlas::SCAL(n, a, x, 1); #else -- GitLab