提交 cf5ea925 编写于 作者: T tensor-tang

fix bugs

上级 6ed20474
......@@ -15,12 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
// #include "paddle/fluid/operators/math/detail/activation_functions.h"
// #include "paddle/fluid/operators/math/cpu_vec.h"
namespace paddle {
namespace operators {
......@@ -233,6 +230,13 @@ use lstm_x_t as input and compute as standard LSTM.
)DOC");
}
template <typename T>
inline void vec_relu(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
......@@ -240,14 +244,14 @@ inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] + bias[0];
}
vec_relu(n, y, y);
vec_relu<T>(n, y, y);
} else {
vec_relu(n, x, y);
vec_relu<T>(n, x, y);
}
}
template <typename DeviceContext, typename T>
inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
inline void vec_softmax(const math::BlasT<DeviceContext, T>& blas, const int n,
const T* x, T* y) {
T scalar = x[0];
// max
......@@ -257,7 +261,7 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
// sub
for (int i = 0; i < n; ++i) {
y[c] = x[c] - alpha;
y[i] = x[i] - scalar;
}
// exp
......@@ -270,57 +274,45 @@ inline void vec_softmax(const BlasT<DeviceContext, T>& blas, const int n,
}
// scale
blas.VSCAL(n, static_cast<T>(1) / scalar, y);
blas.SCAL(n, static_cast<T>(1) / scalar, y);
}
__m256 exp(__m256 a) { return exp256_ps(a); }
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
__m256 log(__m256 a) { return log256_ps(a); }
__m256 sin(__m256 a) { return sin256_ps(a); }
__m256 cos(__m256 a) { return cos256_ps(a); }
__m256 relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp);
template <typename T>
inline T sigmoid(T x) {
return 1. / (1. + exp(-x));
}
__m256 sigmoid(const __m256 a) {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
__m256 tmp = _mm256_max_ps(a, min);
tmp = _mm256_min_ps(tmp, max);
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
tmp = exp(tmp);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
template <typename T>
inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
__m256 tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
tmp = _mm256_min_ps(tmp, max);
tmp = exp(tmp);
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
_mm256_set1_ps(1.0f));
template <typename T>
inline void vec_sigmoid(const int n, const T* x, T* y) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 1.0 / (1.0 + std::exp(-tmp));
}
}
__m256 linear(const __m256 a) { return a; }
inline void vec_sigmoid(const T* x, T* y) {
const real min = SIGMOID_THRESHOLD_MIN;
const real max = SIGMOID_THRESHOLD_MAX;
real tmp = (a < min) ? min : ((a > max) ? max : a);
return 1.0 / (1.0 + exp(-tmp));
template <typename T>
inline void vec_tanh(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = tanh<T>(x[i]);
}
}
template <typename DeviceContext, typename T>
template <typename T>
class AttentionLSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); // T x M
auto* h0 = ctx.Input<Tensor>("H0"); // N x D
auto* c0 = ctx.Input<Tensor>("C0"); // N x D
......@@ -334,7 +326,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut"); // max_seq_len x 1
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
......@@ -342,7 +334,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; // batch size
auto x_dims = x->dims(); // T x M
auto w_dims = w->dims(); // (D+M) x 4D
auto w_dims = lstm_w->dims(); // (D+M) x 4D
const int total_T = x_dims[0];
const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2;
......@@ -357,6 +350,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
// TODO(TJ): act functor init here
const T* x_data = x->data<T>();
const T* h0_data = h0->data<T>();
const T* c0_data = c0->data<T>();
......@@ -368,16 +363,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
const T* atten_scalar_bias_data =
atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;
T* hidden_out_data = hidden_out->mutable_data<T>();
T* cell_out_data = cell_out->mutable_data<T>();
T* atted_x_data = atted_x->mutable_data<T>();
T* fc_out_data = fc_out->mutable_data<T>();
T* lstm_x_data = lstm_x->mutable_data<T>();
T* lstm_out_data = lstm_out->mutable_data<T>();
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data,
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data);
const T* cur_x_data = x_data;
......@@ -400,7 +395,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// fc2: scalar
if (atten_scalar_data) {
// x = a*x
blas.SCAL(seq_len, atten_scalar_data, fc_out_data);
blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data);
}
......@@ -431,16 +426,16 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
// b = input * tilde
blas.VMUL(D, lstm_out_data + D, lstm_out + D3, lstm_out_data + D);
blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
// cell_out = a + b
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
// state act tanh(cell_out) * output_gate
vec_tanh(D, cur_cell_out_data, lstm_out_data);
blas.VMUL(D, lstm_out_data, lstm_out + D2, cur_hidden_out_data);
blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
prev_hidden_data = hidden_out + i * gate_size;
prev_hidden_data = cur_hidden_out_data;
prev_cell_data = cur_cell_out_data;
cur_cell_out_data = cur_cell_out_data + D;
cur_hidden_out_data = cur_hidden_out_data + D;
......@@ -458,7 +453,5 @@ REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp,
ops::AttentionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(
attention_lstm,
ops::AttentionLSTMKernel<paddle::platform::CPUDeviceContext, float>,
ops::AttentionLSTMKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel<float>,
ops::AttentionLSTMKernel<double>);
......@@ -160,7 +160,7 @@ class Blas {
T DOT(int n, const T* x, const T* y) const;
template <typename T>
void SCAL(int n, const T a, const T* x) const;
void SCAL(int n, const T a, T* x) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
......@@ -233,11 +233,26 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VCOPY<T>(args...);
}
template <typename... ARGS>
void VEXP(ARGS... args) const {
Base()->template VEXP<T>(args...);
}
template <typename... ARGS>
void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...);
}
template <typename... ARGS>
T DOT(ARGS... args) const {
return Base()->template DOT<T>(args...);
}
template <typename... ARGS>
void SCAL(ARGS... args) const {
Base()->template SCAL<T>(args...);
}
template <typename... ARGS>
void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...);
......
......@@ -415,8 +415,7 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a,
const T *x) const {
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1);
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册