未验证 提交 9362d85e 编写于 作者: J Jack Zhou 提交者: GitHub

Add LSTM, Simple RNN and GRU CPU kernel (#28577)

* add lstm, simple rnn op kernel

* fix the test_lstm for the rnn op

* change func name

* fix forward postprocess bug

* add gru forward, backward code

* remove unittest.skipIf; use a big rnn op instead of combination op

* fix input doesn't have gradient bug

* add eigen lstm forward, backward
Co-authored-by: Nwawltor <fangzeyang0904@hotmail.com>
上级 30ef3815
...@@ -30,18 +30,24 @@ namespace detail { ...@@ -30,18 +30,24 @@ namespace detail {
enum ActivationType { enum ActivationType {
kSigmoid, kSigmoid,
KSigmoidV2,
kReLU, kReLU,
kTanh, kTanh,
kTanhV2,
kIdentity, kIdentity,
}; };
inline ActivationType GetActivationType(const std::string &type) { inline ActivationType GetActivationType(const std::string &type) {
if (type == "sigmoid") { if (type == "sigmoid") {
return ActivationType::kSigmoid; return ActivationType::kSigmoid;
} else if (type == "sigmoid_v2") {
return ActivationType::KSigmoidV2;
} else if (type == "relu") { } else if (type == "relu") {
return ActivationType::kReLU; return ActivationType::kReLU;
} else if (type == "tanh") { } else if (type == "tanh") {
return ActivationType::kTanh; return ActivationType::kTanh;
} else if (type == "tanh_v2") {
return ActivationType::kTanhV2;
} else if (type == "identity" || type == "") { } else if (type == "identity" || type == "") {
return ActivationType::kIdentity; return ActivationType::kIdentity;
} }
...@@ -68,6 +74,14 @@ DEVICE T Sigmoid(const T a) { ...@@ -68,6 +74,14 @@ DEVICE T Sigmoid(const T a) {
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp)); return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
} }
/*
* Don't limit input in a threshold range.
*/
template <typename T>
DEVICE T SigmoidV2(const T a) {
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-a));
}
template <typename T> template <typename T>
DEVICE T Tanh(const T a) { DEVICE T Tanh(const T a) {
T tmp = -2.0 * a; T tmp = -2.0 * a;
...@@ -75,6 +89,15 @@ DEVICE T Tanh(const T a) { ...@@ -75,6 +89,15 @@ DEVICE T Tanh(const T a) {
return (2.0 / (1.0 + exp(tmp))) - 1.0; return (2.0 / (1.0 + exp(tmp))) - 1.0;
} }
/*
* Don't limit input in a threshold range.
*/
template <typename T>
DEVICE T TanhV2(const T a) {
T tmp = -2.0 * a;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
} // namespace forward } // namespace forward
namespace backward { namespace backward {
...@@ -108,20 +131,24 @@ struct Active { ...@@ -108,20 +131,24 @@ struct Active {
}; };
static DEVICE Active<float>::Act kActFloat[] = { static DEVICE Active<float>::Act kActFloat[] = {
&forward::Sigmoid<float>, &forward::Relu<float>, &forward::Tanh<float>, &forward::Sigmoid<float>, &forward::SigmoidV2<float>,
&forward::Identity<float>}; &forward::Relu<float>, &forward::Tanh<float>,
&forward::TanhV2<float>, &forward::Identity<float>};
static DEVICE Active<float>::ActGrad kActGradFloat[] = { static DEVICE Active<float>::ActGrad kActGradFloat[] = {
&backward::Sigmoid<float>, &backward::Relu<float>, &backward::Tanh<float>, &backward::Sigmoid<float>, &backward::Sigmoid<float>,
&backward::Identity<float>}; &backward::Relu<float>, &backward::Tanh<float>,
&backward::Tanh<float>, &backward::Identity<float>};
static DEVICE Active<double>::Act kActDouble[] = { static DEVICE Active<double>::Act kActDouble[] = {
&forward::Sigmoid<double>, &forward::Relu<double>, &forward::Tanh<double>, &forward::Sigmoid<double>, &forward::SigmoidV2<double>,
&forward::Identity<double>}; &forward::Relu<double>, &forward::Tanh<double>,
&forward::TanhV2<double>, &forward::Identity<double>};
static DEVICE Active<double>::ActGrad kActGradDouble[] = { static DEVICE Active<double>::ActGrad kActGradDouble[] = {
&backward::Sigmoid<double>, &backward::Relu<double>, &backward::Sigmoid<double>, &backward::Sigmoid<double>,
&backward::Tanh<double>, &backward::Identity<double>}; &backward::Relu<double>, &backward::Tanh<double>,
&backward::Tanh<double>, &backward::Identity<double>};
namespace forward { namespace forward {
inline DEVICE float activation(float a, int index) { inline DEVICE float activation(float a, int index) {
...@@ -149,7 +176,9 @@ namespace forward { ...@@ -149,7 +176,9 @@ namespace forward {
namespace avx { namespace avx {
__m256 Relu(const __m256 a); __m256 Relu(const __m256 a);
__m256 Sigmoid(const __m256 a); __m256 Sigmoid(const __m256 a);
__m256 SigmoidV2(const __m256 a);
__m256 Tanh(const __m256 a); __m256 Tanh(const __m256 a);
__m256 TanhV2(const __m256 a);
__m256 Identity(const __m256 a); __m256 Identity(const __m256 a);
} // namespace avx } // namespace avx
} // namespace forward } // namespace forward
...@@ -164,12 +193,12 @@ __m256 Identity(const __m256 a, const __m256 b); ...@@ -164,12 +193,12 @@ __m256 Identity(const __m256 a, const __m256 b);
} // namespace backward } // namespace backward
static Active<__m256>::Act kActAvx[] = { static Active<__m256>::Act kActAvx[] = {
&forward::avx::Sigmoid, &forward::avx::Relu, &forward::avx::Tanh, &forward::avx::Sigmoid, &forward::avx::SigmoidV2, &forward::avx::Relu,
&forward::avx::Identity}; &forward::avx::Tanh, &forward::avx::TanhV2, &forward::avx::Identity};
static Active<__m256>::ActGrad kActGradAvx[] = { static Active<__m256>::ActGrad kActGradAvx[] = {
&backward::avx::Sigmoid, &backward::avx::Relu, &backward::avx::Tanh, &backward::avx::Sigmoid, &backward::avx::Sigmoid, &backward::avx::Relu,
&backward::avx::Identity}; &backward::avx::Tanh, &backward::avx::Tanh, &backward::avx::Identity};
namespace forward { namespace forward {
inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); }
......
...@@ -43,6 +43,13 @@ __m256 Sigmoid(const __m256 a) { ...@@ -43,6 +43,13 @@ __m256 Sigmoid(const __m256 a) {
return tmp; return tmp;
} }
__m256 SigmoidV2(const __m256 a) {
__m256 tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), a);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), exp256_ps(tmp));
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
}
__m256 Tanh(const __m256 a) { __m256 Tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT); __m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
...@@ -53,6 +60,14 @@ __m256 Tanh(const __m256 a) { ...@@ -53,6 +60,14 @@ __m256 Tanh(const __m256 a) {
_mm256_set1_ps(1.0f)); _mm256_set1_ps(1.0f));
} }
__m256 TanhV2(const __m256 a) {
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
return _mm256_sub_ps(
_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), exp256_ps(tmp))),
_mm256_set1_ps(1.0f));
}
__m256 Identity(const __m256 a) { return a; } __m256 Identity(const __m256 a) { return a; }
} // namespace avx } // namespace avx
......
...@@ -25,26 +25,38 @@ namespace detail { ...@@ -25,26 +25,38 @@ namespace detail {
#ifndef __NVCC__ #ifndef __NVCC__
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, void hl_naive_gru_forward_reset_output(
T *gate_value, T *reset_output_value, OpResetOutput op_reset_output, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, const T *prev_output_value, int frame_size, ActivationType active_gate,
ActivationType active_gate) { bool old_version = true, const T *reset_bias = nullptr) {
T r_value_update_gate; T r_value_update_gate;
T r_value_reset_gate; T r_value_reset_gate;
T r_value_reset_output; T r_value_reset_output;
T r_prev_out = 0; T r_prev_out = 0;
T *update_gate = gate_value; T r_reset_bias = 0;
T *reset_gate = gate_value + frame_size; T *update_gate = nullptr;
T *reset_gate = nullptr;
if (old_version) {
update_gate = gate_value;
reset_gate = gate_value + frame_size;
} else {
reset_gate = gate_value;
update_gate = gate_value + frame_size;
}
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
r_value_update_gate = update_gate[i]; r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i]; r_value_reset_gate = reset_gate[i];
if (!old_version) {
r_value_reset_output = reset_output_value[i];
r_reset_bias = reset_bias[i];
}
if (prev_output_value) { if (prev_output_value) {
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
&r_value_reset_output, active_gate); &r_value_reset_output, active_gate, &r_reset_bias,
old_version);
update_gate[i] = r_value_update_gate; update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate; reset_gate[i] = r_value_reset_gate;
...@@ -53,16 +65,20 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -53,16 +65,20 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
} }
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, void hl_naive_gru_forward_final_output(
T *gate_value, T *prev_output_value, OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size, ActivationType active_node,
ActivationType active_node, bool origin_mode, bool old_version = true) {
bool origin_mode) {
T r_value_update_gate; T r_value_update_gate;
T r_value_frame_state; T r_value_frame_state;
T r_prev_out = 0; T r_prev_out = 0;
T r_output; T r_output;
T *update_gate = gate_value; T *update_gate;
if (old_version) {
update_gate = gate_value;
} else {
update_gate = gate_value + frame_size;
}
T *frame_state = gate_value + frame_size * 2; T *frame_state = gate_value + frame_size * 2;
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
...@@ -83,16 +99,26 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -83,16 +99,26 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, const T *prev_output_value, int frame_size,
ActivationType active_gate) { ActivationType active_gate,
bool old_version = true,
const T *reset_bias = nullptr) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_reset_output; __m256 r_value_reset_output;
__m256 r_prev_out = _mm256_set1_ps(0.0f), __m256 r_prev_out = _mm256_set1_ps(0.0f),
r_prev_out_last = _mm256_set1_ps(0.0f); r_prev_out_last = _mm256_set1_ps(0.0f);
T *update_gate = gate_value; __m256 r_reset_bias = _mm256_set1_ps(0.0f);
T *reset_gate = gate_value + frame_size; T *update_gate;
T *reset_gate;
if (old_version) {
update_gate = gate_value;
reset_gate = gate_value + frame_size;
} else {
reset_gate = gate_value;
update_gate = gate_value + frame_size;
}
int block = 8; int block = 8;
const int n = frame_size; const int n = frame_size;
const int rest = n % block; const int rest = n % block;
...@@ -115,9 +141,15 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -115,9 +141,15 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
if (prev_output_value) { if (prev_output_value) {
r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i)); r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
} }
if (!old_version) {
r_reset_bias = _mm256_loadu_ps((const float *)(reset_bias + i));
r_value_reset_output =
_mm256_loadu_ps((const float *)(reset_output_value + i));
}
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
&r_value_reset_output, active_gate); &r_value_reset_output, active_gate, &r_reset_bias,
old_version);
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i), _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
r_value_update_gate); r_value_update_gate);
...@@ -131,7 +163,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -131,7 +163,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
i = n - block; i = n - block;
op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last, op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last,
&r_prev_out_last, &r_value_reset_output, active_gate); &r_prev_out_last, &r_value_reset_output, active_gate,
&r_reset_bias, old_version);
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i), _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
r_value_update_gate_last); r_value_update_gate_last);
...@@ -145,17 +178,24 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -145,17 +178,24 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, const T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
ActivationType active_node, ActivationType active_node,
bool origin_mode) { bool origin_mode,
bool old_version = true) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
__m256 r_prev_out = _mm256_set1_ps(0.0f), __m256 r_prev_out = _mm256_set1_ps(0.0f),
r_prev_out_last = _mm256_set1_ps(0.0f); r_prev_out_last = _mm256_set1_ps(0.0f);
__m256 r_output; __m256 r_output;
T *update_gate = gate_value; T *update_gate;
if (old_version) {
update_gate = gate_value;
} else {
update_gate = gate_value + frame_size;
}
T *frame_state = gate_value + frame_size * 2; T *frame_state = gate_value + frame_size * 2;
int block = 8; int block = 8;
const int n = frame_size; const int n = frame_size;
...@@ -205,19 +245,21 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -205,19 +245,21 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output, inline void forward_reset_output(OpResetOutput op_reset_output,
GRUMetaValue<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate) { int batch_size, ActivationType active_gate,
bool old_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) { (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output( hl_avx_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value, op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate); value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} else { } else {
hl_naive_gru_forward_reset_output( hl_naive_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value, op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate); value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.reset_output_value += frame_size; value.reset_output_value += frame_size;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -230,17 +272,19 @@ template <class OpFinalOutput, typename T> ...@@ -230,17 +272,19 @@ template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output, inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node, int batch_size, ActivationType active_node,
bool origin_mode) { bool origin_mode, bool old_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) { (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value, hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value, value.prev_out_value, value.output_value,
frame_size, active_node, origin_mode); frame_size, active_node, origin_mode,
old_version);
} else { } else {
hl_naive_gru_forward_final_output( hl_naive_gru_forward_final_output(op_final_output, value.gate_value,
op_final_output, value.gate_value, value.prev_out_value, value.prev_out_value,
value.output_value, frame_size, active_node, origin_mode); value.output_value, frame_size,
active_node, origin_mode, old_version);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
...@@ -253,7 +297,7 @@ inline void forward_final_output(OpFinalOutput op_final_output, ...@@ -253,7 +297,7 @@ inline void forward_final_output(OpFinalOutput op_final_output,
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size,
ActivationType active_node, ActivationType active_node,
...@@ -295,7 +339,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -295,7 +339,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
ActivationType active_gate) { ActivationType active_gate) {
...@@ -340,7 +384,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -340,7 +384,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, ActivationType active_node, int frame_size, ActivationType active_node,
bool origin_mode) { bool origin_mode) {
...@@ -364,7 +408,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -364,7 +408,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
r_frame_state_value = frame_state_value[i]; r_frame_state_value = frame_state_value[i];
r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i]; r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i];
if (prev_out_value) { if (prev_out_value) {
r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i]; r_prev_out_value = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
} }
if (prev_out_grad) { if (prev_out_grad) {
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
...@@ -385,7 +429,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -385,7 +429,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
ActivationType active_gate) { ActivationType active_gate) {
...@@ -412,7 +456,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -412,7 +456,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i]; r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
} }
if (prev_out_value) { if (prev_out_value) {
r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i]; r_prev_out_value = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
} }
if (prev_out_grad) { if (prev_out_grad) {
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
...@@ -431,6 +475,135 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -431,6 +475,135 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
#endif #endif
} }
template <class OpGruGrad, typename T>
inline void hl_naive_gru_backward(OpGruGrad op_gru_grad, T *gate_value,
T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_value,
T *reset_output_grad, T *output_grad,
int frame_size, ActivationType active_node,
ActivationType active_gate) {
T r_value_reset_gate;
T r_grad_reset_gate;
T r_value_update_gate;
T r_grad_update_gate;
T r_value_frame_state;
T r_grad_frame_state;
T r_value_prev_out = 0;
T r_grad_prev_out = 0;
T r_grad_output;
T r_value_reset_output;
T r_grad_reset_output = 0;
T *reset_gate_value = gate_value;
T *reset_gate_grad = gate_grad;
T *update_gate_value = gate_value + frame_size;
T *update_gate_grad = gate_grad + frame_size;
T *frame_state_value = gate_value + 2 * frame_size;
T *frame_state_grad = gate_grad + 2 * frame_size;
for (int i = 0; i < frame_size; ++i) {
r_value_reset_gate = reset_gate_value[i];
r_grad_reset_gate = reset_gate_grad[i];
r_value_update_gate = update_gate_value[i];
r_grad_update_gate = update_gate_grad[i];
r_value_frame_state = frame_state_value[i];
r_grad_frame_state = frame_state_grad[i];
if (prev_out_value) {
r_value_prev_out = prev_out_value[i];
}
if (prev_out_grad) {
r_grad_prev_out = prev_out_grad[i];
}
r_grad_output = output_grad[i];
r_value_reset_output = reset_output_value[i];
if (prev_out_value && prev_out_grad) {
r_grad_reset_output = reset_output_grad[i];
}
op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate,
&r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state,
&r_value_prev_out, &r_grad_prev_out, &r_grad_output,
&r_value_reset_output, &r_grad_reset_output, active_node,
active_gate);
reset_gate_grad[i] = r_grad_reset_gate;
update_gate_grad[i] = r_grad_update_gate;
frame_state_grad[i] = r_grad_frame_state;
if (prev_out_grad) {
prev_out_grad[i] = r_grad_prev_out;
}
if (prev_out_value && prev_out_grad) {
reset_output_grad[i] = r_grad_reset_output;
}
}
}
template <class OpGruGrad, typename T>
inline void hl_avx_gru_backward(OpGruGrad op_gru_grad, T *gate_value,
T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_value,
T *reset_output_grad, T *output_grad,
int frame_size, ActivationType active_node,
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_value_reset_gate;
__m256 r_grad_reset_gate;
__m256 r_value_update_gate;
__m256 r_grad_update_gate;
__m256 r_value_frame_state;
__m256 r_grad_frame_state;
__m256 r_value_prev_out = _mm256_set1_ps(0.0f);
__m256 r_grad_prev_out = _mm256_set1_ps(0.0f);
__m256 r_grad_output;
__m256 r_value_reset_output;
__m256 r_grad_reset_output = _mm256_set1_ps(0.0f);
__m256 *reset_gate_value = reinterpret_cast<__m256 *>(gate_value);
__m256 *reset_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
__m256 *update_gate_value =
reinterpret_cast<__m256 *>(gate_value + frame_size);
__m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad + frame_size);
__m256 *frame_state_value =
reinterpret_cast<__m256 *>(gate_value + 2 * frame_size);
__m256 *frame_state_grad =
reinterpret_cast<__m256 *>(gate_grad + 2 * frame_size);
for (int i = 0; i < frame_size / 8; ++i) {
r_value_reset_gate = reset_gate_value[i];
r_grad_reset_gate = reset_gate_grad[i];
r_value_update_gate = update_gate_value[i];
r_grad_update_gate = update_gate_grad[i];
r_value_frame_state = frame_state_value[i];
r_grad_frame_state = frame_state_grad[i];
if (prev_out_value) {
r_value_prev_out = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
}
if (prev_out_grad) {
r_grad_prev_out = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
}
r_grad_output = (reinterpret_cast<__m256 *>(output_grad))[i];
r_value_reset_output = (reinterpret_cast<__m256 *>(reset_output_value))[i];
if (prev_out_value && prev_out_grad) {
r_grad_reset_output = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
}
op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate,
&r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state,
&r_value_prev_out, &r_grad_prev_out, &r_grad_output,
&r_value_reset_output, &r_grad_reset_output, active_node,
active_gate);
reset_gate_grad[i] = r_grad_reset_gate;
update_gate_grad[i] = r_grad_update_gate;
frame_state_grad[i] = r_grad_frame_state;
if (prev_out_grad) {
(reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_grad_prev_out;
}
if (prev_out_value && prev_out_grad) {
(reinterpret_cast<__m256 *>(reset_output_grad))[i] = r_grad_reset_output;
}
}
#endif
}
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad, inline void backward_state_grad(OpStateGrad op_state_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
...@@ -491,6 +664,39 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, ...@@ -491,6 +664,39 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad,
} }
} }
template <class OpGruGrad, typename T>
inline void cpu_gru_backward(OpGruGrad op_gru_grad, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size,
int batch_size, ActivationType active_node,
ActivationType active_gate) {
for (int b = 0; b < batch_size; ++b) {
if (OpGruGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward(
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
} else {
hl_naive_gru_backward(
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
}
value.gate_value += frame_size * 3;
value.reset_output_value += frame_size;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
grad.gate_grad += frame_size * 3;
grad.output_grad += frame_size;
grad.reset_output_grad += frame_size;
if (grad.prev_out_grad) {
grad.prev_out_grad += frame_size;
}
}
}
#endif #endif
} // namespace detail } // namespace detail
......
...@@ -31,8 +31,8 @@ namespace detail { ...@@ -31,8 +31,8 @@ namespace detail {
template <class OpResetOutput, bool is_batch, typename T> template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, const T *prev_output_value,
int batch_size, int frame_size, int batch_size,
ActivationType active_gate) { ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -68,12 +68,10 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, ...@@ -68,12 +68,10 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
* grid(frame_blocks, batch_blocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpFinalOutput, bool is_batch, typename T> template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, __global__ void KeGruForwardFinalOutput(
T *gate_value, T *prev_output_value, OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size, int batch_size, ActivationType active_node,
int batch_size, bool origin_mode) {
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -106,8 +104,9 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -106,8 +104,9 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
* grid(frame_blocks, 1) * grid(frame_blocks, 1)
*/ */
template <class T, int Tiled_size> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, __global__ void KeFastCollectiveGruGate(T *gate_value,
T *gate_weight, T *reset_output, const T *prev_output_value,
const T *gate_weight, T *reset_output,
int frame_size, int frame_size,
ActivationType active_node) { ActivationType active_node) {
T xt_0 = 0.0f; T xt_0 = 0.0f;
...@@ -164,10 +163,10 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, ...@@ -164,10 +163,10 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
* grid(frame_blocks, 1) * grid(frame_blocks, 1)
*/ */
template <class T, int Tiled_size> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, __global__ void KeFastCollectiveGruOut(const T *gate_weight,
T *output_value, T *gate_value, const T *prev_out_value, T *output_value,
T *reset_value, int frame_size, T *gate_value, T *reset_value,
ActivationType act_node, int frame_size, ActivationType act_node,
bool origin_mode) { bool origin_mode) {
int COL = blockIdx.x * blockDim.x + threadIdx.x; int COL = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -223,7 +222,7 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, ...@@ -223,7 +222,7 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
*/ */
template <class OpStateGrad, bool is_batch, typename T> template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_node,
...@@ -272,7 +271,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -272,7 +271,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
*/ */
template <class OpResetGrad, bool is_batch, typename T> template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_gate) { ActivationType active_gate) {
......
...@@ -30,10 +30,17 @@ class gru_resetOutput { ...@@ -30,10 +30,17 @@ class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate, HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate,
T *prev_out, T *value_reset_output, T *prev_out, T *value_reset_output,
ActivationType act_gate) { ActivationType act_gate,
T *value_reset_bias = nullptr,
bool old_version = true) {
*value_update_gate = activation(*value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
*value_reset_gate = activation(*value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
*value_reset_output = (*prev_out) * (*value_reset_gate); if (old_version) {
*value_reset_output = (*prev_out) * (*value_reset_gate);
} else {
*value_reset_output =
(*value_reset_output + *value_reset_bias) * (*value_reset_gate);
}
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
...@@ -43,10 +50,19 @@ class gru_resetOutput { ...@@ -43,10 +50,19 @@ class gru_resetOutput {
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_reset_gate, __m256 *prev_out, __m256 *value_reset_gate, __m256 *prev_out,
__m256 *value_reset_output, __m256 *value_reset_output,
ActivationType act_gate) { ActivationType act_gate,
__m256 *value_reset_bias = nullptr,
bool old_version = true) {
*value_update_gate = activation(*value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
*value_reset_gate = activation(*value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
*value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate); if (old_version) {
*value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate);
} else {
*value_reset_output =
_mm256_add_ps(*value_reset_output, *value_reset_bias);
*value_reset_output =
_mm256_mul_ps(*value_reset_output, *value_reset_gate);
}
} }
#endif #endif
#endif #endif
...@@ -192,6 +208,61 @@ class gru_resetGrad { ...@@ -192,6 +208,61 @@ class gru_resetGrad {
#endif #endif
#endif #endif
}; };
template <typename T>
class gru {
public:
HOSTDEVICE void operator()(T *value_reset_gate, T *grad_reset_gate,
T *value_update_gate, T *grad_update_gate,
T *value_frame_state, T *grad_frame_state,
T *value_prev_out, T *grad_prev_out,
T *grad_output, T *value_reset_output,
T *grad_reset_output, ActivationType act_node,
ActivationType act_gate) {
*grad_update_gate =
activation((*grad_output) * ((*value_prev_out) - (*value_frame_state)),
(*value_update_gate), act_gate);
*grad_prev_out += (*grad_output * (*value_update_gate));
*grad_frame_state =
activation(*grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_node);
T reset_output = (*value_reset_output) / (*value_reset_gate);
*grad_reset_gate = activation(reset_output * (*grad_frame_state),
*value_reset_gate, act_gate);
*grad_reset_output = (*value_reset_gate) * (*grad_frame_state);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_reset_gate, __m256 *grad_reset_gate,
__m256 *value_update_gate,
__m256 *grad_update_gate,
__m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_output,
__m256 *value_reset_output,
__m256 *grad_reset_output, ActivationType act_node,
ActivationType act_gate) {
*grad_update_gate = activation(
_mm256_mul_ps(*grad_output,
_mm256_sub_ps(*value_prev_out, *value_frame_state)),
*value_update_gate, act_gate);
*grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
*grad_frame_state = activation(
_mm256_mul_ps(*grad_output,
_mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)),
*value_frame_state, act_node);
__m256 reset_output = _mm256_div_ps(*value_reset_output, *value_reset_gate);
*grad_reset_gate =
activation(_mm256_mul_ps(reset_output, *grad_frame_state),
*value_reset_gate, act_gate);
*grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state);
}
#endif
#endif
};
} // namespace backward } // namespace backward
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
...@@ -28,6 +30,11 @@ namespace operators { ...@@ -28,6 +30,11 @@ namespace operators {
namespace math { namespace math {
namespace detail { namespace detail {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__ #ifndef __NVCC__
template <class T, class Op> template <class T, class Op>
...@@ -35,7 +42,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -35,7 +42,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, T cell_clip, int frame_size, T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -48,10 +56,15 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -48,10 +56,15 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
T r_state_atv; T r_state_atv;
T r_out; T r_out;
T *value_in = value.gate_value; T *value_ig = value.gate_value;
T *value_ig = value.gate_value + frame_size; T *value_fg = value.gate_value + frame_size;
T *value_fg = value.gate_value + frame_size * 2; T *value_in = value.gate_value + frame_size * 2;
T *value_og = value.gate_value + frame_size * 3; T *value_og = value.gate_value + frame_size * 3;
if (old_api_version) {
value_in = value.gate_value;
value_ig = value.gate_value + frame_size;
value_fg = value.gate_value + frame_size * 2;
}
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -85,7 +98,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -85,7 +98,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
T cell_clip, ActivationType active_node, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -107,14 +121,25 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -107,14 +121,25 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
T r_checkFGrad; T r_checkFGrad;
T r_checkOGrad; T r_checkOGrad;
T *value_in = value.gate_value; T *value_ig = value.gate_value;
T *value_ig = value.gate_value + frame_size; T *value_fg = value.gate_value + frame_size;
T *value_fg = value.gate_value + frame_size * 2; T *value_in = value.gate_value + frame_size * 2;
T *value_og = value.gate_value + frame_size * 3; T *value_og = value.gate_value + frame_size * 3;
T *grad_in = grad.gate_grad; if (old_api_version) {
T *grad_ig = grad.gate_grad + frame_size; value_in = value.gate_value;
T *grad_fg = grad.gate_grad + frame_size * 2; value_ig = value.gate_value + frame_size;
value_fg = value.gate_value + frame_size * 2;
}
T *grad_ig = grad.gate_grad;
T *grad_fg = grad.gate_grad + frame_size;
T *grad_in = grad.gate_grad + frame_size * 2;
T *grad_og = grad.gate_grad + frame_size * 3; T *grad_og = grad.gate_grad + frame_size * 3;
if (old_api_version) {
grad_in = grad.gate_grad;
grad_ig = grad.gate_grad + frame_size;
grad_fg = grad.gate_grad + frame_size * 2;
}
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -158,7 +183,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -158,7 +183,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, T cell_clip, int frame_size, T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -172,12 +198,17 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -172,12 +198,17 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 r_state_atv; __m256 r_state_atv;
__m256 r_out; __m256 r_out;
__m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); __m256 *value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
__m256 *value_fg = __m256 *value_in =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
__m256 *value_og = __m256 *value_og =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3);
if (old_api_version) {
value_in = reinterpret_cast<__m256 *>(value.gate_value);
value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
}
for (int i = 0; i < frame_size / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -191,7 +222,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -191,7 +222,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
} }
if (value.prev_state_value) { if (value.prev_state_value) {
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state =
(reinterpret_cast<__m256 const *>(value.prev_state_value))[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
...@@ -214,7 +246,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -214,7 +246,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
T cell_clip, ActivationType active_node, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -237,16 +270,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -237,16 +270,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 r_checkFGrad; __m256 r_checkFGrad;
__m256 r_checkOGrad; __m256 r_checkOGrad;
__m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); __m256 *value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
__m256 *value_fg = __m256 *value_in =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
__m256 *value_og = __m256 *value_og =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3);
__m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad); if (old_api_version) {
__m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size); value_in = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2); value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
}
__m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad);
__m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size);
__m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2);
__m256 *grad_og = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 3); __m256 *grad_og = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 3);
if (old_api_version) {
grad_in = reinterpret_cast<__m256 *>(grad.gate_grad);
grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size);
grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2);
}
for (int i = 0; i < frame_size / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -263,7 +307,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -263,7 +307,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_output_grad = (reinterpret_cast<__m256 *>(grad.output_grad))[i]; r_output_grad = (reinterpret_cast<__m256 *>(grad.output_grad))[i];
r_state_grad = (reinterpret_cast<__m256 *>(grad.state_grad))[i]; r_state_grad = (reinterpret_cast<__m256 *>(grad.state_grad))[i];
if (value.prev_state_value) { if (value.prev_state_value) {
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state =
(reinterpret_cast<__m256 const *>(value.prev_state_value))[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
...@@ -292,30 +337,133 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -292,30 +337,133 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
#endif #endif
} }
template <class T>
void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context,
LstmMetaValue<T> value, int frame_size) {
auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto eigen_value_in = typename EigenVector<T>::Type(
value.gate_value + frame_size * 2, Array1(frame_size));
auto eigen_value_og = typename EigenVector<T>::Type(
value.gate_value + frame_size * 3, Array1(frame_size));
auto eigen_state =
typename EigenVector<T>::Type(value.state_value, Array1(frame_size));
auto eigen_state_act = typename EigenVector<T>::Type(value.state_active_value,
Array1(frame_size));
auto eigen_output =
typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
auto &place = *context.eigen_device();
TanhFunctor<T>()(place, eigen_value_in, eigen_value_in);
SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig);
SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg);
SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og);
eigen_state.device(place) = eigen_value_in * eigen_value_ig;
if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size));
eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg;
}
TanhFunctor<T>()(place, eigen_state, eigen_state_act);
eigen_output.device(place) = eigen_value_og * eigen_state_act;
}
template <class T>
void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context,
LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size) {
auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto eigen_value_in = typename EigenVector<T>::Type(
value.gate_value + frame_size * 2, Array1(frame_size));
auto eigen_value_og = typename EigenVector<T>::Type(
value.gate_value + frame_size * 3, Array1(frame_size));
auto eigen_state_act = typename EigenVector<T>::Type(value.state_active_value,
Array1(frame_size));
auto eigen_grad_ig =
typename EigenVector<T>::Type(grad.gate_grad, Array1(frame_size));
auto eigen_grad_fg = typename EigenVector<T>::Type(
grad.gate_grad + frame_size, Array1(frame_size));
auto eigen_grad_in = typename EigenVector<T>::Type(
grad.gate_grad + frame_size * 2, Array1(frame_size));
auto eigen_grad_og = typename EigenVector<T>::Type(
grad.gate_grad + frame_size * 3, Array1(frame_size));
auto eigen_grad_output =
typename EigenVector<T>::Type(grad.output_grad, Array1(frame_size));
auto eigen_grad_state =
typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size));
auto &place = *context.eigen_device();
SigmoidGradFunctor<T>()(place, 1 /*useless*/, eigen_value_og,
eigen_grad_output * eigen_state_act, eigen_grad_og);
eigen_grad_state.device(place) =
eigen_grad_state +
eigen_grad_output * eigen_value_og *
(static_cast<T>(1) - eigen_state_act * eigen_state_act);
TanhGradFunctor<T>()(place, 1, eigen_value_in,
eigen_grad_state * eigen_value_ig, eigen_grad_in);
SigmoidGradFunctor<T>()(place, 1, eigen_value_ig,
eigen_grad_state * eigen_value_in, eigen_grad_ig);
if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size));
SigmoidGradFunctor<T>()(place, 1, eigen_value_fg,
eigen_grad_state * eigen_prev_state, eigen_grad_fg);
} else {
SigmoidGradFunctor<T>()(place, 1, eigen_value_fg, 0, eigen_grad_fg);
}
if (grad.prev_state_grad) {
auto eigen_grad_pre_state =
typename EigenVector<T>::Type(grad.prev_state_grad, Array1(frame_size));
eigen_grad_pre_state.device(place) = eigen_grad_state * eigen_value_fg;
}
}
template <class T, class Op> template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size, void cpu_lstm_forward(const platform::CPUDeviceContext &context, Op op,
T cell_clip, ActivationType active_node, LstmMetaValue<T> value, int frame_size, T cell_clip,
ActivationType active_gate, ActivationType active_state) { ActivationType active_node, ActivationType active_gate,
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { ActivationType active_state, bool old_api_version) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip, if (!old_api_version) {
active_node, active_gate, active_state); eigen_lstm_forward_one_sequence<T>(context, value, frame_size);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip, if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
active_node, active_gate, active_state); avx_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_node, active_gate, active_state,
old_api_version);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_node, active_gate, active_state,
old_api_version);
}
} }
} }
template <class T, class Op> template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, T cell_clip, ActivationType active_node, int frame_size, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate, ActivationType active_state,
ActivationType active_state) { bool old_api_version) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (!old_api_version) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip, eigen_lstm_backward_one_sequence<T>(context, value, grad, frame_size);
active_node, active_gate, active_state);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip, if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
active_node, active_gate, active_state); avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip,
active_node, active_gate, active_state,
old_api_version);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size,
cell_clip, active_node, active_gate,
active_state, old_api_version);
}
} }
} }
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h"
...@@ -101,11 +102,64 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -101,11 +102,64 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template <typename T>
struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, frame_size, frame_size, 1,
value.prev_out_value, value.state_weight, 0,
value.reset_output_value);
}
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate, false);
T *cell_state_value = value.gate_value + 2 * frame_size;
T *reset_output_value = value.reset_output_value;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(frame_size, cell_state_value, reset_output_value,
cell_state_value);
cell_state_value += frame_size * 3;
reset_output_value += frame_size;
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node, true,
false);
#endif
}
};
template <typename T>
struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
// calculate grad_update_gate, grad_frame_state,
// grad_reset_output, grad_reset_gate
detail::cpu_gru_backward(detail::backward::gru<T>(), value, grad,
frame_size, batch_size, active_node, active_gate);
#endif
}
};
template struct GRUUnitFunctor<platform::CPUDeviceContext, float>; template struct GRUUnitFunctor<platform::CPUDeviceContext, float>;
template struct GRUUnitFunctor<platform::CPUDeviceContext, double>; template struct GRUUnitFunctor<platform::CPUDeviceContext, double>;
template struct GRUUnitGradFunctor<platform::CPUDeviceContext, float>; template struct GRUUnitGradFunctor<platform::CPUDeviceContext, float>;
template struct GRUUnitGradFunctor<platform::CPUDeviceContext, double>; template struct GRUUnitGradFunctor<platform::CPUDeviceContext, double>;
template struct GRUUnitFunctorV2<platform::CPUDeviceContext, float>;
template struct GRUUnitFunctorV2<platform::CPUDeviceContext, double>;
template struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, float>;
template struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -21,12 +21,13 @@ namespace math { ...@@ -21,12 +21,13 @@ namespace math {
template <typename T> template <typename T>
struct GRUMetaValue { struct GRUMetaValue {
T *gate_weight; const T *gate_weight;
T *state_weight; const T *state_weight;
const T *reset_bias;
T *gate_value; T *gate_value;
T *reset_output_value; T *reset_output_value;
T *output_value; T *output_value;
T *prev_out_value; const T *prev_out_value;
}; };
template <typename T> template <typename T>
...@@ -37,6 +38,7 @@ struct GRUMetaGrad { ...@@ -37,6 +38,7 @@ struct GRUMetaGrad {
T *reset_output_grad; T *reset_output_grad;
T *output_grad; T *output_grad;
T *prev_out_grad; T *prev_out_grad;
T *state_bias_grad;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -57,6 +59,22 @@ struct GRUUnitGradFunctor { ...@@ -57,6 +59,22 @@ struct GRUUnitGradFunctor {
bool origin_mode); bool origin_mode);
}; };
template <typename DeviceContext, typename T>
struct GRUUnitFunctorV2 {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
template <typename DeviceContext, typename T>
struct GRUUnitGradFunctorV2 {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -33,10 +33,12 @@ struct LstmUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -33,10 +33,12 @@ struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
T cell_clip, const detail::ActivationType& gate_act, T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size, detail::cpu_lstm_forward(context, detail::forward::lstm<T>(), value,
cell_clip, cand_act, gate_act, cell_act); frame_size, cell_clip, cand_act, gate_act,
cell_act, old_api_version);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
value.state_active_value += frame_size; value.state_active_value += frame_size;
...@@ -55,11 +57,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -55,11 +57,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
int frame_size, int batch_size, T cell_clip, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad, detail::cpu_lstm_backward(context, detail::backward::lstm<T>(), value,
frame_size, cell_clip, cand_act, gate_act, grad, frame_size, cell_clip, cand_act, gate_act,
cell_act); cell_act, old_api_version);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
......
...@@ -26,7 +26,8 @@ struct LstmUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -26,7 +26,8 @@ struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
T cell_clip, const detail::ActivationType& gate_act, T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value, detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, cell_clip, cand_act, frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act); gate_act, cell_act);
...@@ -40,7 +41,8 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -40,7 +41,8 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
int frame_size, int batch_size, T cell_clip, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad, detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, cell_clip, cand_act, frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act); gate_act, cell_act);
......
...@@ -25,7 +25,7 @@ namespace math { ...@@ -25,7 +25,7 @@ namespace math {
template <class T> template <class T>
struct LstmMetaValue { struct LstmMetaValue {
T *gate_value; T *gate_value;
T *prev_state_value; const T *prev_state_value;
T *state_value; T *state_value;
T *state_active_value; T *state_active_value;
T *output_value; T *output_value;
...@@ -53,7 +53,8 @@ class LstmUnitFunctor { ...@@ -53,7 +53,8 @@ class LstmUnitFunctor {
int frame_size, int batch_size, T cell_clip, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType &gate_act, const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act, const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act); const detail::ActivationType &cand_act,
bool old_api_version = true);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -63,7 +64,8 @@ class LstmUnitGradFunctor { ...@@ -63,7 +64,8 @@ class LstmUnitGradFunctor {
LstmMetaGrad<T> grad, int frame_size, int batch_size, LstmMetaGrad<T> grad, int frame_size, int batch_size,
T cell_clip, const detail::ActivationType &gate_act, T cell_clip, const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act, const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act); const detail::ActivationType &cand_act,
bool old_api_version = true);
}; };
} // namespace math } // namespace math
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/rnn_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -251,5 +252,10 @@ REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker, ...@@ -251,5 +252,10 @@ REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker,
ops::RNNGradOpMaker<paddle::imperative::OpBase>); ops::RNNGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp); REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp);
REGISTER_OP_CPU_KERNEL(rnn, ops::NotImpleKernel<float>); REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(rnn_grad, ops::NotImpleKernel<float>); rnn, ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
rnn_grad, ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -524,6 +524,12 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> { ...@@ -524,6 +524,12 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
offset += len; offset += len;
} }
Tensor input_grad_value;
if (!in_grad) {
in_grad = &input_grad_value;
in_grad->Resize(input->dims());
}
auto *init_h_data = pre_state[0]->data<T>(); auto *init_h_data = pre_state[0]->data<T>();
// auto *last_h_data = state[0]->data<T>(); // auto *last_h_data = state[0]->data<T>();
auto *last_h_grad_data = state_grad[0]->data<T>(); auto *last_h_grad_data = state_grad[0]->data<T>();
......
此差异已折叠。
...@@ -65,10 +65,18 @@ class TestLstm(unittest.TestCase): ...@@ -65,10 +65,18 @@ class TestLstm(unittest.TestCase):
paddle.jit.ProgramTranslator().enable(True) paddle.jit.ProgramTranslator().enable(True)
net = Net(12, 2) net = Net(12, 2)
x = paddle.randn((2, 10, 12)) x = paddle.randn((2, 10, 12))
x.stop_gradient = False
dygraph_out = net(x) dygraph_out = net(x)
loss = paddle.mean(dygraph_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
# switch eval mode firstly # switch eval mode firstly
net.eval() net.eval()
x = paddle.randn((2, 10, 12))
dygraph_out = net(x)
dropout_out = net(x)
net = paddle.jit.to_static( net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])])
paddle.jit.save(net, 'simple_lstm') paddle.jit.save(net, 'simple_lstm')
...@@ -106,6 +114,14 @@ class TestSaveInEvalMode(unittest.TestCase): ...@@ -106,6 +114,14 @@ class TestSaveInEvalMode(unittest.TestCase):
def test_save_in_eval(self): def test_save_in_eval(self):
paddle.jit.ProgramTranslator().enable(True) paddle.jit.ProgramTranslator().enable(True)
net = LinearNet() net = LinearNet()
x = paddle.randn((2, 10))
x.stop_gradient = False
dygraph_out = net(x)
loss = paddle.mean(dygraph_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
# switch eval mode firstly # switch eval mode firstly
net.eval() net.eval()
# save directly # save directly
...@@ -129,6 +145,14 @@ class TestEvalAfterSave(unittest.TestCase): ...@@ -129,6 +145,14 @@ class TestEvalAfterSave(unittest.TestCase):
def test_eval_after_save(self): def test_eval_after_save(self):
x = paddle.randn((2, 10, 12)).astype('float32') x = paddle.randn((2, 10, 12)).astype('float32')
net = Net(12, 2) net = Net(12, 2)
x.stop_gradient = False
dy_out = net(x)
loss = paddle.mean(dy_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
x = paddle.randn((2, 10, 12)).astype('float32')
dy_out = net(x) dy_out = net(x)
# save model # save model
paddle.jit.save(net, 'jit.save/lstm', input_spec=[x]) paddle.jit.save(net, 'jit.save/lstm', input_spec=[x])
......
...@@ -49,3 +49,34 @@ def convert_params_for_net_static(np_net, paddle_net, place): ...@@ -49,3 +49,34 @@ def convert_params_for_net_static(np_net, paddle_net, place):
paddle_layer.cell_fw, place) paddle_layer.cell_fw, place)
convert_params_for_cell_static(np_layer.cell_bw, convert_params_for_cell_static(np_layer.cell_bw,
paddle_layer.cell_bw, place) paddle_layer.cell_bw, place)
def get_params_for_cell(np_cell, num_layers, idx):
state = np_cell.parameters
weight_list = [
('{}.weight_{}'.format(num_layers, idx), state['weight_ih']),
('{}.weight_{}'.format(num_layers, idx + 1), state['weight_hh'])
]
bias_list = [('{}.bias_{}'.format(num_layers, idx), state['bias_ih']),
('{}.bias_{}'.format(num_layers, idx + 1), state['bias_hh'])]
return weight_list, bias_list
def get_params_for_net(np_net):
weight_list = []
bias_list = []
for layer_idx, np_layer in enumerate(np_net):
if hasattr(np_layer, "cell"):
weight, bias = get_params_for_cell(np_layer.cell, layer_idx, 0)
for w, b in zip(weight, bias):
weight_list.append(w)
bias_list.append(b)
else:
for count, cell in enumerate([np_layer.cell_fw, np_layer.cell_bw]):
weight, bias = get_params_for_cell(cell, layer_idx, count * 2)
for w, b in zip(weight, bias):
weight_list.append(w)
bias_list.append(b)
weight_list.extend(bias_list)
return weight_list
...@@ -33,11 +33,16 @@ class LayerListMixin(LayerMixin): ...@@ -33,11 +33,16 @@ class LayerListMixin(LayerMixin):
class SimpleRNNCell(LayerMixin): class SimpleRNNCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): def __init__(self,
input_size,
hidden_size,
bias=True,
nonlinearity="RNN_TANH",
dtype="float64"):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.bias = bias self.bias = bias
if nonlinearity == 'tanh': if nonlinearity == 'RNN_TANH':
self.nonlinearity = np.tanh self.nonlinearity = np.tanh
else: else:
self.nonlinearity = lambda x: np.maximum(x, 0.) self.nonlinearity = lambda x: np.maximum(x, 0.)
...@@ -45,16 +50,16 @@ class SimpleRNNCell(LayerMixin): ...@@ -45,16 +50,16 @@ class SimpleRNNCell(LayerMixin):
self.parameters = dict() self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, ( self.weight_ih = np.random.uniform(-std, std, (
hidden_size, input_size)).astype('float64') hidden_size, input_size)).astype(dtype)
self.weight_hh = np.random.uniform(-std, std, ( self.weight_hh = np.random.uniform(-std, std, (
hidden_size, hidden_size)).astype('float64') hidden_size, hidden_size)).astype(dtype)
self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh self.parameters['weight_hh'] = self.weight_hh
if bias: if bias:
self.bias_ih = np.random.uniform(-std, std, self.bias_ih = np.random.uniform(-std, std,
(hidden_size, )).astype('float64') (hidden_size, )).astype(dtype)
self.bias_hh = np.random.uniform(-std, std, self.bias_hh = np.random.uniform(-std, std,
(hidden_size, )).astype('float64') (hidden_size, )).astype(dtype)
self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh self.parameters['bias_hh'] = self.bias_hh
else: else:
...@@ -80,23 +85,23 @@ class SimpleRNNCell(LayerMixin): ...@@ -80,23 +85,23 @@ class SimpleRNNCell(LayerMixin):
class GRUCell(LayerMixin): class GRUCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True): def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.bias = bias self.bias = bias
self.parameters = dict() self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, ( self.weight_ih = np.random.uniform(-std, std, (
3 * hidden_size, input_size)).astype('float64') 3 * hidden_size, input_size)).astype(dtype)
self.weight_hh = np.random.uniform(-std, std, ( self.weight_hh = np.random.uniform(-std, std, (
3 * hidden_size, hidden_size)).astype('float64') 3 * hidden_size, hidden_size)).astype(dtype)
self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh self.parameters['weight_hh'] = self.weight_hh
if bias: if bias:
self.bias_ih = np.random.uniform(-std, std, ( self.bias_ih = np.random.uniform(-std, std,
3 * hidden_size)).astype('float64') (3 * hidden_size)).astype(dtype)
self.bias_hh = np.random.uniform(-std, std, ( self.bias_hh = np.random.uniform(-std, std,
3 * hidden_size)).astype('float64') (3 * hidden_size)).astype(dtype)
self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh self.parameters['bias_hh'] = self.bias_hh
else: else:
...@@ -128,23 +133,23 @@ class GRUCell(LayerMixin): ...@@ -128,23 +133,23 @@ class GRUCell(LayerMixin):
class LSTMCell(LayerMixin): class LSTMCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True): def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.bias = bias self.bias = bias
self.parameters = dict() self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, ( self.weight_ih = np.random.uniform(-std, std, (
4 * hidden_size, input_size)).astype('float64') 4 * hidden_size, input_size)).astype(dtype)
self.weight_hh = np.random.uniform(-std, std, ( self.weight_hh = np.random.uniform(-std, std, (
4 * hidden_size, hidden_size)).astype('float64') 4 * hidden_size, hidden_size)).astype(dtype)
self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh self.parameters['weight_hh'] = self.weight_hh
if bias: if bias:
self.bias_ih = np.random.uniform(-std, std, ( self.bias_ih = np.random.uniform(-std, std,
4 * hidden_size)).astype('float64') (4 * hidden_size)).astype(dtype)
self.bias_hh = np.random.uniform(-std, std, ( self.bias_hh = np.random.uniform(-std, std,
4 * hidden_size)).astype('float64') (4 * hidden_size)).astype(dtype)
self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh self.parameters['bias_hh'] = self.bias_hh
else: else:
...@@ -403,28 +408,36 @@ class SimpleRNN(RNNMixin): ...@@ -403,28 +408,36 @@ class SimpleRNN(RNNMixin):
input_size, input_size,
hidden_size, hidden_size,
num_layers=1, num_layers=1,
nonlinearity="tanh", nonlinearity="RNN_TANH",
direction="forward", direction="forward",
dropout=0., dropout=0.,
time_major=False): time_major=False,
dtype="float64"):
super(SimpleRNN, self).__init__() super(SimpleRNN, self).__init__()
if direction in ["forward", "backward"]: if direction in ["forward", "backward"]:
is_reverse = direction == "backward" is_reverse = direction == "backward"
cell = SimpleRNNCell(input_size, hidden_size, nonlinearity) cell = SimpleRNNCell(
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = SimpleRNNCell(hidden_size, hidden_size, nonlinearity) cell = SimpleRNNCell(
hidden_size,
hidden_size,
nonlinearity=nonlinearity,
dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction == "bidirectional":
cell_fw = SimpleRNNCell(input_size, hidden_size, nonlinearity) cell_fw = SimpleRNNCell(
cell_bw = SimpleRNNCell(input_size, hidden_size, nonlinearity) input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
cell_bw = SimpleRNNCell(
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell_fw = SimpleRNNCell(2 * hidden_size, hidden_size, cell_fw = SimpleRNNCell(
nonlinearity) 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype)
cell_bw = SimpleRNNCell(2 * hidden_size, hidden_size, cell_bw = SimpleRNNCell(
nonlinearity) 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError(
...@@ -447,23 +460,24 @@ class LSTM(RNNMixin): ...@@ -447,23 +460,24 @@ class LSTM(RNNMixin):
num_layers=1, num_layers=1,
direction="forward", direction="forward",
dropout=0., dropout=0.,
time_major=False): time_major=False,
dtype="float64"):
super(LSTM, self).__init__() super(LSTM, self).__init__()
if direction in ["forward", "backward"]: if direction in ["forward", "backward"]:
is_reverse = direction == "backward" is_reverse = direction == "backward"
cell = LSTMCell(input_size, hidden_size) cell = LSTMCell(input_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = LSTMCell(hidden_size, hidden_size) cell = LSTMCell(hidden_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction == "bidirectional":
cell_fw = LSTMCell(input_size, hidden_size) cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype)
cell_bw = LSTMCell(input_size, hidden_size) cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell_fw = LSTMCell(2 * hidden_size, hidden_size) cell_fw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype)
cell_bw = LSTMCell(2 * hidden_size, hidden_size) cell_bw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError(
...@@ -486,23 +500,24 @@ class GRU(RNNMixin): ...@@ -486,23 +500,24 @@ class GRU(RNNMixin):
num_layers=1, num_layers=1,
direction="forward", direction="forward",
dropout=0., dropout=0.,
time_major=False): time_major=False,
dtype="float64"):
super(GRU, self).__init__() super(GRU, self).__init__()
if direction in ["forward", "backward"]: if direction in ["forward", "backward"]:
is_reverse = direction == "backward" is_reverse = direction == "backward"
cell = GRUCell(input_size, hidden_size) cell = GRUCell(input_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = GRUCell(hidden_size, hidden_size) cell = GRUCell(hidden_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction == "bidirectional":
cell_fw = GRUCell(input_size, hidden_size) cell_fw = GRUCell(input_size, hidden_size, dtype=dtype)
cell_bw = GRUCell(input_size, hidden_size) cell_bw = GRUCell(input_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell_fw = GRUCell(2 * hidden_size, hidden_size) cell_fw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype)
cell_bw = GRUCell(2 * hidden_size, hidden_size) cell_bw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
import sys
sys.path.append("./rnn")
from rnn_numpy import GRU
from convert import get_params_for_net
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
class TestGRUOp(OpTest):
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def setUp(self):
self.op_type = "rnn"
self.dtype = "float64"
self.sequence_length = np.array(
[12, 11, 10, 9, 8, 7, 6, 5], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.is_test = False
self.mode = "GRU"
self.dropout = 0.
seq_length = 12
batch_size = 8
input_size = 4
self.hidden_size = 2
self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
if self.sequence_length is not None:
input[3][1:][:] = 0
input[4][2:][:] = 0
input[2][3:][:] = 0
input[1][4:][:] = 0
rnn1 = GRU(input_size,
self.hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction,
dropout=self.dropout,
dtype=self.dtype)
flat_w = get_params_for_net(rnn1)
output, last_hidden = rnn1(input, sequence_length=self.sequence_length)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
self.hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)],
}
self.attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec,
'input_size': input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'is_test': self.is_test,
'mode': self.mode
}
self.outputs = {
'Out': output,
'State': [('last_hidden', last_hidden)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
def set_attrs(self):
pass
def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
def test_grad(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h']
grad_check_list.extend(var_name_list)
self.check_grad(set(grad_check_list), ['Out', 'last_hidden'])
class TestGRUOp1(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
class TestGRUOp2(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
class TestGRUOp3(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
self.is_test = True
class TestGRUOp4(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
self.is_test = True
class TestGRUOpAvx(TestGRUOp):
def set_attrs(self):
self.dtype = "float32"
self.hidden_size = 8
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import paddle.fluid.core as core
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
import sys
from op_test import OpTest
sys.path.append("./rnn")
from rnn_numpy import SimpleRNN, LSTM, GRU
from convert import get_params_for_net
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
class TestRNNOp(OpTest):
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def setUp(self):
self.op_type = "rnn"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.mode = "LSTM"
self.is_test = False
self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
seq_length = 12
batch_size = 5
input_size = 3
hidden_size = 2
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
if self.sequence_length is not None:
input[11][1:][:] = 0
input[10][2:][:] = 0
input[9][3:][:] = 0
input[8][4:][:] = 0
rnn1 = LSTM(
input_size,
hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction)
flat_w = get_params_for_net(rnn1)
output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype)
init_c = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h), ('init_c', init_c)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h), ('init_c', init_c)],
}
self.attrs = {
'dropout_prob': 0.0,
'is_bidirec': self.is_bidirec,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': self.num_layers,
'mode': self.mode,
'is_test': self.is_test
}
self.outputs = {
'Out': output,
"State": [('last_hidden', last_hidden), ('last_cell', last_cell)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
def set_attrs(self):
pass
def test_grad(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h', 'init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell'])
class TestRNNOp1(TestRNNOp):
def set_attrs(self):
self.sequence_length = None
class TestRNNOp2(TestRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
class TestRNNOp3(TestRNNOp):
def set_attrs(self):
self.is_test = True
self.sequence_length = None
class TestRNNOp4(TestRNNOp):
def set_attrs(self):
self.is_test = True
self.sequence_length = None
self.is_bidirec = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
import sys
sys.path.append("./rnn")
from rnn_numpy import SimpleRNN
from convert import get_params_for_net
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
class TestSimpleRNNOp(OpTest):
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def setUp(self):
self.op_type = "rnn"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.is_test = False
self.mode = "RNN_TANH"
self.dropout = 0.
self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
seq_length = 12
batch_size = 5
input_size = 3
hidden_size = 2
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
if self.sequence_length is not None:
input[11][1:][:] = 0
input[10][2:][:] = 0
input[9][3:][:] = 0
input[8][4:][:] = 0
rnn1 = SimpleRNN(
input_size,
hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction,
dropout=self.dropout,
nonlinearity=self.mode)
flat_w = get_params_for_net(rnn1)
output, last_hidden = rnn1(input, sequence_length=self.sequence_length)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)]
}
self.attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': self.num_layers,
'is_test': self.is_test,
'mode': self.mode
}
self.outputs = {
'Out': output,
'State': [('last_hidden', last_hidden)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
def set_attrs(self):
pass
def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
def test_grad(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h']
grad_check_list.extend(var_name_list)
self.check_grad(set(grad_check_list), ['Out', 'last_hidden'])
class TestSimpleRNNOp1(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
class TestSimpleRNNOp2(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
class TestSimpleRNNOp3(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_test = True
class TestSimpleRNNOp4(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
self.is_test = True
class TestSimpleRNNOp5(TestSimpleRNNOp):
def set_attrs(self):
self.mode = "RNN_RELU"
if __name__ == '__main__':
unittest.main()
...@@ -27,4 +27,5 @@ NEED_TO_FIX_OP_LIST = [ ...@@ -27,4 +27,5 @@ NEED_TO_FIX_OP_LIST = [
'tree_conv', 'tree_conv',
'cvm', 'cvm',
'cudnn_lstm', 'cudnn_lstm',
'rnn',
] ]
...@@ -28,4 +28,5 @@ no_check_set_white_list = [ ...@@ -28,4 +28,5 @@ no_check_set_white_list = [
'check_finite_and_unscale', 'check_finite_and_unscale',
'update_loss_scaling', 'update_loss_scaling',
'cudnn_lstm', 'cudnn_lstm',
'rnn',
] ]
...@@ -43,7 +43,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ ...@@ -43,7 +43,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'yolov3_loss', \ 'yolov3_loss', \
'inverse', \ 'inverse', \
'bilateral_slice',\ 'bilateral_slice',\
'cudnn_lstm' 'cudnn_lstm', \
'rnn', \
] ]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
......
...@@ -985,8 +985,7 @@ class RNNBase(LayerList): ...@@ -985,8 +985,7 @@ class RNNBase(LayerList):
"direction should be forward, backward or bidirectional, " "direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction)) "received direction = {}".format(direction))
self.could_use_cudnn = get_device().startswith( self.could_use_cudnn = True
"gpu:") and get_cudnn_version()
self.could_use_cudnn &= direction != "backward" self.could_use_cudnn &= direction != "backward"
self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * ( self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * (
2 if direction == "bidirectional" else 1) 2 if direction == "bidirectional" else 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册