提交 cf5ea925 编写于 作者: T tensor-tang

fix bugs

上级 6ed20474
...@@ -15,12 +15,9 @@ limitations under the License. */ ...@@ -15,12 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h" #include "paddle/fluid/operators/attention_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h" // #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/sequence2batch.h" // #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -233,6 +230,13 @@ use lstm_x_t as input and compute as standard LSTM. ...@@ -233,6 +230,13 @@ use lstm_x_t as input and compute as standard LSTM.
)DOC"); )DOC");
} }
template <typename T>
inline void vec_relu(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; // y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T> template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) { inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
...@@ -240,14 +244,14 @@ inline void bias_relu(const int n, const T* x, const T* bias, T* y) { ...@@ -240,14 +244,14 @@ inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = x[i] + bias[0]; y[i] = x[i] + bias[0];
} }
vec_relu(n, y, y); vec_relu<T>(n, y, y);
} else { } else {
vec_relu(n, x, y); vec_relu<T>(n, x, y);
} }
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n, inline void vec_softmax(const math::BlasT<DeviceContext, T>& blas, const int n,
const T* x, T* y) { const T* x, T* y) {
T scalar = x[0]; T scalar = x[0];
// max // max
...@@ -257,7 +261,7 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n, ...@@ -257,7 +261,7 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
// sub // sub
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[c] = x[c] - alpha; y[i] = x[i] - scalar;
} }
// exp // exp
...@@ -270,57 +274,45 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n, ...@@ -270,57 +274,45 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
} }
// scale // scale
blas.VSCAL(n, static_cast<T>(1) / scalar, y); blas.SCAL(n, static_cast<T>(1) / scalar, y);
} }
__m256 exp(__m256 a) { return exp256_ps(a); } #define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
__m256 log(__m256 a) { return log256_ps(a); } template <typename T>
inline T sigmoid(T x) {
__m256 sin(__m256 a) { return sin256_ps(a); } return 1. / (1. + exp(-x));
__m256 cos(__m256 a) { return cos256_ps(a); }
__m256 relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp);
} }
__m256 sigmoid(const __m256 a) { template <typename T>
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); inline T tanh(T x) {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); return 2. * sigmoid(2. * x) - 1.;
__m256 tmp = _mm256_max_ps(a, min);
tmp = _mm256_min_ps(tmp, max);
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
tmp = exp(tmp);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
} }
__m256 tanh(const __m256 a) { template <typename T>
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT); inline void vec_sigmoid(const int n, const T* x, T* y) {
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); const T min = SIGMOID_THRESHOLD_MIN;
tmp = _mm256_min_ps(tmp, max); const T max = SIGMOID_THRESHOLD_MAX;
tmp = exp(tmp); for (int i = 0; i < n; ++i) {
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), y[i] = 1.0 / (1.0 + std::exp(-tmp));
_mm256_set1_ps(1.0f)); }
} }
__m256 linear(const __m256 a) { return a; } template <typename T>
inline void vec_tanh(const int n, const T* x, T* y) {
inline void vec_sigmoid(const T* x, T* y) { for (int i = 0; i < n; ++i) {
const real min = SIGMOID_THRESHOLD_MIN; y[i] = tanh<T>(x[i]);
const real max = SIGMOID_THRESHOLD_MAX; }
real tmp = (a < min) ? min : ((a > max) ? max : a);
return 1.0 / (1.0 + exp(-tmp));
} }
template <typename DeviceContext, typename T> template <typename T>
class AttentionLSTMKernel : public framework::OpKernel<T> { class AttentionLSTMKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); // T x M auto* x = ctx.Input<LoDTensor>("X"); // T x M
auto* h0 = ctx.Input<Tensor>("H0"); // N x D auto* h0 = ctx.Input<Tensor>("H0"); // N x D
auto* c0 = ctx.Input<Tensor>("C0"); // N x D auto* c0 = ctx.Input<Tensor>("C0"); // N x D
...@@ -334,7 +326,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -334,7 +326,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1 auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1 auto* fc_out = ctx.Output<Tensor>("AttentionFCOut"); // max_seq_len x 1
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
...@@ -342,7 +334,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -342,7 +334,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto x_lod = x->lod(); auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; // batch size const int N = x_lod[0].size() - 1; // batch size
auto x_dims = x->dims(); // T x M auto x_dims = x->dims(); // T x M
auto w_dims = w->dims(); // (D+M) x 4D auto w_dims = lstm_w->dims(); // (D+M) x 4D
const int total_T = x_dims[0];
const int M = x_dims[1]; // x frame size const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2; const int D2 = D * 2;
...@@ -357,6 +350,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -357,6 +350,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1}); fc_out->Resize({max_seq_len, 1});
// TODO(TJ): act functor init here
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
const T* c0_data = c0->data<T>(); const T* c0_data = c0->data<T>();
...@@ -368,16 +363,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -368,16 +363,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
const T* atten_scalar_bias_data = const T* atten_scalar_bias_data =
atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL; atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;
T* hidden_out_data = hidden_out->mutable_data<T>(); T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(); T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
T* atted_x_data = atted_x->mutable_data<T>(); T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(); T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
T* lstm_x_data = lstm_x->mutable_data<T>(); T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(); T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data, math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data); atted_x_data, atten_b_data);
const T* cur_x_data = x_data; const T* cur_x_data = x_data;
...@@ -400,7 +395,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -400,7 +395,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// fc2: scalar // fc2: scalar
if (atten_scalar_data) { if (atten_scalar_data) {
// x = a*x // x = a*x
blas.SCAL(seq_len, atten_scalar_data, fc_out_data); blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data, bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data); fc_out_data);
} }
...@@ -431,16 +426,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -431,16 +426,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
// b = input * tilde // b = input * tilde
blas.VMUL(D, lstm_out_data + D, lstm_out + D3, lstm_out_data + D); blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
// cell_out = a + b // cell_out = a + b
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
// state act tanh(cell_out) * output_gate // state act tanh(cell_out) * output_gate
vec_tanh(D, cur_cell_out_data, lstm_out_data); vec_tanh(D, cur_cell_out_data, lstm_out_data);
blas.VMUL(D, lstm_out_data, lstm_out + D2, cur_hidden_out_data); blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
prev_hidden_data = hidden_out + i * gate_size; prev_hidden_data = cur_hidden_out_data;
prev_cell_data = cur_cell_out_data; prev_cell_data = cur_cell_out_data;
cur_cell_out_data = cur_cell_out_data + D; cur_cell_out_data = cur_cell_out_data + D;
cur_hidden_out_data = cur_hidden_out_data + D; cur_hidden_out_data = cur_hidden_out_data + D;
...@@ -458,7 +453,5 @@ REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp, ...@@ -458,7 +453,5 @@ REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp,
ops::AttentionLSTMOpMaker, ops::AttentionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel<float>,
attention_lstm, ops::AttentionLSTMKernel<double>);
ops::AttentionLSTMKernel<paddle::platform::CPUDeviceContext, float>,
ops::AttentionLSTMKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -160,7 +160,7 @@ class Blas { ...@@ -160,7 +160,7 @@ class Blas {
T DOT(int n, const T* x, const T* y) const; T DOT(int n, const T* x, const T* y) const;
template <typename T> template <typename T>
void SCAL(int n, const T a, const T* x) const; void SCAL(int n, const T a, T* x) const;
template <typename T> template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
...@@ -233,11 +233,26 @@ class BlasT : private Blas<DeviceContext> { ...@@ -233,11 +233,26 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VCOPY<T>(args...); Base()->template VCOPY<T>(args...);
} }
template <typename... ARGS>
void VEXP(ARGS... args) const {
Base()->template VEXP<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void GEMV(ARGS... args) const { void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...); Base()->template GEMV<T>(args...);
} }
template <typename... ARGS>
T DOT(ARGS... args) const {
return Base()->template DOT<T>(args...);
}
template <typename... ARGS>
void SCAL(ARGS... args) const {
Base()->template SCAL<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void BatchedGEMM(ARGS... args) const { void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...); Base()->template BatchedGEMM<T>(args...);
......
...@@ -415,8 +415,7 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const { ...@@ -415,8 +415,7 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
const T *x) const {
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1); CBlas<T>::SCAL(n, a, x, 1);
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册