未验证 提交 9362d85e 编写于 作者: J Jack Zhou 提交者: GitHub

Add LSTM, Simple RNN and GRU CPU kernel (#28577)

* add lstm, simple rnn op kernel

* fix the test_lstm for the rnn op

* change func name

* fix forward postprocess bug

* add gru forward, backward code

* remove unittest.skipIf; use a big rnn op instead of combination op

* fix input doesn't have gradient bug

* add eigen lstm forward, backward
Co-authored-by: Nwawltor <fangzeyang0904@hotmail.com>
上级 30ef3815
...@@ -30,18 +30,24 @@ namespace detail { ...@@ -30,18 +30,24 @@ namespace detail {
enum ActivationType { enum ActivationType {
kSigmoid, kSigmoid,
KSigmoidV2,
kReLU, kReLU,
kTanh, kTanh,
kTanhV2,
kIdentity, kIdentity,
}; };
inline ActivationType GetActivationType(const std::string &type) { inline ActivationType GetActivationType(const std::string &type) {
if (type == "sigmoid") { if (type == "sigmoid") {
return ActivationType::kSigmoid; return ActivationType::kSigmoid;
} else if (type == "sigmoid_v2") {
return ActivationType::KSigmoidV2;
} else if (type == "relu") { } else if (type == "relu") {
return ActivationType::kReLU; return ActivationType::kReLU;
} else if (type == "tanh") { } else if (type == "tanh") {
return ActivationType::kTanh; return ActivationType::kTanh;
} else if (type == "tanh_v2") {
return ActivationType::kTanhV2;
} else if (type == "identity" || type == "") { } else if (type == "identity" || type == "") {
return ActivationType::kIdentity; return ActivationType::kIdentity;
} }
...@@ -68,6 +74,14 @@ DEVICE T Sigmoid(const T a) { ...@@ -68,6 +74,14 @@ DEVICE T Sigmoid(const T a) {
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp)); return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
} }
/*
* Don't limit input in a threshold range.
*/
template <typename T>
DEVICE T SigmoidV2(const T a) {
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-a));
}
template <typename T> template <typename T>
DEVICE T Tanh(const T a) { DEVICE T Tanh(const T a) {
T tmp = -2.0 * a; T tmp = -2.0 * a;
...@@ -75,6 +89,15 @@ DEVICE T Tanh(const T a) { ...@@ -75,6 +89,15 @@ DEVICE T Tanh(const T a) {
return (2.0 / (1.0 + exp(tmp))) - 1.0; return (2.0 / (1.0 + exp(tmp))) - 1.0;
} }
/*
* Don't limit input in a threshold range.
*/
template <typename T>
DEVICE T TanhV2(const T a) {
T tmp = -2.0 * a;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
} // namespace forward } // namespace forward
namespace backward { namespace backward {
...@@ -108,20 +131,24 @@ struct Active { ...@@ -108,20 +131,24 @@ struct Active {
}; };
static DEVICE Active<float>::Act kActFloat[] = { static DEVICE Active<float>::Act kActFloat[] = {
&forward::Sigmoid<float>, &forward::Relu<float>, &forward::Tanh<float>, &forward::Sigmoid<float>, &forward::SigmoidV2<float>,
&forward::Identity<float>}; &forward::Relu<float>, &forward::Tanh<float>,
&forward::TanhV2<float>, &forward::Identity<float>};
static DEVICE Active<float>::ActGrad kActGradFloat[] = { static DEVICE Active<float>::ActGrad kActGradFloat[] = {
&backward::Sigmoid<float>, &backward::Relu<float>, &backward::Tanh<float>, &backward::Sigmoid<float>, &backward::Sigmoid<float>,
&backward::Identity<float>}; &backward::Relu<float>, &backward::Tanh<float>,
&backward::Tanh<float>, &backward::Identity<float>};
static DEVICE Active<double>::Act kActDouble[] = { static DEVICE Active<double>::Act kActDouble[] = {
&forward::Sigmoid<double>, &forward::Relu<double>, &forward::Tanh<double>, &forward::Sigmoid<double>, &forward::SigmoidV2<double>,
&forward::Identity<double>}; &forward::Relu<double>, &forward::Tanh<double>,
&forward::TanhV2<double>, &forward::Identity<double>};
static DEVICE Active<double>::ActGrad kActGradDouble[] = { static DEVICE Active<double>::ActGrad kActGradDouble[] = {
&backward::Sigmoid<double>, &backward::Relu<double>, &backward::Sigmoid<double>, &backward::Sigmoid<double>,
&backward::Tanh<double>, &backward::Identity<double>}; &backward::Relu<double>, &backward::Tanh<double>,
&backward::Tanh<double>, &backward::Identity<double>};
namespace forward { namespace forward {
inline DEVICE float activation(float a, int index) { inline DEVICE float activation(float a, int index) {
...@@ -149,7 +176,9 @@ namespace forward { ...@@ -149,7 +176,9 @@ namespace forward {
namespace avx { namespace avx {
__m256 Relu(const __m256 a); __m256 Relu(const __m256 a);
__m256 Sigmoid(const __m256 a); __m256 Sigmoid(const __m256 a);
__m256 SigmoidV2(const __m256 a);
__m256 Tanh(const __m256 a); __m256 Tanh(const __m256 a);
__m256 TanhV2(const __m256 a);
__m256 Identity(const __m256 a); __m256 Identity(const __m256 a);
} // namespace avx } // namespace avx
} // namespace forward } // namespace forward
...@@ -164,12 +193,12 @@ __m256 Identity(const __m256 a, const __m256 b); ...@@ -164,12 +193,12 @@ __m256 Identity(const __m256 a, const __m256 b);
} // namespace backward } // namespace backward
static Active<__m256>::Act kActAvx[] = { static Active<__m256>::Act kActAvx[] = {
&forward::avx::Sigmoid, &forward::avx::Relu, &forward::avx::Tanh, &forward::avx::Sigmoid, &forward::avx::SigmoidV2, &forward::avx::Relu,
&forward::avx::Identity}; &forward::avx::Tanh, &forward::avx::TanhV2, &forward::avx::Identity};
static Active<__m256>::ActGrad kActGradAvx[] = { static Active<__m256>::ActGrad kActGradAvx[] = {
&backward::avx::Sigmoid, &backward::avx::Relu, &backward::avx::Tanh, &backward::avx::Sigmoid, &backward::avx::Sigmoid, &backward::avx::Relu,
&backward::avx::Identity}; &backward::avx::Tanh, &backward::avx::Tanh, &backward::avx::Identity};
namespace forward { namespace forward {
inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); }
......
...@@ -43,6 +43,13 @@ __m256 Sigmoid(const __m256 a) { ...@@ -43,6 +43,13 @@ __m256 Sigmoid(const __m256 a) {
return tmp; return tmp;
} }
__m256 SigmoidV2(const __m256 a) {
__m256 tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), a);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), exp256_ps(tmp));
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
}
__m256 Tanh(const __m256 a) { __m256 Tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT); __m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
...@@ -53,6 +60,14 @@ __m256 Tanh(const __m256 a) { ...@@ -53,6 +60,14 @@ __m256 Tanh(const __m256 a) {
_mm256_set1_ps(1.0f)); _mm256_set1_ps(1.0f));
} }
__m256 TanhV2(const __m256 a) {
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
return _mm256_sub_ps(
_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), exp256_ps(tmp))),
_mm256_set1_ps(1.0f));
}
__m256 Identity(const __m256 a) { return a; } __m256 Identity(const __m256 a) { return a; }
} // namespace avx } // namespace avx
......
...@@ -25,26 +25,38 @@ namespace detail { ...@@ -25,26 +25,38 @@ namespace detail {
#ifndef __NVCC__ #ifndef __NVCC__
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, void hl_naive_gru_forward_reset_output(
T *gate_value, T *reset_output_value, OpResetOutput op_reset_output, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, const T *prev_output_value, int frame_size, ActivationType active_gate,
ActivationType active_gate) { bool old_version = true, const T *reset_bias = nullptr) {
T r_value_update_gate; T r_value_update_gate;
T r_value_reset_gate; T r_value_reset_gate;
T r_value_reset_output; T r_value_reset_output;
T r_prev_out = 0; T r_prev_out = 0;
T *update_gate = gate_value; T r_reset_bias = 0;
T *reset_gate = gate_value + frame_size; T *update_gate = nullptr;
T *reset_gate = nullptr;
if (old_version) {
update_gate = gate_value;
reset_gate = gate_value + frame_size;
} else {
reset_gate = gate_value;
update_gate = gate_value + frame_size;
}
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
r_value_update_gate = update_gate[i]; r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i]; r_value_reset_gate = reset_gate[i];
if (!old_version) {
r_value_reset_output = reset_output_value[i];
r_reset_bias = reset_bias[i];
}
if (prev_output_value) { if (prev_output_value) {
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
&r_value_reset_output, active_gate); &r_value_reset_output, active_gate, &r_reset_bias,
old_version);
update_gate[i] = r_value_update_gate; update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate; reset_gate[i] = r_value_reset_gate;
...@@ -53,16 +65,20 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -53,16 +65,20 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
} }
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, void hl_naive_gru_forward_final_output(
T *gate_value, T *prev_output_value, OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size, ActivationType active_node,
ActivationType active_node, bool origin_mode, bool old_version = true) {
bool origin_mode) {
T r_value_update_gate; T r_value_update_gate;
T r_value_frame_state; T r_value_frame_state;
T r_prev_out = 0; T r_prev_out = 0;
T r_output; T r_output;
T *update_gate = gate_value; T *update_gate;
if (old_version) {
update_gate = gate_value;
} else {
update_gate = gate_value + frame_size;
}
T *frame_state = gate_value + frame_size * 2; T *frame_state = gate_value + frame_size * 2;
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
...@@ -83,16 +99,26 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -83,16 +99,26 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, const T *prev_output_value, int frame_size,
ActivationType active_gate) { ActivationType active_gate,
bool old_version = true,
const T *reset_bias = nullptr) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_reset_output; __m256 r_value_reset_output;
__m256 r_prev_out = _mm256_set1_ps(0.0f), __m256 r_prev_out = _mm256_set1_ps(0.0f),
r_prev_out_last = _mm256_set1_ps(0.0f); r_prev_out_last = _mm256_set1_ps(0.0f);
T *update_gate = gate_value; __m256 r_reset_bias = _mm256_set1_ps(0.0f);
T *reset_gate = gate_value + frame_size; T *update_gate;
T *reset_gate;
if (old_version) {
update_gate = gate_value;
reset_gate = gate_value + frame_size;
} else {
reset_gate = gate_value;
update_gate = gate_value + frame_size;
}
int block = 8; int block = 8;
const int n = frame_size; const int n = frame_size;
const int rest = n % block; const int rest = n % block;
...@@ -115,9 +141,15 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -115,9 +141,15 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
if (prev_output_value) { if (prev_output_value) {
r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i)); r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
} }
if (!old_version) {
r_reset_bias = _mm256_loadu_ps((const float *)(reset_bias + i));
r_value_reset_output =
_mm256_loadu_ps((const float *)(reset_output_value + i));
}
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
&r_value_reset_output, active_gate); &r_value_reset_output, active_gate, &r_reset_bias,
old_version);
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i), _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
r_value_update_gate); r_value_update_gate);
...@@ -131,7 +163,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -131,7 +163,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
i = n - block; i = n - block;
op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last, op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last,
&r_prev_out_last, &r_value_reset_output, active_gate); &r_prev_out_last, &r_value_reset_output, active_gate,
&r_reset_bias, old_version);
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i), _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
r_value_update_gate_last); r_value_update_gate_last);
...@@ -145,17 +178,24 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -145,17 +178,24 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, const T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
ActivationType active_node, ActivationType active_node,
bool origin_mode) { bool origin_mode,
bool old_version = true) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
__m256 r_prev_out = _mm256_set1_ps(0.0f), __m256 r_prev_out = _mm256_set1_ps(0.0f),
r_prev_out_last = _mm256_set1_ps(0.0f); r_prev_out_last = _mm256_set1_ps(0.0f);
__m256 r_output; __m256 r_output;
T *update_gate = gate_value; T *update_gate;
if (old_version) {
update_gate = gate_value;
} else {
update_gate = gate_value + frame_size;
}
T *frame_state = gate_value + frame_size * 2; T *frame_state = gate_value + frame_size * 2;
int block = 8; int block = 8;
const int n = frame_size; const int n = frame_size;
...@@ -205,19 +245,21 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -205,19 +245,21 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output, inline void forward_reset_output(OpResetOutput op_reset_output,
GRUMetaValue<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate) { int batch_size, ActivationType active_gate,
bool old_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) { (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output( hl_avx_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value, op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate); value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} else { } else {
hl_naive_gru_forward_reset_output( hl_naive_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value, op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate); value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.reset_output_value += frame_size; value.reset_output_value += frame_size;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -230,17 +272,19 @@ template <class OpFinalOutput, typename T> ...@@ -230,17 +272,19 @@ template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output, inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node, int batch_size, ActivationType active_node,
bool origin_mode) { bool origin_mode, bool old_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) { (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value, hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value, value.prev_out_value, value.output_value,
frame_size, active_node, origin_mode); frame_size, active_node, origin_mode,
old_version);
} else { } else {
hl_naive_gru_forward_final_output( hl_naive_gru_forward_final_output(op_final_output, value.gate_value,
op_final_output, value.gate_value, value.prev_out_value, value.prev_out_value,
value.output_value, frame_size, active_node, origin_mode); value.output_value, frame_size,
active_node, origin_mode, old_version);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
...@@ -253,7 +297,7 @@ inline void forward_final_output(OpFinalOutput op_final_output, ...@@ -253,7 +297,7 @@ inline void forward_final_output(OpFinalOutput op_final_output,
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size,
ActivationType active_node, ActivationType active_node,
...@@ -295,7 +339,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -295,7 +339,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
ActivationType active_gate) { ActivationType active_gate) {
...@@ -340,7 +384,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -340,7 +384,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, ActivationType active_node, int frame_size, ActivationType active_node,
bool origin_mode) { bool origin_mode) {
...@@ -364,7 +408,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -364,7 +408,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
r_frame_state_value = frame_state_value[i]; r_frame_state_value = frame_state_value[i];
r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i]; r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i];
if (prev_out_value) { if (prev_out_value) {
r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i]; r_prev_out_value = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
} }
if (prev_out_grad) { if (prev_out_grad) {
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
...@@ -385,7 +429,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -385,7 +429,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
ActivationType active_gate) { ActivationType active_gate) {
...@@ -412,7 +456,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -412,7 +456,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i]; r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
} }
if (prev_out_value) { if (prev_out_value) {
r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i]; r_prev_out_value = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
} }
if (prev_out_grad) { if (prev_out_grad) {
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
...@@ -431,6 +475,135 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -431,6 +475,135 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
#endif #endif
} }
template <class OpGruGrad, typename T>
inline void hl_naive_gru_backward(OpGruGrad op_gru_grad, T *gate_value,
T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_value,
T *reset_output_grad, T *output_grad,
int frame_size, ActivationType active_node,
ActivationType active_gate) {
T r_value_reset_gate;
T r_grad_reset_gate;
T r_value_update_gate;
T r_grad_update_gate;
T r_value_frame_state;
T r_grad_frame_state;
T r_value_prev_out = 0;
T r_grad_prev_out = 0;
T r_grad_output;
T r_value_reset_output;
T r_grad_reset_output = 0;
T *reset_gate_value = gate_value;
T *reset_gate_grad = gate_grad;
T *update_gate_value = gate_value + frame_size;
T *update_gate_grad = gate_grad + frame_size;
T *frame_state_value = gate_value + 2 * frame_size;
T *frame_state_grad = gate_grad + 2 * frame_size;
for (int i = 0; i < frame_size; ++i) {
r_value_reset_gate = reset_gate_value[i];
r_grad_reset_gate = reset_gate_grad[i];
r_value_update_gate = update_gate_value[i];
r_grad_update_gate = update_gate_grad[i];
r_value_frame_state = frame_state_value[i];
r_grad_frame_state = frame_state_grad[i];
if (prev_out_value) {
r_value_prev_out = prev_out_value[i];
}
if (prev_out_grad) {
r_grad_prev_out = prev_out_grad[i];
}
r_grad_output = output_grad[i];
r_value_reset_output = reset_output_value[i];
if (prev_out_value && prev_out_grad) {
r_grad_reset_output = reset_output_grad[i];
}
op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate,
&r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state,
&r_value_prev_out, &r_grad_prev_out, &r_grad_output,
&r_value_reset_output, &r_grad_reset_output, active_node,
active_gate);
reset_gate_grad[i] = r_grad_reset_gate;
update_gate_grad[i] = r_grad_update_gate;
frame_state_grad[i] = r_grad_frame_state;
if (prev_out_grad) {
prev_out_grad[i] = r_grad_prev_out;
}
if (prev_out_value && prev_out_grad) {
reset_output_grad[i] = r_grad_reset_output;
}
}
}
template <class OpGruGrad, typename T>
inline void hl_avx_gru_backward(OpGruGrad op_gru_grad, T *gate_value,
T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_value,
T *reset_output_grad, T *output_grad,
int frame_size, ActivationType active_node,
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_value_reset_gate;
__m256 r_grad_reset_gate;
__m256 r_value_update_gate;
__m256 r_grad_update_gate;
__m256 r_value_frame_state;
__m256 r_grad_frame_state;
__m256 r_value_prev_out = _mm256_set1_ps(0.0f);
__m256 r_grad_prev_out = _mm256_set1_ps(0.0f);
__m256 r_grad_output;
__m256 r_value_reset_output;
__m256 r_grad_reset_output = _mm256_set1_ps(0.0f);
__m256 *reset_gate_value = reinterpret_cast<__m256 *>(gate_value);
__m256 *reset_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
__m256 *update_gate_value =
reinterpret_cast<__m256 *>(gate_value + frame_size);
__m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad + frame_size);
__m256 *frame_state_value =
reinterpret_cast<__m256 *>(gate_value + 2 * frame_size);
__m256 *frame_state_grad =
reinterpret_cast<__m256 *>(gate_grad + 2 * frame_size);
for (int i = 0; i < frame_size / 8; ++i) {
r_value_reset_gate = reset_gate_value[i];
r_grad_reset_gate = reset_gate_grad[i];
r_value_update_gate = update_gate_value[i];
r_grad_update_gate = update_gate_grad[i];
r_value_frame_state = frame_state_value[i];
r_grad_frame_state = frame_state_grad[i];
if (prev_out_value) {
r_value_prev_out = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
}
if (prev_out_grad) {
r_grad_prev_out = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
}
r_grad_output = (reinterpret_cast<__m256 *>(output_grad))[i];
r_value_reset_output = (reinterpret_cast<__m256 *>(reset_output_value))[i];
if (prev_out_value && prev_out_grad) {
r_grad_reset_output = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
}
op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate,
&r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state,
&r_value_prev_out, &r_grad_prev_out, &r_grad_output,
&r_value_reset_output, &r_grad_reset_output, active_node,
active_gate);
reset_gate_grad[i] = r_grad_reset_gate;
update_gate_grad[i] = r_grad_update_gate;
frame_state_grad[i] = r_grad_frame_state;
if (prev_out_grad) {
(reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_grad_prev_out;
}
if (prev_out_value && prev_out_grad) {
(reinterpret_cast<__m256 *>(reset_output_grad))[i] = r_grad_reset_output;
}
}
#endif
}
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad, inline void backward_state_grad(OpStateGrad op_state_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
...@@ -491,6 +664,39 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, ...@@ -491,6 +664,39 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad,
} }
} }
template <class OpGruGrad, typename T>
inline void cpu_gru_backward(OpGruGrad op_gru_grad, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size,
int batch_size, ActivationType active_node,
ActivationType active_gate) {
for (int b = 0; b < batch_size; ++b) {
if (OpGruGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward(
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
} else {
hl_naive_gru_backward(
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
}
value.gate_value += frame_size * 3;
value.reset_output_value += frame_size;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
grad.gate_grad += frame_size * 3;
grad.output_grad += frame_size;
grad.reset_output_grad += frame_size;
if (grad.prev_out_grad) {
grad.prev_out_grad += frame_size;
}
}
}
#endif #endif
} // namespace detail } // namespace detail
......
...@@ -31,8 +31,8 @@ namespace detail { ...@@ -31,8 +31,8 @@ namespace detail {
template <class OpResetOutput, bool is_batch, typename T> template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, const T *prev_output_value,
int batch_size, int frame_size, int batch_size,
ActivationType active_gate) { ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -68,12 +68,10 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, ...@@ -68,12 +68,10 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
* grid(frame_blocks, batch_blocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpFinalOutput, bool is_batch, typename T> template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, __global__ void KeGruForwardFinalOutput(
T *gate_value, T *prev_output_value, OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size, int batch_size, ActivationType active_node,
int batch_size, bool origin_mode) {
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -106,8 +104,9 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -106,8 +104,9 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
* grid(frame_blocks, 1) * grid(frame_blocks, 1)
*/ */
template <class T, int Tiled_size> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, __global__ void KeFastCollectiveGruGate(T *gate_value,
T *gate_weight, T *reset_output, const T *prev_output_value,
const T *gate_weight, T *reset_output,
int frame_size, int frame_size,
ActivationType active_node) { ActivationType active_node) {
T xt_0 = 0.0f; T xt_0 = 0.0f;
...@@ -164,10 +163,10 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, ...@@ -164,10 +163,10 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
* grid(frame_blocks, 1) * grid(frame_blocks, 1)
*/ */
template <class T, int Tiled_size> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, __global__ void KeFastCollectiveGruOut(const T *gate_weight,
T *output_value, T *gate_value, const T *prev_out_value, T *output_value,
T *reset_value, int frame_size, T *gate_value, T *reset_value,
ActivationType act_node, int frame_size, ActivationType act_node,
bool origin_mode) { bool origin_mode) {
int COL = blockIdx.x * blockDim.x + threadIdx.x; int COL = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -223,7 +222,7 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, ...@@ -223,7 +222,7 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
*/ */
template <class OpStateGrad, bool is_batch, typename T> template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_node,
...@@ -272,7 +271,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -272,7 +271,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
*/ */
template <class OpResetGrad, bool is_batch, typename T> template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, const T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_gate) { ActivationType active_gate) {
......
...@@ -30,10 +30,17 @@ class gru_resetOutput { ...@@ -30,10 +30,17 @@ class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate, HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate,
T *prev_out, T *value_reset_output, T *prev_out, T *value_reset_output,
ActivationType act_gate) { ActivationType act_gate,
T *value_reset_bias = nullptr,
bool old_version = true) {
*value_update_gate = activation(*value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
*value_reset_gate = activation(*value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
*value_reset_output = (*prev_out) * (*value_reset_gate); if (old_version) {
*value_reset_output = (*prev_out) * (*value_reset_gate);
} else {
*value_reset_output =
(*value_reset_output + *value_reset_bias) * (*value_reset_gate);
}
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
...@@ -43,10 +50,19 @@ class gru_resetOutput { ...@@ -43,10 +50,19 @@ class gru_resetOutput {
HOSTDEVICE void operator()(__m256 *value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_reset_gate, __m256 *prev_out, __m256 *value_reset_gate, __m256 *prev_out,
__m256 *value_reset_output, __m256 *value_reset_output,
ActivationType act_gate) { ActivationType act_gate,
__m256 *value_reset_bias = nullptr,
bool old_version = true) {
*value_update_gate = activation(*value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
*value_reset_gate = activation(*value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
*value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate); if (old_version) {
*value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate);
} else {
*value_reset_output =
_mm256_add_ps(*value_reset_output, *value_reset_bias);
*value_reset_output =
_mm256_mul_ps(*value_reset_output, *value_reset_gate);
}
} }
#endif #endif
#endif #endif
...@@ -192,6 +208,61 @@ class gru_resetGrad { ...@@ -192,6 +208,61 @@ class gru_resetGrad {
#endif #endif
#endif #endif
}; };
template <typename T>
class gru {
public:
HOSTDEVICE void operator()(T *value_reset_gate, T *grad_reset_gate,
T *value_update_gate, T *grad_update_gate,
T *value_frame_state, T *grad_frame_state,
T *value_prev_out, T *grad_prev_out,
T *grad_output, T *value_reset_output,
T *grad_reset_output, ActivationType act_node,
ActivationType act_gate) {
*grad_update_gate =
activation((*grad_output) * ((*value_prev_out) - (*value_frame_state)),
(*value_update_gate), act_gate);
*grad_prev_out += (*grad_output * (*value_update_gate));
*grad_frame_state =
activation(*grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_node);
T reset_output = (*value_reset_output) / (*value_reset_gate);
*grad_reset_gate = activation(reset_output * (*grad_frame_state),
*value_reset_gate, act_gate);
*grad_reset_output = (*value_reset_gate) * (*grad_frame_state);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_reset_gate, __m256 *grad_reset_gate,
__m256 *value_update_gate,
__m256 *grad_update_gate,
__m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_output,
__m256 *value_reset_output,
__m256 *grad_reset_output, ActivationType act_node,
ActivationType act_gate) {
*grad_update_gate = activation(
_mm256_mul_ps(*grad_output,
_mm256_sub_ps(*value_prev_out, *value_frame_state)),
*value_update_gate, act_gate);
*grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
*grad_frame_state = activation(
_mm256_mul_ps(*grad_output,
_mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)),
*value_frame_state, act_node);
__m256 reset_output = _mm256_div_ps(*value_reset_output, *value_reset_gate);
*grad_reset_gate =
activation(_mm256_mul_ps(reset_output, *grad_frame_state),
*value_reset_gate, act_gate);
*grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state);
}
#endif
#endif
};
} // namespace backward } // namespace backward
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
...@@ -28,6 +30,11 @@ namespace operators { ...@@ -28,6 +30,11 @@ namespace operators {
namespace math { namespace math {
namespace detail { namespace detail {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__ #ifndef __NVCC__
template <class T, class Op> template <class T, class Op>
...@@ -35,7 +42,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -35,7 +42,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, T cell_clip, int frame_size, T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -48,10 +56,15 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -48,10 +56,15 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
T r_state_atv; T r_state_atv;
T r_out; T r_out;
T *value_in = value.gate_value; T *value_ig = value.gate_value;
T *value_ig = value.gate_value + frame_size; T *value_fg = value.gate_value + frame_size;
T *value_fg = value.gate_value + frame_size * 2; T *value_in = value.gate_value + frame_size * 2;
T *value_og = value.gate_value + frame_size * 3; T *value_og = value.gate_value + frame_size * 3;
if (old_api_version) {
value_in = value.gate_value;
value_ig = value.gate_value + frame_size;
value_fg = value.gate_value + frame_size * 2;
}
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -85,7 +98,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -85,7 +98,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
T cell_clip, ActivationType active_node, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -107,14 +121,25 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -107,14 +121,25 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
T r_checkFGrad; T r_checkFGrad;
T r_checkOGrad; T r_checkOGrad;
T *value_in = value.gate_value; T *value_ig = value.gate_value;
T *value_ig = value.gate_value + frame_size; T *value_fg = value.gate_value + frame_size;
T *value_fg = value.gate_value + frame_size * 2; T *value_in = value.gate_value + frame_size * 2;
T *value_og = value.gate_value + frame_size * 3; T *value_og = value.gate_value + frame_size * 3;
T *grad_in = grad.gate_grad; if (old_api_version) {
T *grad_ig = grad.gate_grad + frame_size; value_in = value.gate_value;
T *grad_fg = grad.gate_grad + frame_size * 2; value_ig = value.gate_value + frame_size;
value_fg = value.gate_value + frame_size * 2;
}
T *grad_ig = grad.gate_grad;
T *grad_fg = grad.gate_grad + frame_size;
T *grad_in = grad.gate_grad + frame_size * 2;
T *grad_og = grad.gate_grad + frame_size * 3; T *grad_og = grad.gate_grad + frame_size * 3;
if (old_api_version) {
grad_in = grad.gate_grad;
grad_ig = grad.gate_grad + frame_size;
grad_fg = grad.gate_grad + frame_size * 2;
}
for (int i = 0; i < frame_size; i++) { for (int i = 0; i < frame_size; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -158,7 +183,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -158,7 +183,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, T cell_clip, int frame_size, T cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -172,12 +198,17 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -172,12 +198,17 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 r_state_atv; __m256 r_state_atv;
__m256 r_out; __m256 r_out;
__m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); __m256 *value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
__m256 *value_fg = __m256 *value_in =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
__m256 *value_og = __m256 *value_og =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3);
if (old_api_version) {
value_in = reinterpret_cast<__m256 *>(value.gate_value);
value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
}
for (int i = 0; i < frame_size / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -191,7 +222,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -191,7 +222,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
} }
if (value.prev_state_value) { if (value.prev_state_value) {
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state =
(reinterpret_cast<__m256 const *>(value.prev_state_value))[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
...@@ -214,7 +246,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -214,7 +246,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
T cell_clip, ActivationType active_node, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state,
bool old_api_version) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -237,16 +270,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -237,16 +270,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 r_checkFGrad; __m256 r_checkFGrad;
__m256 r_checkOGrad; __m256 r_checkOGrad;
__m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); __m256 *value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
__m256 *value_fg = __m256 *value_in =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
__m256 *value_og = __m256 *value_og =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3);
__m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad); if (old_api_version) {
__m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size); value_in = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2); value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
}
__m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad);
__m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size);
__m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2);
__m256 *grad_og = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 3); __m256 *grad_og = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 3);
if (old_api_version) {
grad_in = reinterpret_cast<__m256 *>(grad.gate_grad);
grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size);
grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2);
}
for (int i = 0; i < frame_size / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i]; r_value_in = value_in[i];
...@@ -263,7 +307,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -263,7 +307,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_output_grad = (reinterpret_cast<__m256 *>(grad.output_grad))[i]; r_output_grad = (reinterpret_cast<__m256 *>(grad.output_grad))[i];
r_state_grad = (reinterpret_cast<__m256 *>(grad.state_grad))[i]; r_state_grad = (reinterpret_cast<__m256 *>(grad.state_grad))[i];
if (value.prev_state_value) { if (value.prev_state_value) {
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state =
(reinterpret_cast<__m256 const *>(value.prev_state_value))[i];
} }
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
...@@ -292,30 +337,133 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -292,30 +337,133 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
#endif #endif
} }
template <class T>
void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context,
LstmMetaValue<T> value, int frame_size) {
auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto eigen_value_in = typename EigenVector<T>::Type(
value.gate_value + frame_size * 2, Array1(frame_size));
auto eigen_value_og = typename EigenVector<T>::Type(
value.gate_value + frame_size * 3, Array1(frame_size));
auto eigen_state =
typename EigenVector<T>::Type(value.state_value, Array1(frame_size));
auto eigen_state_act = typename EigenVector<T>::Type(value.state_active_value,
Array1(frame_size));
auto eigen_output =
typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
auto &place = *context.eigen_device();
TanhFunctor<T>()(place, eigen_value_in, eigen_value_in);
SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig);
SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg);
SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og);
eigen_state.device(place) = eigen_value_in * eigen_value_ig;
if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size));
eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg;
}
TanhFunctor<T>()(place, eigen_state, eigen_state_act);
eigen_output.device(place) = eigen_value_og * eigen_state_act;
}
template <class T>
void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context,
LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size) {
auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto eigen_value_in = typename EigenVector<T>::Type(
value.gate_value + frame_size * 2, Array1(frame_size));
auto eigen_value_og = typename EigenVector<T>::Type(
value.gate_value + frame_size * 3, Array1(frame_size));
auto eigen_state_act = typename EigenVector<T>::Type(value.state_active_value,
Array1(frame_size));
auto eigen_grad_ig =
typename EigenVector<T>::Type(grad.gate_grad, Array1(frame_size));
auto eigen_grad_fg = typename EigenVector<T>::Type(
grad.gate_grad + frame_size, Array1(frame_size));
auto eigen_grad_in = typename EigenVector<T>::Type(
grad.gate_grad + frame_size * 2, Array1(frame_size));
auto eigen_grad_og = typename EigenVector<T>::Type(
grad.gate_grad + frame_size * 3, Array1(frame_size));
auto eigen_grad_output =
typename EigenVector<T>::Type(grad.output_grad, Array1(frame_size));
auto eigen_grad_state =
typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size));
auto &place = *context.eigen_device();
SigmoidGradFunctor<T>()(place, 1 /*useless*/, eigen_value_og,
eigen_grad_output * eigen_state_act, eigen_grad_og);
eigen_grad_state.device(place) =
eigen_grad_state +
eigen_grad_output * eigen_value_og *
(static_cast<T>(1) - eigen_state_act * eigen_state_act);
TanhGradFunctor<T>()(place, 1, eigen_value_in,
eigen_grad_state * eigen_value_ig, eigen_grad_in);
SigmoidGradFunctor<T>()(place, 1, eigen_value_ig,
eigen_grad_state * eigen_value_in, eigen_grad_ig);
if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size));
SigmoidGradFunctor<T>()(place, 1, eigen_value_fg,
eigen_grad_state * eigen_prev_state, eigen_grad_fg);
} else {
SigmoidGradFunctor<T>()(place, 1, eigen_value_fg, 0, eigen_grad_fg);
}
if (grad.prev_state_grad) {
auto eigen_grad_pre_state =
typename EigenVector<T>::Type(grad.prev_state_grad, Array1(frame_size));
eigen_grad_pre_state.device(place) = eigen_grad_state * eigen_value_fg;
}
}
template <class T, class Op> template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size, void cpu_lstm_forward(const platform::CPUDeviceContext &context, Op op,
T cell_clip, ActivationType active_node, LstmMetaValue<T> value, int frame_size, T cell_clip,
ActivationType active_gate, ActivationType active_state) { ActivationType active_node, ActivationType active_gate,
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { ActivationType active_state, bool old_api_version) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip, if (!old_api_version) {
active_node, active_gate, active_state); eigen_lstm_forward_one_sequence<T>(context, value, frame_size);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip, if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
active_node, active_gate, active_state); avx_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_node, active_gate, active_state,
old_api_version);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_node, active_gate, active_state,
old_api_version);
}
} }
} }
template <class T, class Op> template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, T cell_clip, ActivationType active_node, int frame_size, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate, ActivationType active_state,
ActivationType active_state) { bool old_api_version) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (!old_api_version) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip, eigen_lstm_backward_one_sequence<T>(context, value, grad, frame_size);
active_node, active_gate, active_state);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip, if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
active_node, active_gate, active_state); avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip,
active_node, active_gate, active_state,
old_api_version);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size,
cell_clip, active_node, active_gate,
active_state, old_api_version);
}
} }
} }
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h"
...@@ -101,11 +102,64 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -101,11 +102,64 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template <typename T>
struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, frame_size, frame_size, 1,
value.prev_out_value, value.state_weight, 0,
value.reset_output_value);
}
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate, false);
T *cell_state_value = value.gate_value + 2 * frame_size;
T *reset_output_value = value.reset_output_value;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(frame_size, cell_state_value, reset_output_value,
cell_state_value);
cell_state_value += frame_size * 3;
reset_output_value += frame_size;
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node, true,
false);
#endif
}
};
template <typename T>
struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
// calculate grad_update_gate, grad_frame_state,
// grad_reset_output, grad_reset_gate
detail::cpu_gru_backward(detail::backward::gru<T>(), value, grad,
frame_size, batch_size, active_node, active_gate);
#endif
}
};
template struct GRUUnitFunctor<platform::CPUDeviceContext, float>; template struct GRUUnitFunctor<platform::CPUDeviceContext, float>;
template struct GRUUnitFunctor<platform::CPUDeviceContext, double>; template struct GRUUnitFunctor<platform::CPUDeviceContext, double>;
template struct GRUUnitGradFunctor<platform::CPUDeviceContext, float>; template struct GRUUnitGradFunctor<platform::CPUDeviceContext, float>;
template struct GRUUnitGradFunctor<platform::CPUDeviceContext, double>; template struct GRUUnitGradFunctor<platform::CPUDeviceContext, double>;
template struct GRUUnitFunctorV2<platform::CPUDeviceContext, float>;
template struct GRUUnitFunctorV2<platform::CPUDeviceContext, double>;
template struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, float>;
template struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -21,12 +21,13 @@ namespace math { ...@@ -21,12 +21,13 @@ namespace math {
template <typename T> template <typename T>
struct GRUMetaValue { struct GRUMetaValue {
T *gate_weight; const T *gate_weight;
T *state_weight; const T *state_weight;
const T *reset_bias;
T *gate_value; T *gate_value;
T *reset_output_value; T *reset_output_value;
T *output_value; T *output_value;
T *prev_out_value; const T *prev_out_value;
}; };
template <typename T> template <typename T>
...@@ -37,6 +38,7 @@ struct GRUMetaGrad { ...@@ -37,6 +38,7 @@ struct GRUMetaGrad {
T *reset_output_grad; T *reset_output_grad;
T *output_grad; T *output_grad;
T *prev_out_grad; T *prev_out_grad;
T *state_bias_grad;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -57,6 +59,22 @@ struct GRUUnitGradFunctor { ...@@ -57,6 +59,22 @@ struct GRUUnitGradFunctor {
bool origin_mode); bool origin_mode);
}; };
template <typename DeviceContext, typename T>
struct GRUUnitFunctorV2 {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
template <typename DeviceContext, typename T>
struct GRUUnitGradFunctorV2 {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -33,10 +33,12 @@ struct LstmUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -33,10 +33,12 @@ struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
T cell_clip, const detail::ActivationType& gate_act, T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size, detail::cpu_lstm_forward(context, detail::forward::lstm<T>(), value,
cell_clip, cand_act, gate_act, cell_act); frame_size, cell_clip, cand_act, gate_act,
cell_act, old_api_version);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
value.state_active_value += frame_size; value.state_active_value += frame_size;
...@@ -55,11 +57,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -55,11 +57,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
int frame_size, int batch_size, T cell_clip, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad, detail::cpu_lstm_backward(context, detail::backward::lstm<T>(), value,
frame_size, cell_clip, cand_act, gate_act, grad, frame_size, cell_clip, cand_act, gate_act,
cell_act); cell_act, old_api_version);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
......
...@@ -26,7 +26,8 @@ struct LstmUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -26,7 +26,8 @@ struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
T cell_clip, const detail::ActivationType& gate_act, T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value, detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, cell_clip, cand_act, frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act); gate_act, cell_act);
...@@ -40,7 +41,8 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -40,7 +41,8 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
int frame_size, int batch_size, T cell_clip, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act,
bool old_api_version = true) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad, detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, cell_clip, cand_act, frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act); gate_act, cell_act);
......
...@@ -25,7 +25,7 @@ namespace math { ...@@ -25,7 +25,7 @@ namespace math {
template <class T> template <class T>
struct LstmMetaValue { struct LstmMetaValue {
T *gate_value; T *gate_value;
T *prev_state_value; const T *prev_state_value;
T *state_value; T *state_value;
T *state_active_value; T *state_active_value;
T *output_value; T *output_value;
...@@ -53,7 +53,8 @@ class LstmUnitFunctor { ...@@ -53,7 +53,8 @@ class LstmUnitFunctor {
int frame_size, int batch_size, T cell_clip, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType &gate_act, const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act, const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act); const detail::ActivationType &cand_act,
bool old_api_version = true);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -63,7 +64,8 @@ class LstmUnitGradFunctor { ...@@ -63,7 +64,8 @@ class LstmUnitGradFunctor {
LstmMetaGrad<T> grad, int frame_size, int batch_size, LstmMetaGrad<T> grad, int frame_size, int batch_size,
T cell_clip, const detail::ActivationType &gate_act, T cell_clip, const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act, const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act); const detail::ActivationType &cand_act,
bool old_api_version = true);
}; };
} // namespace math } // namespace math
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/rnn_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -251,5 +252,10 @@ REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker, ...@@ -251,5 +252,10 @@ REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker,
ops::RNNGradOpMaker<paddle::imperative::OpBase>); ops::RNNGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp); REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp);
REGISTER_OP_CPU_KERNEL(rnn, ops::NotImpleKernel<float>); REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(rnn_grad, ops::NotImpleKernel<float>); rnn, ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
rnn_grad, ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -524,6 +524,12 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> { ...@@ -524,6 +524,12 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
offset += len; offset += len;
} }
Tensor input_grad_value;
if (!in_grad) {
in_grad = &input_grad_value;
in_grad->Resize(input->dims());
}
auto *init_h_data = pre_state[0]->data<T>(); auto *init_h_data = pre_state[0]->data<T>();
// auto *last_h_data = state[0]->data<T>(); // auto *last_h_data = state[0]->data<T>();
auto *last_h_grad_data = state_grad[0]->data<T>(); auto *last_h_grad_data = state_grad[0]->data<T>();
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
using TensorList = std::vector<framework::Tensor>;
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \
const std::string& mode = ctx.Attr<std::string>("mode"); \
return mode == #MODE_STR; \
}
DEFINE_MODE_DETECTOR(lstm, LSTM);
DEFINE_MODE_DETECTOR(gru, GRU);
DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU);
DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH);
void SwapPoniter(Tensor** a, Tensor** b) {
Tensor* c = *a;
*a = *b;
*b = c;
}
template <typename T>
void create_mask_matrix(const framework::ExecutionContext& context,
const Tensor* sequence_length, Tensor* mask_matrix,
const bool& is_reverse, int* min_seq_len) {
const auto& seq_len_vec = GetDataFromTensor<int>(sequence_length);
const int& table_width = mask_matrix->dims()[0];
Tensor temp;
temp.Resize(
framework::make_ddim({mask_matrix->dims()[1], mask_matrix->dims()[0]}));
T* data_temp = temp.mutable_data<T>(context.GetPlace());
std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast<T>(1.0));
*min_seq_len = table_width;
for (unsigned int i = 0; i < seq_len_vec.size(); i++) {
// reset the mask matrix
*min_seq_len = std::min(seq_len_vec[i], *min_seq_len);
if (seq_len_vec[i] == table_width) {
continue;
}
if (is_reverse) {
std::fill(data_temp + i * table_width,
data_temp + (i + 1) * table_width - seq_len_vec[i],
static_cast<T>(0));
} else {
std::fill(data_temp + i * table_width + seq_len_vec[i],
data_temp + (i + 1) * table_width, static_cast<T>(0));
}
}
mask_matrix->mutable_data<T>(context.GetPlace());
std::vector<int> trans_vec;
trans_vec.emplace_back(1);
trans_vec.emplace_back(0);
auto& dev_ctx = context.template device_context<platform::CPUDeviceContext>();
TransCompute<platform::CPUDeviceContext, T>(2, dev_ctx, temp, mask_matrix,
trans_vec);
}
template <typename T>
struct Cell {
virtual ~Cell() {}
virtual void operator()(const platform::CPUDeviceContext* device_ctx,
Tensor* input, const Tensor* weight_hh,
const Tensor* init_h, const Tensor* init_c,
Tensor* last_h, Tensor* last_c, Tensor* last_c_act,
Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const {}
};
template <typename T, template <typename> class EigenActivationFunctor,
math::detail::ActivationType act_type>
struct SimpleRNNCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
input, static_cast<T>(1.0));
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(input, "Input", "z", "Activation"));
auto hidden = EigenVector<T>::Flatten(
GET_DATA_SAFELY(output, "Output", "hidden", "Activation"));
auto* place = device_ctx->eigen_device();
EigenActivationFunctor<T> functor;
functor(*place, z, hidden);
}
};
template <typename T>
struct GRUCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
math::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh_gru, mat_dim_b,
static_cast<T>(1.0), input, static_cast<T>(1.0));
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_hh->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.reset_bias = bias_hh->data<T>() + 2 * frame_size;
gru_value.gate_value = input->data<T>();
gru_value.reset_output_value = last_c->data<T>();
gru_value.output_value = output->data<T>();
gru_value.prev_out_value = init_h->data<T>();
auto gate_act = math::detail::GetActivationType("sigmoid_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2");
math::GRUUnitFunctorV2<platform::CPUDeviceContext, T>::compute(
*device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act);
}
};
template <typename T>
struct LSTMCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
input, static_cast<T>(1.0));
math::LstmMetaValue<T> lstm_value;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
auto gate_act = math::detail::GetActivationType("sigmoid_v2");
auto cell_act = math::detail::GetActivationType("tanh_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2");
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
Tensor cell_pre_act;
if (last_c_act == nullptr) { /* is test */
cell_pre_act.mutable_data<T>(init_h->dims(), device_ctx->GetPlace());
last_c_act = &cell_pre_act;
}
lstm_value.prev_state_value = init_c->data<T>();
lstm_value.gate_value = input->data<T>();
lstm_value.output_value = output->data<T>();
lstm_value.state_value = last_c->data<T>();
lstm_value.state_active_value = last_c_act->data<T>();
T cell_clip = 0.0;
math::LstmUnitFunctor<platform::CPUDeviceContext, T>::compute(
*device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act,
cell_act, cand_act, false);
}
};
template <typename T>
void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
Tensor* x, Tensor* mask,
const float& dropout_prob,
const int& seed_number, const bool& is_test,
bool* is_has_reset) {
if (is_test) {
return;
}
auto* x_data = x->data<T>();
size_t size = framework::product(x->dims());
auto* mask_data = mask->data<uint8_t>();
if (!(*is_has_reset)) {
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::fill(x_data, x_data + size, static_cast<T>(0));
std::fill(mask_data, mask_data + size, static_cast<T>(0));
*is_has_reset = true;
return;
}
auto engine = framework::GetCPURandomEngine(seed_number);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
x_data[i] = static_cast<T>(0);
} else {
mask_data[i] = 1;
x_data[i] /= static_cast<T>(1.0f - dropout_prob);
}
}
*is_has_reset = true;
} else {
if (dropout_prob == 1.0f) {
std::fill(x_data, x_data + size, static_cast<T>(0));
return;
}
for (size_t i = 0; i < size; ++i) {
if (mask_data[i] == 0) {
x_data[i] = static_cast<T>(0);
} else {
x_data[i] /= static_cast<T>(1.0f - dropout_prob);
}
}
}
}
template <typename T>
void dropout_cpu_grad_function_inplace(
const framework::ExecutionContext& context, Tensor* grad_x,
const Tensor* mask, const float& dropout_prob) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x);
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dX;
} else {
dX.device(place) = dX * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T, typename CellType>
struct Layer {
explicit Layer(const CellType& cell) : cell_(cell) {}
virtual ~Layer() {}
void preprocess(const framework::ExecutionContext& context,
const Tensor* input, const Tensor& weight,
const Tensor& bias_ih, const Tensor& bias_hh,
Tensor* cache_input, bool is_test) {
// crate the temp input for the X * W_ih^T + Bias_ih
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& hidden_size = weight.dims()[0];
cache_input->Resize(framework::make_ddim(
{input->dims()[0], input->dims()[1], hidden_size}));
if (is_test) {
cache_input->mutable_data<T>(context.GetPlace());
}
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(input->dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(weight.dims(), 0, true);
// convert the batch matmul to matmul, this operator could be speed faster
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast<T>(1.0),
cache_input, static_cast<T>(0));
auto eigen_in = framework::EigenMatrix<T>::Reshape(
*cache_input, cache_input->dims().size() - 1);
auto eigen_bias_ih = framework::EigenMatrix<T>::From(
bias_ih, framework::make_ddim({1, bias_ih.dims()[0]}));
const int& row_num =
framework::product(cache_input->dims()) / cache_input->dims()[2];
eigen_in =
eigen_in + eigen_bias_ih.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
if (is_gru(context)) {
// reset_gate update_gate cell_gate = [1, 1, 0]
Tensor bias_hh_tmp;
bias_hh_tmp.Resize({bias_hh.numel()});
bias_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(bias_hh, context.GetPlace(), dev_ctx, &bias_hh_tmp);
bias_hh_tmp.Resize({3, bias_hh_tmp.numel() / 3});
auto bias_hh_tmp_unbind = Unbind(bias_hh_tmp);
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast<T>(0.0));
auto eigen_bias_hh_tmp = framework::EigenMatrix<T>::From(
bias_hh_tmp, framework::make_ddim({1, bias_hh.dims()[0]}));
eigen_in = eigen_in +
eigen_bias_hh_tmp.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
} else {
auto eigen_bias_hh = framework::EigenMatrix<T>::From(
bias_hh, framework::make_ddim({1, bias_hh.dims()[0]}));
eigen_in =
eigen_in + eigen_bias_hh.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
}
}
void postprocess(const framework::ExecutionContext& context, Tensor* output,
const Tensor* init_h, const Tensor* init_c, Tensor* last_h,
Tensor* last_c, const Tensor& mask_tensor) {
// in the output, if mask flag is 0, we will retun the zero data
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_output =
framework::EigenMatrix<T>::Reshape(*output, output->dims().size() - 1);
auto eigen_mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto eigen_init_h =
framework::EigenMatrix<T>::Reshape(*init_h, init_h->dims().size() - 1);
auto eigen_last_h =
framework::EigenMatrix<T>::Reshape(*last_h, last_h->dims().size() - 1);
auto eigen_mask_broadcast =
eigen_mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[2]));
eigen_last_h.device(place) = eigen_output * eigen_mask_broadcast +
eigen_init_h * (1 - eigen_mask_broadcast);
eigen_output.device(place) = eigen_output * eigen_mask_broadcast;
if (is_lstm(context)) {
auto eigen_init_c = framework::EigenMatrix<T>::Reshape(
*init_c, init_c->dims().size() - 1);
auto eigen_last_c = framework::EigenMatrix<T>::Reshape(
*last_c, last_c->dims().size() - 1);
eigen_last_c.device(place) = eigen_last_c * eigen_mask_broadcast +
eigen_init_c * (1 - eigen_mask_broadcast);
}
}
virtual void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output,
const int& layer_idx, const int& gate_num,
Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {}
void RunTestIter(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList* last_h_ptr,
TensorList* last_c_ptr, Tensor* output, int layer_idx,
Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_bidirect, int offset) {
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& time_step = input->dims()[0];
this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
vec[3 + offset * 4], gate_value, true);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
TensorList mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
Tensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
bool has_allocate_mem_c = false;
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
// define the init_h holder for the swap
Tensor init_h_temp;
framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
&init_h_temp);
Tensor* init_h_holder = &init_h_temp;
Tensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
Tensor* init_c_holder = nullptr;
const Tensor* init_c_temp_holder = nullptr;
Tensor init_c_temp;
Tensor* last_c_holder = nullptr;
Tensor last_c_temp;
if (is_lstm(context)) {
last_c_holder = &(*last_c_ptr)[layer_idx];
init_c_temp_holder = &init_c[layer_idx];
} else if (is_gru(context)) {
// for reset output value
last_c_temp.Resize(init_h[layer_idx].dims());
last_c_temp.mutable_data<T>(context.GetPlace());
last_c_holder = &last_c_temp;
}
Tensor weight_hh_tmp; // for gru
if (is_gru(context)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
weight_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (i > 0) {
if (!has_allocate_mem_c) {
if (is_lstm(context) || is_gru(context)) {
init_c_temp.Resize(init_h[layer_idx].dims());
init_c_temp.mutable_data<T>(context.GetPlace());
init_c_holder = &init_c_temp;
}
has_allocate_mem_c = true;
}
SwapPoniter(&init_c_holder, &last_c_holder);
init_c_temp_holder = init_c_holder;
}
cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
init_c_temp_holder, last_h_holder, last_c_holder, nullptr,
&output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(context, &output_tensors[i], init_h_holder,
init_c_temp_holder, last_h_holder, last_c_holder,
mask_tensor_list[i]);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
&(*last_h_ptr)[layer_idx]);
}
} else {
framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
dev_ctx, &(*last_h_ptr)[layer_idx]);
}
if (time_step % 2 == 0) {
if (is_lstm(context)) {
framework::TensorCopy(*last_c_holder, context.GetPlace(), dev_ctx,
&(*last_c_ptr)[layer_idx]);
}
}
}
void RunIter(const framework::ExecutionContext& context, const Tensor* input,
const TensorList& vec, const TensorList& init_h,
const TensorList& init_c, const Tensor* sequence_length,
TensorList* last_h_ptr, TensorList* last_c_ptr, Tensor* output,
int layer_idx, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_bidirect, int offset,
bool is_test) {
if (is_test) {
RunTestIter(context, input, vec, init_h, init_c, sequence_length,
last_h_ptr, last_c_ptr, output, layer_idx, gate_value,
cell_value, cell_act_value, is_bidirect, offset);
return;
}
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& time_step = input->dims()[0];
this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
vec[3 + offset * 4], gate_value, is_test);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
TensorList mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
Tensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
// define the init_h holder for the swap
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
TensorList cell_value_tensors;
TensorList cell_act_value_tensors;
Tensor init_h_temp;
framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
&init_h_temp);
Tensor* init_h_holder = &init_h_temp;
Tensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
const Tensor* init_c_holder = nullptr;
Tensor* last_c_holder = nullptr;
Tensor* last_c_act_holder = nullptr;
if (is_lstm(context) || is_gru(context)) {
cell_value->Resize({time_step, cell_value->numel() / time_step});
cell_value_tensors = Unbind(*cell_value);
if (is_lstm(context)) {
cell_act_value->Resize(
{time_step, cell_act_value->numel() / time_step});
cell_act_value_tensors = Unbind(*cell_act_value);
}
}
Tensor weight_hh_tmp; // for gru
if (is_gru(context)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
weight_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (is_lstm(context)) {
if (i == 0) {
init_c_holder = &init_c[layer_idx];
} else {
init_c_holder = &cell_value_tensors[i - 1];
}
cell_value_tensors[i].Resize(init_c[layer_idx].dims());
cell_act_value_tensors[i].Resize(init_c[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
last_c_act_holder = &cell_act_value_tensors[i];
} else if (is_gru(context)) {
cell_value_tensors[i].Resize(init_h[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
}
cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
init_c_holder, last_h_holder, last_c_holder, last_c_act_holder,
&output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(context, &output_tensors[i], init_h_holder,
init_c_holder, last_h_holder, last_c_holder,
mask_tensor_list[i]);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
&(*last_h_ptr)[layer_idx]);
}
} else {
framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
dev_ctx, &(*last_h_ptr)[layer_idx]);
}
if (is_lstm(context)) {
framework::TensorCopy(cell_value_tensors[time_step - 1],
context.GetPlace(), dev_ctx,
&(*last_c_ptr)[layer_idx]);
}
}
// Cell for the rnn module
CellType cell_;
};
template <typename T, typename CellType>
struct SingleLayer : public Layer<T, CellType> {
explicit SingleLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output, const int& layer_idx,
const int& gate_num, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, output, layer_idx, gate_value, cell_value,
cell_act_value, false, 0, is_test);
}
};
template <typename T, typename CellType>
struct BidirLayer : public Layer<T, CellType> {
explicit BidirLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output, const int& layer_idx,
const int& gate_num, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {
TensorList output_vec(2);
Tensor forward_input_w, forward_cell_value, forward_cell_act_value;
Tensor backward_input_w, backward_cell_value, backward_cell_act_value;
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int hidden_size = output->dims()[2];
for (int i = 0; i < 2; ++i) {
output_vec[i].Resize({time_step, batch_size, hidden_size / 2});
output_vec[i].mutable_data<T>(context.GetPlace());
}
if (!is_test) {
gate_value->Resize({2, gate_value->numel() / 2});
forward_input_w = gate_value->Slice(0, 1);
backward_input_w = gate_value->Slice(1, 2);
if (is_lstm(context) || is_gru(context)) /* for lstm and gru */ {
cell_value->Resize({2, cell_value->numel() / 2});
cell_act_value->Resize({2, cell_act_value->numel() / 2});
forward_cell_value = cell_value->Slice(0, 1);
backward_cell_value = cell_value->Slice(1, 2);
if (is_lstm(context)) {
forward_cell_act_value = cell_act_value->Slice(0, 1);
backward_cell_act_value = cell_act_value->Slice(1, 2);
}
}
}
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, &output_vec[0], layer_idx, &forward_input_w,
&forward_cell_value, &forward_cell_act_value, true, 0,
is_test);
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, &output_vec[1], layer_idx, &backward_input_w,
&backward_cell_value, &backward_cell_act_value, true, 1,
is_test);
// concat the the output result
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
paddle::operators::math::ConcatFunctor<platform::CPUDeviceContext, T>
concat_functor;
concat_functor(dev_ctx, output_vec, static_cast<int>(2), output);
}
};
template <typename TensorType>
void SplitReserveData(const framework::ExecutionContext& ctx,
TensorType* reserve_data, Tensor* gate_data,
Tensor* cell_data, Tensor* cell_act_data,
Tensor* hidden_data, int direction_num,
const int& time_step, const int& batch_size,
const int& hidden_size, const int& gate_num,
const int& num_layers) {
const int& gate_data_idx = gate_num * num_layers;
const int& cell_data_idx = (gate_num + 1) * num_layers;
const int& cell_act_data_idx = (gate_num + 2) * num_layers;
// simple rnn
int hidden_data_start_idx = gate_data_idx;
*gate_data = reserve_data->Slice(0, gate_data_idx);
if (is_lstm(ctx)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
*cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx);
hidden_data_start_idx = cell_act_data_idx;
} else if (is_gru(ctx)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
hidden_data_start_idx = cell_data_idx;
}
int hidden_data_idx = hidden_data_start_idx + (num_layers - 1);
if (num_layers > 1) {
*hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx);
}
}
template <typename TensorType>
void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec,
const int& num_layers, const int& gate_num,
const bool& is_bidirec,
std::vector<TensorList>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
const int& direction_num = is_bidirec ? 2 : 1;
const int& layer_weight_size = 4 * direction_num;
const int& all_weight_size = num_layers * layer_weight_size;
const int& bias_start_idx = all_weight_size / 2;
for (int i = 0; i < num_layers; i++) {
TensorList tensor_list;
tensor_list.reserve(layer_weight_size);
for (int j = 0; j < layer_weight_size; j++) {
Tensor tensor_holder;
tensor_list.emplace_back(tensor_holder);
}
for (int j = 0; j < layer_weight_size; j++) {
int k = j % 4;
const int& section = j / 4;
int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
if (k >= 2) {
tensor_idx += bias_start_idx;
}
tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
}
params_vec->emplace_back(tensor_list);
}
}
template <typename CellType, typename T>
void AllocateReserveData(const framework::ExecutionContext& ctx,
Tensor* reserve_data, Tensor* gate_data,
Tensor* cell_data, Tensor* cell_act_data,
Tensor* hidden_data, const Tensor* input,
bool is_bidirec, int num_layers, int gate_num,
int hidden_size) {
const int& direction_num = is_bidirec ? 2 : 1;
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& block_size = direction_num * time_step * batch_size * hidden_size;
int hidden_data_idx = (num_layers - 1);
if (is_lstm(ctx)) {
hidden_data_idx += (gate_num + 2) * num_layers;
} else if (is_gru(ctx)) {
hidden_data_idx += (gate_num + 1) * num_layers;
} else {
hidden_data_idx += gate_num * num_layers;
}
reserve_data->Resize({hidden_data_idx, block_size});
reserve_data->mutable_data<T>(ctx.GetPlace());
SplitReserveData(ctx, reserve_data, gate_data, cell_data, cell_act_data,
hidden_data, direction_num, time_step, batch_size,
hidden_size, gate_num, num_layers);
}
template <typename CellType, template <typename, typename> class LayerT,
template <typename, typename> class SingleLayerT,
template <typename, typename> class BidirLayerT, typename T>
void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input,
const std::vector<const Tensor*> weight_list, const Tensor* init_h,
const Tensor* init_c, const Tensor* sequence_length,
Tensor* last_h, Tensor* last_c, Tensor* output,
Tensor* dropout_mask, const int& num_layers, const int& gate_num,
const int& input_size, const int& hidden_size,
const bool& is_bidirec, const std::string& cell_type,
const float& dropout_prob, const bool& is_test, const int& seed,
Tensor* reserve_data) {
const int& direction_num = is_bidirec ? 2 : 1;
const auto& init_h_dims = init_h->dims();
PADDLE_ENFORCE_EQ(init_h_dims[0], num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of init hidden, but received"
" num_layers:%d, dim:%d",
num_layers, init_h_dims[0]));
if (is_lstm(ctx)) {
const auto& init_c_dims = init_c->dims();
PADDLE_ENFORCE_EQ(init_c_dims[0], num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers, init_h_dims[0]));
}
CellType cell;
std::vector<TensorList> parameter_lists;
parameter_lists.reserve(num_layers);
reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
&parameter_lists);
Tensor gate_data, cell_data, cell_act_data, hidden_data;
if (!is_test) {
AllocateReserveData<CellType, T>(
ctx, reserve_data, &gate_data, &cell_data, &cell_act_data, &hidden_data,
input, is_bidirec, num_layers, gate_num, hidden_size);
gate_data.Resize({num_layers, gate_data.numel() / num_layers});
cell_data.Resize({num_layers, cell_data.numel() / num_layers});
cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers});
if (num_layers > 1) {
hidden_data.Resize(
{num_layers - 1, hidden_data.numel() / (num_layers - 1)});
}
}
Tensor* input_holder;
Tensor* output_holder = output;
Tensor temp;
bool has_allocate_mem = false;
auto init_h_unbind = Unbind(*init_h);
auto last_h_unbind = Unbind(*last_h);
TensorList init_c_unbind, last_c_unbind;
if (is_lstm(ctx)) {
init_c_unbind = Unbind(*init_c);
last_c_unbind = Unbind(*last_c);
}
Tensor curr_gate_data, curr_cell_data, curr_cell_act_data;
Tensor curr_hidden_data, prev_hidden_data;
bool has_dropout_reset = false;
for (int i = 0; i < num_layers; i++) {
if (!is_test) {
if (cell_data.numel() > 0) /** for lstm, gru **/ {
curr_cell_data = cell_data.Slice(i, i + 1);
}
if (cell_act_data.numel() > 0) /*for lstm*/ {
curr_cell_act_data = cell_act_data.Slice(i, i + 1);
}
curr_gate_data = gate_data.Slice(i, i + 1);
output_holder = output;
if (i < num_layers - 1 && num_layers > 1) {
curr_hidden_data = hidden_data.Slice(i, i + 1);
curr_hidden_data.Resize(output->dims());
output_holder = &curr_hidden_data;
}
}
if (i > 0) {
if (!has_allocate_mem) {
temp.Resize(output->dims());
temp.mutable_data<T>(ctx.GetPlace());
input_holder = &temp;
has_allocate_mem = true;
}
if (!is_test) {
prev_hidden_data = hidden_data.Slice(i - 1, i);
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims());
} else {
SwapPoniter(&output_holder, &input_holder);
}
if (dropout_prob != 0 && (!is_test)) {
dropout_cpu_function_inplace<T>(ctx, input_holder, dropout_mask,
dropout_prob, seed, is_test,
&has_dropout_reset);
}
}
const Tensor* input_temp_holder = input;
if (i > 0) {
input_temp_holder = input_holder;
}
LayerT<T, CellType>* layer;
SingleLayerT<T, CellType> slayer(cell);
BidirLayerT<T, CellType> blayer(cell);
if (is_bidirec) {
layer = &blayer;
} else {
layer = &slayer;
}
(*layer)(ctx, input_temp_holder, parameter_lists[i], init_h_unbind,
init_c_unbind, sequence_length, last_h_unbind, last_c_unbind,
output_holder, i, gate_num, &curr_gate_data, &curr_cell_data,
&curr_cell_act_data, is_test);
}
if (num_layers % 2 == 0) {
framework::TensorCopy(
*output_holder, ctx.GetPlace(),
ctx.template device_context<platform::CPUDeviceContext>(), output);
}
}
template <typename DeviceContext, typename T>
class RNNCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto pre_state = ctx.MultiInput<Tensor>("PreState");
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
auto state = ctx.MultiOutput<Tensor>("State");
auto* output = ctx.Output<Tensor>("Out");
auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
auto* reserve_data = ctx.Output<Tensor>("Reserve");
const int& num_layers = ctx.Attr<int>("num_layers");
const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
const int& input_size = ctx.Attr<int>("input_size");
const int& hidden_size = ctx.Attr<int>("hidden_size");
const float& dropout_prob = ctx.Attr<float>("dropout_prob");
const std::string& mode = ctx.Attr<std::string>("mode");
const bool& is_test = ctx.Attr<bool>("is_test");
const int& seed = ctx.Attr<int>("seed");
bool has_seq_length = ctx.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}
if (!dropout_mask->IsInitialized()) {
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
}
// init the output and allocate the memory
output->mutable_data<T>(ctx.GetPlace());
int gate_num = 4;
state[0]->mutable_data<T>(ctx.GetPlace());
if (is_lstm(ctx)) {
state[1]->mutable_data<T>(ctx.GetPlace());
RnnFunc<LSTMCell<T>, Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], pre_state[1], sequence_length,
state[0], state[1], output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_rnn_relu(ctx)) {
gate_num = 1;
RnnFunc<
SimpleRNNCell<T, ReluFunctor, math::detail::ActivationType::kReLU>,
Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_rnn_tanh(ctx)) {
gate_num = 1;
RnnFunc<
SimpleRNNCell<T, TanhFunctor, math::detail::ActivationType::kTanhV2>,
Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_gru(ctx)) {
gate_num = 3;
RnnFunc<GRUCell<T>, Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
}
}
};
template <typename T>
void create_lstm_value(math::LstmMetaValue<T>* lstm_value) {
lstm_value->check_ig = nullptr;
lstm_value->check_fg = nullptr;
lstm_value->check_og = nullptr;
}
template <typename T>
void create_lstm_grad(math::LstmMetaGrad<T>* lstm_grad) {
lstm_grad->check_ig_grad = nullptr;
lstm_grad->check_fg_grad = nullptr;
lstm_grad->check_og_grad = nullptr;
}
template <typename T>
void create_tensor_by_list(const framework::ExecutionContext& context,
Tensor* dst, const std::vector<T>& v) {
int tensor_size = v.size();
dst->Resize({tensor_size});
dst->mutable_data<T>(context.GetPlace());
int size = v.size();
for (int i = 0; i < size; ++i) {
dst->data<T>()[i] = v[i];
}
}
template <typename T>
void make_grad_gate_buf(const framework::ExecutionContext& context,
Tensor* grad_gate, Tensor* grad_gate_buf,
Tensor* reset_output_grad = nullptr) {
int dim_size = grad_gate->dims().size();
int batch_size = grad_gate->dims()[dim_size - 2];
int frame_size = grad_gate->dims()[dim_size - 1];
Tensor grad_gate_mask;
create_tensor_by_list<T>(context, &grad_gate_mask, {1, 1, 0});
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_grad_gate_mask = framework::EigenMatrix<T>::From(
grad_gate_mask, framework::make_ddim({3, 1}));
auto eigen_grad_gate_mask_broadcast =
eigen_grad_gate_mask.broadcast(Eigen::DSizes<int, 2>(1, frame_size / 3))
.reshape(Eigen::DSizes<int, 1>(frame_size))
.broadcast(Eigen::DSizes<int, 2>(batch_size, 1));
auto eigen_grad_gate_buf = framework::EigenMatrix<T>::From(
*grad_gate_buf, framework::make_ddim({batch_size, frame_size}));
auto eigen_grad_gate = framework::EigenMatrix<T>::From(
*grad_gate, framework::make_ddim({batch_size, frame_size}));
eigen_grad_gate_buf.device(place) =
eigen_grad_gate * eigen_grad_gate_mask_broadcast;
if (reset_output_grad) {
Tensor grad_reset_output_mask;
create_tensor_by_list<T>(context, &grad_reset_output_mask, {0, 0, 1});
auto eigen_grad_reset_output_mask = framework::EigenMatrix<T>::From(
grad_reset_output_mask, framework::make_ddim({3, 1}));
auto eigen_grad_reset_output_mask_broadcast =
eigen_grad_reset_output_mask
.broadcast(Eigen::DSizes<int, 2>(1, frame_size / 3))
.reshape(Eigen::DSizes<int, 1>(frame_size))
.broadcast(Eigen::DSizes<int, 2>(batch_size, 1));
auto eigen_grad_reset_output =
framework::EigenMatrix<T>::Reshape(*reset_output_grad,
reset_output_grad->dims().size() - 1)
.broadcast(Eigen::DSizes<int, 3>(1, 3, 1))
.reshape(Eigen::DSizes<int, 2>(batch_size, frame_size));
eigen_grad_gate_buf.device(place) =
eigen_grad_gate_buf +
eigen_grad_reset_output_mask_broadcast * eigen_grad_reset_output;
}
}
template <typename T, typename GradCellType>
struct GradLayer {
explicit GradLayer(const GradCellType& cell) : cell_(cell) {}
virtual ~GradLayer() {}
void run_rnn_grad_function(
const framework::ExecutionContext& context,
const platform::CPUDeviceContext& device_ctx, const Tensor* input,
Tensor* input_grad, const Tensor* sequence_length,
std::vector<Tensor>* init_h_unbind, std::vector<Tensor>* init_c_unbind,
std::vector<Tensor>* init_h_grad_unbind,
std::vector<Tensor>* init_c_grad_unbind, Tensor* layer_grad_gate_tensor,
std::vector<Tensor>* layer_gate_tensor_unbind,
std::vector<Tensor>* layer_grad_gate_tensor_unbind,
std::vector<Tensor>* layer_state_tensor_unbind,
std::vector<Tensor>* layer_act_state_tensor_unbind,
std::vector<Tensor>* output_tensor_unbind,
std::vector<Tensor>* output_grad_tensor_unbind,
const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const std::vector<TensorList>& parameter_lists,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& time_step, const bool& has_sequence_length,
const bool& is_bidirec, const bool& is_reverse) {
const int& direction_num = is_bidirec ? 2 : 1;
const int& current_reverse_idx = is_reverse ? 1 : 0;
const int& current_layer_idx =
direction_num * layer_idx + current_reverse_idx;
int begin_idx = 0;
if (is_reverse) {
begin_idx = time_step;
}
Tensor mask_matrix;
TensorList mask_tensor_list;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
// copy the last_h, last_c for swaping pointer
Tensor a, b;
Tensor* dynamic_grad_last_h = &a;
Tensor* dynamic_grad_last_c = &b;
dynamic_grad_last_h->Resize(last_h_grad_unbind[current_layer_idx].dims());
dynamic_grad_last_h->mutable_data<T>(context.GetPlace());
framework::TensorCopy(last_h_grad_unbind[current_layer_idx],
context.GetPlace(), dynamic_grad_last_h);
if (last_c_grad_unbind.size() > 0) {
dynamic_grad_last_c->Resize(last_c_grad_unbind[current_layer_idx].dims());
dynamic_grad_last_c->mutable_data<T>(context.GetPlace());
framework::TensorCopy(last_c_grad_unbind[current_layer_idx],
context.GetPlace(), dynamic_grad_last_c);
} else {
dynamic_grad_last_c = nullptr;
}
Tensor c, d;
Tensor* dynamic_grad_pre_h = &c;
Tensor* dynamic_grad_pre_c = &d;
math::SetConstant<platform::CPUDeviceContext, T> zero;
if (init_h_grad_unbind->size() > 0) {
dynamic_grad_pre_h->ShareDataWith(
(*init_h_grad_unbind)[current_layer_idx]);
} else {
dynamic_grad_pre_h->Resize(dynamic_grad_last_h->dims());
dynamic_grad_pre_h->mutable_data<T>(context.GetPlace());
zero(device_ctx, dynamic_grad_pre_h, static_cast<T>(0.0));
}
if (init_c_grad_unbind->size() > 0) {
dynamic_grad_pre_c->ShareDataWith(
(*init_c_grad_unbind)[current_layer_idx]);
} else {
if (is_lstm(context) || is_gru(context)) {
dynamic_grad_pre_c->Resize(dynamic_grad_last_h->dims());
dynamic_grad_pre_c->mutable_data<T>(context.GetPlace());
if (is_gru(context)) {
dynamic_grad_last_c = dynamic_grad_pre_c;
}
} else {
dynamic_grad_pre_c = nullptr;
}
}
if (is_reverse) {
// must be reverse the input, output, input_grad, output_grad
// the gate and grad_gate must be reverse
std::reverse(layer_gate_tensor_unbind->begin(),
layer_gate_tensor_unbind->end());
std::reverse(layer_grad_gate_tensor_unbind->begin(),
layer_grad_gate_tensor_unbind->end());
/*
if (has_sequence_length) {
std::reverse(mask_tensor_list.begin(), mask_tensor_list.end());
}*/
std::reverse(output_tensor_unbind->begin(), output_tensor_unbind->end());
std::reverse(output_grad_tensor_unbind->begin(),
output_grad_tensor_unbind->end());
}
Tensor* weight_grad =
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 1]);
weight_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, weight_grad, static_cast<T>(0.0));
Tensor* pre_hidden = nullptr;
Tensor* pre_state = nullptr;
Tensor* hidden = nullptr;
Tensor grad_gate_buf;
TensorList grad_gate_buf_unbind;
if (is_gru(context)) {
grad_gate_buf.Resize(layer_grad_gate_tensor->dims());
grad_gate_buf.mutable_data<T>(context.GetPlace());
grad_gate_buf_unbind = Unbind(grad_gate_buf);
}
for (int i = time_step - 1; i >= 0; --i) {
if (has_sequence_length) {
this->mask_preprocess(context, &(*output_grad_tensor_unbind)[i],
dynamic_grad_last_h, dynamic_grad_last_c,
dynamic_grad_pre_h, dynamic_grad_pre_c,
mask_tensor_list[i]);
} else {
this->preprocess(context, &(*output_grad_tensor_unbind)[i],
dynamic_grad_last_h);
}
hidden = &(*output_tensor_unbind)[i];
if (i == 0) {
pre_hidden = &(*init_h_unbind)[current_layer_idx];
if (init_c_unbind->size() > 0) {
pre_state = &(*init_c_unbind)[current_layer_idx];
}
} else {
pre_hidden = &(*output_tensor_unbind)[i - 1];
if (layer_state_tensor_unbind->size() > 0) {
pre_state = &(*layer_state_tensor_unbind)[begin_idx + i - 1];
}
}
this->cell_(
context, &(*layer_gate_tensor_unbind)[i],
&(*layer_state_tensor_unbind)[begin_idx + i],
&(*layer_act_state_tensor_unbind)[begin_idx + i], hidden,
&(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]),
pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c,
&(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h,
dynamic_grad_pre_c, &grad_gate_buf_unbind[i],
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
mask_tensor_list[i], has_sequence_length);
SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h);
SwapPoniter(&dynamic_grad_last_c, &dynamic_grad_pre_c);
}
// postproces for gradient for w_hi, X, bias_hi, bias_hh
this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad,
parameter_lists[layer_idx],
&((*weight_list_grad)[layer_idx]), &grad_gate_buf,
is_reverse);
// copy the gradient to init_c init_h
if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(*dynamic_grad_last_h, context.GetPlace(),
&((*init_h_grad_unbind)[current_layer_idx]));
}
if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(*dynamic_grad_last_c, context.GetPlace(),
&((*init_c_grad_unbind)[current_layer_idx]));
}
}
virtual void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, const TensorList& init_h_unbind,
const TensorList& init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
const std::vector<TensorList>& weight_list_grad, const int& layer_idx,
const int& gate_num) {}
void preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_grad_output = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);
auto eigen_grad_last_h = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
// the output gradient contribute the gradient to last_h
eigen_grad_last_h.device(place) = eigen_grad_last_h + eigen_grad_output;
}
void mask_preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h,
Tensor* grad_last_c, Tensor* grad_pre_h,
Tensor* grad_pre_c, const Tensor& mask_tensor) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto eigen_mask_broadcast =
eigen_mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2]));
auto eigen_grad_last_h = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
auto eigen_grad_pre_h = framework::EigenMatrix<T>::Reshape(
*grad_pre_h, grad_pre_h->dims().size() - 1);
auto eigen_grad_output = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);
eigen_grad_last_h.device(place) =
eigen_grad_last_h + eigen_grad_output * eigen_mask_broadcast;
eigen_grad_pre_h.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_last_h;
eigen_grad_last_h.device(place) = eigen_mask_broadcast * eigen_grad_last_h;
if (grad_last_c && grad_pre_c && is_lstm(context)) {
auto eigen_grad_last_c = framework::EigenMatrix<T>::Reshape(
*grad_last_c, grad_last_c->dims().size() - 1);
auto eigen_grad_pre_c = framework::EigenMatrix<T>::Reshape(
*grad_pre_c, grad_pre_c->dims().size() - 1);
eigen_grad_pre_c.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_last_c;
eigen_grad_last_c.device(place) =
eigen_mask_broadcast * eigen_grad_last_c;
}
}
void postprocess(const framework::ExecutionContext& context,
const Tensor& grad_gate, const Tensor& input,
Tensor* input_grad, const TensorList& parameters,
TensorList* grad_parameters, Tensor* grad_gate_buf,
const int& is_reverse) {
// we get the grad_gate step by step, and need to bradocast the grad to the
// grad_w_hi, grad_bias_hi, grad_bias_hh
int begin_idx = 0;
if (is_reverse) {
begin_idx = 4;
}
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
// calc the gradient for the w_hi
auto mat_dim_out_grad =
math::CreateMatrixDescriptor(grad_gate.dims(), 0, true);
auto mat_dim_input = math::CreateMatrixDescriptor(input.dims(), 0, false);
mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_;
mat_dim_out_grad.batch_size_ = 0;
mat_dim_input.height_ *= mat_dim_input.batch_size_;
mat_dim_input.batch_size_ = 0;
blas.MatMul(grad_gate, mat_dim_out_grad, input, mat_dim_input,
static_cast<T>(1.0), &((*grad_parameters)[begin_idx + 0]),
T(0));
// calc the gradient for the X
auto mat_dim_out_grad_new =
math::CreateMatrixDescriptor(grad_gate.dims(), 0, false);
mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_;
mat_dim_out_grad_new.batch_size_ = 0;
auto mat_dim_parameter =
math::CreateMatrixDescriptor(parameters[0].dims(), 0, false);
blas.MatMul(grad_gate, mat_dim_out_grad_new, parameters[begin_idx + 0],
mat_dim_parameter, static_cast<T>(1.0), input_grad, T(1));
// calc the gradient of Bias_hi, Bias_hh
math::ColwiseSum<platform::CPUDeviceContext, T> col_sum;
Tensor tmp_grad_gate;
tmp_grad_gate.ShareDataWith(grad_gate);
tmp_grad_gate.Resize(
{grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]});
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2]));
// Bias_hh
if (is_gru(context)) {
grad_gate_buf->Resize(tmp_grad_gate.dims());
col_sum(device_ctx, *grad_gate_buf, &((*grad_parameters)[begin_idx + 3]));
} else {
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3]));
}
}
GradCellType cell_;
};
template <typename T, typename GradCellType>
struct SingleGradLayer : GradLayer<T, GradCellType> {
// explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {}
explicit SingleGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~SingleGradLayer() {}
void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, std::vector<Tensor>* init_h_unbind,
std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& gate_num) {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& direction_num = is_bidirec ? 2 : 1;
const int& hidden_size = context.Attr<int>("hidden_size");
// in this section, create the gate_state_grad for the postprocess calculate
// ubind the output, the output from [time_step, batch_size, hidden_size]
auto output_tensor_unbind = Unbind(*output);
auto output_grad_tensor_unbind = Unbind(*output_grad);
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size * gate_num});
auto layer_gate_tensor_unbind = Unbind(layer_gate_tensor);
// the gate_tensor and the grad_gate_tensor must be unbind
Tensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
auto layer_grad_gate_tensor_unbind = Unbind(layer_grad_gate_tensor);
Tensor layer_state_tensor;
TensorList layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
Tensor layer_act_state_tensor;
TensorList layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
const bool& has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_grad_gate_tensor, &layer_gate_tensor_unbind,
&layer_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &output_tensor_unbind,
&output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind,
parameter_lists, weight_list_grad, layer_idx, time_step,
has_sequence_length, is_bidirec, false);
}
};
template <typename T>
void split_tensor_at_last_dim(const framework::ExecutionContext& context,
const platform::CPUDeviceContext& dev_ctx,
const Tensor* output,
std::vector<Tensor*>* output_vec,
const int& axis) {
std::vector<const framework::Tensor*> shape_refer;
(*output_vec)[0]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
(*output_vec)[0]->mutable_data<T>(context.GetPlace());
(*output_vec)[1]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
(*output_vec)[1]->mutable_data<T>(context.GetPlace());
shape_refer.emplace_back((*output_vec)[0]);
shape_refer.emplace_back((*output_vec)[1]);
math::SplitFunctor<platform::CPUDeviceContext, T> functor;
functor(dev_ctx, *output, shape_refer, axis, output_vec);
}
template <typename T, typename GradCellType>
struct BidirGradLayer : GradLayer<T, GradCellType> {
explicit BidirGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~BidirGradLayer() {}
void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, std::vector<Tensor>* init_h_unbind,
std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& gate_num) {
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& direction_num = is_bidirec ? 2 : 1;
const int& hidden_size = context.Attr<int>("hidden_size");
// split the output two tensor to output_forward, output_backward
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
std::vector<Tensor*> output_vec;
Tensor forward_output;
Tensor backward_output;
std::vector<Tensor> forward_output_tensor_unbind;
std::vector<Tensor> backward_output_tensor_unbind;
// in the last layer, we will use the output as the last hidden
// the output just the concat the forward hidden, backward hidden, so just
// split it
// in other layer, we just split the hidden in the rows
output_vec.emplace_back(&forward_output);
output_vec.emplace_back(&backward_output);
split_tensor_at_last_dim<T>(context, device_ctx, output, &output_vec, 2);
forward_output_tensor_unbind = Unbind(*(output_vec[0]));
backward_output_tensor_unbind = Unbind(*(output_vec[1]));
std::vector<Tensor*> output_grad_vec;
Tensor grad_forward_output;
Tensor grad_backward_output;
output_grad_vec.emplace_back(&grad_forward_output);
output_grad_vec.emplace_back(&grad_backward_output);
split_tensor_at_last_dim<T>(context, device_ctx, output_grad,
&output_grad_vec, 2);
auto forward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[0]));
auto backward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[1]));
// the gate_tensor and the grad_gate_tensor must be unbind
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * 2, batch_size, hidden_size * gate_num});
auto layer_forward_gate_tensor = layer_gate_tensor.Slice(0, time_step);
auto layer_backward_gate_tensor =
layer_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_gate_tensor_unbind = Unbind(layer_forward_gate_tensor);
auto layer_backward_gate_tensor_unbind = Unbind(layer_backward_gate_tensor);
Tensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
zero(device_ctx, &layer_grad_gate_tensor, static_cast<T>(0.0));
auto layer_forward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(0, time_step);
auto layer_backward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_grad_gate_tensor_unbind =
Unbind(layer_forward_grad_gate_tensor);
auto layer_backward_grad_gate_tensor_unbind =
Unbind(layer_backward_grad_gate_tensor);
Tensor layer_state_tensor;
TensorList layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
Tensor layer_act_state_tensor;
TensorList layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
const bool& has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_forward_grad_gate_tensor, &layer_forward_gate_tensor_unbind,
&layer_forward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &forward_output_tensor_unbind,
&forward_output_grad_tensor_unbind, last_h_grad_unbind,
last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
time_step, has_sequence_length, is_bidirec, false);
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_backward_grad_gate_tensor, &layer_backward_gate_tensor_unbind,
&layer_backward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &backward_output_tensor_unbind,
&backward_output_grad_tensor_unbind, last_h_grad_unbind,
last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
time_step, has_sequence_length, is_bidirec, true);
}
};
template <typename T>
void backup_tensor(const framework::ExecutionContext& context, Tensor* dst,
Tensor* src) {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
dst->Resize(src->dims());
dst->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*src, device_ctx.GetPlace(), device_ctx, dst);
}
template <typename T>
struct GradCell {
virtual ~GradCell() {}
virtual void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden,
Tensor* grad_state, Tensor* grad_gate,
Tensor* grad_weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_state, Tensor* grad_gate_buf,
Tensor* grad_bias_hh, const Tensor& mask_tensor,
bool has_sequence_length) const {}
virtual void update_pre_hidden_grad(
const framework::ExecutionContext& context, Tensor* grad_gate,
const Tensor* weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state,
Tensor* grad_pre_state_bak, Tensor* grad_gate_buf,
const Tensor& mask_tensor, bool has_sequence_length) const {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
T beta = 0;
Tensor* grad_gate_tmp = grad_gate;
if (is_gru(context)) {
beta = 1.0;
grad_gate_tmp = grad_gate_buf;
}
auto mat_dim_a =
math::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, false);
blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b,
static_cast<T>(1.0), grad_pre_hidden, beta);
if (has_sequence_length) {
auto& place =
*context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto eigen_mask_broadcast = eigen_mask.broadcast(
Eigen::DSizes<int, 2>(1, grad_pre_hidden->dims()[2]));
auto eigen_grad_pre_hidden = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden, grad_pre_hidden->dims().size() - 1);
auto eigen_grad_pre_hidden_bak = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1);
eigen_grad_pre_hidden.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_pre_hidden_bak +
eigen_grad_pre_hidden * eigen_mask_broadcast;
if (grad_pre_state) {
auto eigen_grad_pre_state = framework::EigenMatrix<T>::Reshape(
*grad_pre_state, grad_pre_state->dims().size() - 1);
auto eigen_grad_pre_state_bak = framework::EigenMatrix<T>::Reshape(
*grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1);
eigen_grad_pre_state.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_pre_state_bak +
eigen_grad_pre_state * eigen_mask_broadcast;
}
}
}
virtual void update_weight_hh_grad(const framework::ExecutionContext& context,
Tensor* grad_gate, Tensor* pre_hidden,
Tensor* grad_weight_hh,
Tensor* grad_gate_buf) const {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
auto mat_dim_c = math::CreateMatrixDescriptor(grad_gate->dims(), 0, true);
mat_dim_c.height_ *= mat_dim_c.batch_size_;
mat_dim_c.batch_size_ = 0;
auto mat_dim_d = math::CreateMatrixDescriptor(pre_hidden->dims(), 0, false);
mat_dim_d.height_ *= mat_dim_d.batch_size_;
mat_dim_d.batch_size_ = 0;
Tensor* grad_gate_tmp = grad_gate;
if (is_gru(context)) {
grad_gate_tmp = grad_gate_buf;
}
blas.MatMul(*grad_gate_tmp, mat_dim_c, *pre_hidden, mat_dim_d,
static_cast<T>(1.0), grad_weight_hh, static_cast<T>(1.0));
}
};
template <typename T, template <typename> class EigenActivationBackwardFunctor>
struct SimpleRNNGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_gate_buf, Tensor* grad_bias_hh,
const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
Tensor grad_pre_hidden_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
}
// h = act(z)
// update dz
auto dz = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_gate, "Output", "dz", "Grad"));
auto dh = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_hidden, "Input", "dh", "Grad"));
auto h = EigenVector<T>::Flatten(
GET_DATA_SAFELY(hidden_tensor, "Input", "h", "Value"));
// useless, but need this argument to execute functor
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(gate_tensor, "Input", "z", "Value"));
auto* place = device_ctx.eigen_device();
EigenActivationBackwardFunctor<T> functor;
functor(*place, z, h, dh, dz);
// update grad_weight_hh, grad_pre_hidden
this->update_pre_hidden_grad(
context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak,
nullptr, nullptr, grad_gate_buf, mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh,
grad_gate_buf);
}
};
template <typename T>
struct GRUGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_gate_buf, Tensor* grad_bias_hh,
const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t frame_size = pre_hidden->dims()[2];
size_t batch_size = pre_hidden->dims()[1];
Tensor grad_pre_hidden_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
}
// zero pre_hidden
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, grad_pre_hidden, static_cast<T>(0.0));
math::GRUMetaValue<T> gru_value;
math::GRUMetaGrad<T> gru_grad;
gru_value.gate_value = gate_tensor->data<T>();
gru_value.prev_out_value = pre_hidden->data<T>();
gru_value.reset_output_value = state_tensor->data<T>();
gru_grad.gate_grad = grad_gate->data<T>();
gru_grad.reset_output_grad = grad_state->data<T>();
gru_grad.prev_out_grad = grad_pre_hidden->data<T>();
gru_grad.output_grad = grad_hidden->data<T>();
gru_grad.gate_weight_grad = grad_weight_hh->data<T>();
gru_grad.state_weight_grad =
grad_weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_grad.state_bias_grad = grad_bias_hh->data<T>() + 2 * frame_size;
auto act_gate = math::detail::GetActivationType("sigmoid_v2");
auto act_node = math::detail::GetActivationType("tanh_v2");
math::GRUUnitGradFunctorV2<platform::CPUDeviceContext, T>::compute(
device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node,
act_gate);
make_grad_gate_buf<T>(context, grad_gate, grad_gate_buf, grad_state);
this->update_pre_hidden_grad(
context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak,
nullptr, nullptr, grad_gate_buf, mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh,
grad_gate_buf);
}
};
template <typename T>
struct LSTMGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_gate_buf, Tensor* grad_bias_hh,
const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t frame_size = state_tensor->dims()[2];
size_t batch_size = state_tensor->dims()[1];
Tensor grad_pre_hidden_bak;
Tensor grad_pre_state_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
backup_tensor<T>(context, &grad_pre_state_bak, grad_pre_state);
}
math::LstmMetaValue<T> lstm_value;
math::LstmMetaGrad<T> lstm_grad;
create_lstm_value(&lstm_value);
create_lstm_grad(&lstm_grad);
lstm_value.gate_value = gate_tensor->data<T>();
lstm_value.state_value = state_tensor->data<T>();
lstm_value.state_active_value = act_state_tensor->data<T>();
lstm_value.prev_state_value = pre_state->data<T>();
lstm_grad.state_grad = grad_state->data<T>();
lstm_grad.gate_grad = grad_gate->data<T>();
lstm_grad.output_grad = grad_hidden->data<T>();
lstm_grad.prev_state_grad = grad_pre_state->data<T>();
lstm_value.output_value = nullptr;
lstm_grad.state_active_grad = nullptr;
auto gate_act = math::detail::GetActivationType("sigmoid_v2");
auto state_act = math::detail::GetActivationType("tanh_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2");
T cell_clip = 0.0;
math::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip,
gate_act, state_act, cand_act, false);
this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden,
&grad_pre_hidden_bak, grad_pre_state,
&grad_pre_state_bak, grad_gate_buf,
mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh,
grad_gate_buf);
}
};
template <typename GradCellType,
template <typename, typename> class SingleGradLayerT,
template <typename, typename> class BidirGradLayerT, typename T>
void RnnGradFunc(const framework::ExecutionContext& context,
const int& gate_num) {
// get the tensor pointer for the input
auto* input = context.Input<Tensor>("Input");
auto weight_list = context.MultiInput<Tensor>("WeightList");
auto pre_state = context.MultiInput<Tensor>("PreState");
const Tensor* init_h = pre_state[0];
const Tensor* init_c = nullptr;
if (is_lstm(context)) {
init_c = pre_state[1];
}
auto* reserve_state = context.Input<Tensor>("Reserve");
auto* dropout_state = context.Input<Tensor>("DropoutState");
auto* output = context.Input<Tensor>("Out");
auto* output_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto state_grad = context.MultiInput<Tensor>(framework::GradVarName("State"));
const Tensor* last_h_grad = state_grad[0];
const Tensor* last_c_grad = nullptr;
if (is_lstm(context)) {
last_c_grad = state_grad[1];
}
bool has_seq_length = context.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = context.Input<Tensor>("SequenceLength");
}
// get the tensor pointer for the output
auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
auto weight_grad_list = context.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));
auto pre_state_grad =
context.MultiOutput<Tensor>(framework::GradVarName("PreState"));
Tensor* init_h_grad = nullptr;
Tensor* init_c_grad = nullptr;
if (pre_state_grad.size() > 0) { // has gradient
init_h_grad = pre_state_grad[0];
if (is_lstm(context)) {
init_c_grad = pre_state_grad[1];
}
}
// get the attributes for the calcluate
const int& num_layers = context.Attr<int>("num_layers");
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const float& dropout_prob = context.Attr<float>("dropout_prob");
const bool& is_test = context.Attr<bool>("is_test");
// get the input_size, batch_size, time_step, hidden_size
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& hidden_size = context.Attr<int>("hidden_size");
const int& direction_num = is_bidirec ? 2 : 1;
// allocate the memory and initization the input_grad
Tensor input_grad_value;
if (!input_grad) {
input_grad = &input_grad_value;
}
input_grad->mutable_data<T>(input->dims(), context.GetPlace());
if (init_h_grad) {
init_h_grad->mutable_data<T>(init_h->dims(), context.GetPlace());
}
if (init_c_grad) {
init_c_grad->mutable_data<T>(init_c->dims(), context.GetPlace());
}
// reset the parameter to sorted order and allocate the memory
std::vector<TensorList> parameter_lists;
parameter_lists.reserve(num_layers);
reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
&parameter_lists);
for (unsigned int i = 0; i < weight_grad_list.size(); ++i) {
weight_grad_list[i]->mutable_data<T>(context.GetPlace());
}
std::vector<TensorList> parameter_lists_grad;
parameter_lists_grad.reserve(num_layers);
reset_parameter_vector(weight_grad_list, num_layers, gate_num, is_bidirec,
&parameter_lists_grad);
// resolve the state of reverse_state
Tensor gate_tensor;
Tensor state_tensor;
Tensor act_state_tensor;
Tensor hidden_tensor;
SplitReserveData(context, reserve_state, &gate_tensor, &state_tensor,
&act_state_tensor, &hidden_tensor, direction_num, time_step,
batch_size, hidden_size, gate_num, num_layers);
int gate_num_tmp = gate_num;
if (gate_num == 0) {
gate_num_tmp = 1;
}
gate_tensor.Resize({num_layers, time_step * direction_num, batch_size,
hidden_size * gate_num_tmp});
if (state_tensor.numel() > 0) {
state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (act_state_tensor.numel() > 0) {
act_state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (num_layers > 1) {
hidden_tensor.Resize(
{num_layers - 1, time_step, batch_size, hidden_size * direction_num});
}
// unbind
auto last_h_grad_unbind = Unbind(*last_h_grad);
auto gate_tensor_unbind = Unbind(gate_tensor);
TensorList last_c_grad_unbind;
if (last_c_grad) {
last_c_grad_unbind = Unbind(*last_c_grad);
}
TensorList init_h_unbind, init_c_unbind;
TensorList init_h_grad_unbind, init_c_grad_unbind;
TensorList state_tensor_unbind, act_state_tensor_unbind;
TensorList hidden_tensor_unbind;
init_h_unbind = Unbind(*init_h);
if (init_c) {
init_c_unbind = Unbind(*init_c);
}
if (init_h_grad != nullptr) {
init_h_grad_unbind = Unbind(*init_h_grad);
}
if (init_c_grad != nullptr) {
init_c_grad_unbind = Unbind(*init_c_grad);
}
if (state_tensor.numel() > 0) {
state_tensor_unbind = Unbind(state_tensor);
}
if (act_state_tensor.numel() > 0) {
act_state_tensor_unbind = Unbind(act_state_tensor);
}
if (num_layers > 1) {
hidden_tensor_unbind = Unbind(hidden_tensor);
}
// squeeze the hidden first dim
for (unsigned int i = 0; i < hidden_tensor_unbind.size(); i++) {
hidden_tensor_unbind[i].Resize(
framework::slice_ddim(hidden_tensor_unbind[i].dims(), 1,
hidden_tensor_unbind[i].dims().size()));
}
// add the output tensor to the hidden vector
Tensor tmp;
hidden_tensor_unbind.emplace_back(tmp);
hidden_tensor_unbind[num_layers - 1].ShareDataWith(*output);
GradCellType cell;
Tensor layer_input;
Tensor layer_output;
Tensor* layer_input_grad_holder = nullptr;
Tensor tmp_out;
tmp_out.ShareDataWith(*output_grad);
Tensor* layer_output_grad_holder = &tmp_out;
Tensor input_grad_temp;
Tensor output_grad_temp;
bool has_allocate_mem = false;
for (int i = num_layers - 1; i >= 0; --i) {
// the layer input output had saved, just use the data
if (i > 0) {
layer_input.ShareDataWith(hidden_tensor_unbind[i - 1]);
} else {
layer_input.ShareDataWith(*input);
}
layer_output.ShareDataWith(hidden_tensor_unbind[i]);
if (num_layers == 1) {
layer_input_grad_holder = input_grad;
} else {
if (i == num_layers - 1) {
input_grad_temp.Resize(layer_input.dims());
input_grad_temp.mutable_data<T>(context.GetPlace());
layer_input_grad_holder = &input_grad_temp;
}
}
if (is_bidirec) {
BidirGradLayerT<T, GradCellType> layer(cell);
layer(context, &layer_input, &layer_output, &init_h_unbind,
&init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
layer_output_grad_holder, parameter_lists, sequence_length,
layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
&parameter_lists_grad, i, gate_num_tmp);
} else {
SingleGradLayerT<T, GradCellType> layer(cell);
layer(context, &layer_input, &layer_output, &init_h_unbind,
&init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
layer_output_grad_holder, parameter_lists, sequence_length,
layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
&parameter_lists_grad, i, gate_num_tmp);
}
// calcluate the dropout gradient for the layer_input_grad_holder
// dropout_state save in the forward process
if (i > 0) {
if ((!is_test) && (dropout_prob != 0)) {
dropout_cpu_grad_function_inplace<T>(context, layer_input_grad_holder,
dropout_state, dropout_prob);
}
}
if (i - 1 == 0) {
layer_output_grad_holder = input_grad;
} else {
if (!has_allocate_mem) {
output_grad_temp.Resize(layer_input_grad_holder->dims());
output_grad_temp.mutable_data<T>(context.GetPlace());
layer_output_grad_holder = &output_grad_temp;
has_allocate_mem = true;
}
}
SwapPoniter(&layer_input_grad_holder, &layer_output_grad_holder);
}
}
template <typename DeviceContext, typename T>
class RNNCPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int gate_num = 4;
if (is_lstm(ctx)) {
RnnGradFunc<LSTMGradCell<T>, SingleGradLayer, BidirGradLayer, T>(
ctx, gate_num);
} else if (is_gru(ctx)) {
gate_num = 3;
RnnGradFunc<GRUGradCell<T>, SingleGradLayer, BidirGradLayer, T>(ctx,
gate_num);
// run gru
} else if (is_rnn_relu(ctx)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, ReluGradFunctor>, SingleGradLayer,
BidirGradLayer, T>(ctx, gate_num);
// run rnn
} else if (is_rnn_tanh(ctx)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, TanhGradFunctor>, SingleGradLayer,
BidirGradLayer, T>(ctx, gate_num);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -65,10 +65,18 @@ class TestLstm(unittest.TestCase): ...@@ -65,10 +65,18 @@ class TestLstm(unittest.TestCase):
paddle.jit.ProgramTranslator().enable(True) paddle.jit.ProgramTranslator().enable(True)
net = Net(12, 2) net = Net(12, 2)
x = paddle.randn((2, 10, 12)) x = paddle.randn((2, 10, 12))
x.stop_gradient = False
dygraph_out = net(x) dygraph_out = net(x)
loss = paddle.mean(dygraph_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
# switch eval mode firstly # switch eval mode firstly
net.eval() net.eval()
x = paddle.randn((2, 10, 12))
dygraph_out = net(x)
dropout_out = net(x)
net = paddle.jit.to_static( net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])])
paddle.jit.save(net, 'simple_lstm') paddle.jit.save(net, 'simple_lstm')
...@@ -106,6 +114,14 @@ class TestSaveInEvalMode(unittest.TestCase): ...@@ -106,6 +114,14 @@ class TestSaveInEvalMode(unittest.TestCase):
def test_save_in_eval(self): def test_save_in_eval(self):
paddle.jit.ProgramTranslator().enable(True) paddle.jit.ProgramTranslator().enable(True)
net = LinearNet() net = LinearNet()
x = paddle.randn((2, 10))
x.stop_gradient = False
dygraph_out = net(x)
loss = paddle.mean(dygraph_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
# switch eval mode firstly # switch eval mode firstly
net.eval() net.eval()
# save directly # save directly
...@@ -129,6 +145,14 @@ class TestEvalAfterSave(unittest.TestCase): ...@@ -129,6 +145,14 @@ class TestEvalAfterSave(unittest.TestCase):
def test_eval_after_save(self): def test_eval_after_save(self):
x = paddle.randn((2, 10, 12)).astype('float32') x = paddle.randn((2, 10, 12)).astype('float32')
net = Net(12, 2) net = Net(12, 2)
x.stop_gradient = False
dy_out = net(x)
loss = paddle.mean(dy_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
x = paddle.randn((2, 10, 12)).astype('float32')
dy_out = net(x) dy_out = net(x)
# save model # save model
paddle.jit.save(net, 'jit.save/lstm', input_spec=[x]) paddle.jit.save(net, 'jit.save/lstm', input_spec=[x])
......
...@@ -49,3 +49,34 @@ def convert_params_for_net_static(np_net, paddle_net, place): ...@@ -49,3 +49,34 @@ def convert_params_for_net_static(np_net, paddle_net, place):
paddle_layer.cell_fw, place) paddle_layer.cell_fw, place)
convert_params_for_cell_static(np_layer.cell_bw, convert_params_for_cell_static(np_layer.cell_bw,
paddle_layer.cell_bw, place) paddle_layer.cell_bw, place)
def get_params_for_cell(np_cell, num_layers, idx):
state = np_cell.parameters
weight_list = [
('{}.weight_{}'.format(num_layers, idx), state['weight_ih']),
('{}.weight_{}'.format(num_layers, idx + 1), state['weight_hh'])
]
bias_list = [('{}.bias_{}'.format(num_layers, idx), state['bias_ih']),
('{}.bias_{}'.format(num_layers, idx + 1), state['bias_hh'])]
return weight_list, bias_list
def get_params_for_net(np_net):
weight_list = []
bias_list = []
for layer_idx, np_layer in enumerate(np_net):
if hasattr(np_layer, "cell"):
weight, bias = get_params_for_cell(np_layer.cell, layer_idx, 0)
for w, b in zip(weight, bias):
weight_list.append(w)
bias_list.append(b)
else:
for count, cell in enumerate([np_layer.cell_fw, np_layer.cell_bw]):
weight, bias = get_params_for_cell(cell, layer_idx, count * 2)
for w, b in zip(weight, bias):
weight_list.append(w)
bias_list.append(b)
weight_list.extend(bias_list)
return weight_list
...@@ -33,11 +33,16 @@ class LayerListMixin(LayerMixin): ...@@ -33,11 +33,16 @@ class LayerListMixin(LayerMixin):
class SimpleRNNCell(LayerMixin): class SimpleRNNCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): def __init__(self,
input_size,
hidden_size,
bias=True,
nonlinearity="RNN_TANH",
dtype="float64"):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.bias = bias self.bias = bias
if nonlinearity == 'tanh': if nonlinearity == 'RNN_TANH':
self.nonlinearity = np.tanh self.nonlinearity = np.tanh
else: else:
self.nonlinearity = lambda x: np.maximum(x, 0.) self.nonlinearity = lambda x: np.maximum(x, 0.)
...@@ -45,16 +50,16 @@ class SimpleRNNCell(LayerMixin): ...@@ -45,16 +50,16 @@ class SimpleRNNCell(LayerMixin):
self.parameters = dict() self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, ( self.weight_ih = np.random.uniform(-std, std, (
hidden_size, input_size)).astype('float64') hidden_size, input_size)).astype(dtype)
self.weight_hh = np.random.uniform(-std, std, ( self.weight_hh = np.random.uniform(-std, std, (
hidden_size, hidden_size)).astype('float64') hidden_size, hidden_size)).astype(dtype)
self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh self.parameters['weight_hh'] = self.weight_hh
if bias: if bias:
self.bias_ih = np.random.uniform(-std, std, self.bias_ih = np.random.uniform(-std, std,
(hidden_size, )).astype('float64') (hidden_size, )).astype(dtype)
self.bias_hh = np.random.uniform(-std, std, self.bias_hh = np.random.uniform(-std, std,
(hidden_size, )).astype('float64') (hidden_size, )).astype(dtype)
self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh self.parameters['bias_hh'] = self.bias_hh
else: else:
...@@ -80,23 +85,23 @@ class SimpleRNNCell(LayerMixin): ...@@ -80,23 +85,23 @@ class SimpleRNNCell(LayerMixin):
class GRUCell(LayerMixin): class GRUCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True): def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.bias = bias self.bias = bias
self.parameters = dict() self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, ( self.weight_ih = np.random.uniform(-std, std, (
3 * hidden_size, input_size)).astype('float64') 3 * hidden_size, input_size)).astype(dtype)
self.weight_hh = np.random.uniform(-std, std, ( self.weight_hh = np.random.uniform(-std, std, (
3 * hidden_size, hidden_size)).astype('float64') 3 * hidden_size, hidden_size)).astype(dtype)
self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh self.parameters['weight_hh'] = self.weight_hh
if bias: if bias:
self.bias_ih = np.random.uniform(-std, std, ( self.bias_ih = np.random.uniform(-std, std,
3 * hidden_size)).astype('float64') (3 * hidden_size)).astype(dtype)
self.bias_hh = np.random.uniform(-std, std, ( self.bias_hh = np.random.uniform(-std, std,
3 * hidden_size)).astype('float64') (3 * hidden_size)).astype(dtype)
self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh self.parameters['bias_hh'] = self.bias_hh
else: else:
...@@ -128,23 +133,23 @@ class GRUCell(LayerMixin): ...@@ -128,23 +133,23 @@ class GRUCell(LayerMixin):
class LSTMCell(LayerMixin): class LSTMCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True): def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.bias = bias self.bias = bias
self.parameters = dict() self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.random.uniform(-std, std, ( self.weight_ih = np.random.uniform(-std, std, (
4 * hidden_size, input_size)).astype('float64') 4 * hidden_size, input_size)).astype(dtype)
self.weight_hh = np.random.uniform(-std, std, ( self.weight_hh = np.random.uniform(-std, std, (
4 * hidden_size, hidden_size)).astype('float64') 4 * hidden_size, hidden_size)).astype(dtype)
self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh self.parameters['weight_hh'] = self.weight_hh
if bias: if bias:
self.bias_ih = np.random.uniform(-std, std, ( self.bias_ih = np.random.uniform(-std, std,
4 * hidden_size)).astype('float64') (4 * hidden_size)).astype(dtype)
self.bias_hh = np.random.uniform(-std, std, ( self.bias_hh = np.random.uniform(-std, std,
4 * hidden_size)).astype('float64') (4 * hidden_size)).astype(dtype)
self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh self.parameters['bias_hh'] = self.bias_hh
else: else:
...@@ -403,28 +408,36 @@ class SimpleRNN(RNNMixin): ...@@ -403,28 +408,36 @@ class SimpleRNN(RNNMixin):
input_size, input_size,
hidden_size, hidden_size,
num_layers=1, num_layers=1,
nonlinearity="tanh", nonlinearity="RNN_TANH",
direction="forward", direction="forward",
dropout=0., dropout=0.,
time_major=False): time_major=False,
dtype="float64"):
super(SimpleRNN, self).__init__() super(SimpleRNN, self).__init__()
if direction in ["forward", "backward"]: if direction in ["forward", "backward"]:
is_reverse = direction == "backward" is_reverse = direction == "backward"
cell = SimpleRNNCell(input_size, hidden_size, nonlinearity) cell = SimpleRNNCell(
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = SimpleRNNCell(hidden_size, hidden_size, nonlinearity) cell = SimpleRNNCell(
hidden_size,
hidden_size,
nonlinearity=nonlinearity,
dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction == "bidirectional":
cell_fw = SimpleRNNCell(input_size, hidden_size, nonlinearity) cell_fw = SimpleRNNCell(
cell_bw = SimpleRNNCell(input_size, hidden_size, nonlinearity) input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
cell_bw = SimpleRNNCell(
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell_fw = SimpleRNNCell(2 * hidden_size, hidden_size, cell_fw = SimpleRNNCell(
nonlinearity) 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype)
cell_bw = SimpleRNNCell(2 * hidden_size, hidden_size, cell_bw = SimpleRNNCell(
nonlinearity) 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError(
...@@ -447,23 +460,24 @@ class LSTM(RNNMixin): ...@@ -447,23 +460,24 @@ class LSTM(RNNMixin):
num_layers=1, num_layers=1,
direction="forward", direction="forward",
dropout=0., dropout=0.,
time_major=False): time_major=False,
dtype="float64"):
super(LSTM, self).__init__() super(LSTM, self).__init__()
if direction in ["forward", "backward"]: if direction in ["forward", "backward"]:
is_reverse = direction == "backward" is_reverse = direction == "backward"
cell = LSTMCell(input_size, hidden_size) cell = LSTMCell(input_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = LSTMCell(hidden_size, hidden_size) cell = LSTMCell(hidden_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction == "bidirectional":
cell_fw = LSTMCell(input_size, hidden_size) cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype)
cell_bw = LSTMCell(input_size, hidden_size) cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell_fw = LSTMCell(2 * hidden_size, hidden_size) cell_fw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype)
cell_bw = LSTMCell(2 * hidden_size, hidden_size) cell_bw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError(
...@@ -486,23 +500,24 @@ class GRU(RNNMixin): ...@@ -486,23 +500,24 @@ class GRU(RNNMixin):
num_layers=1, num_layers=1,
direction="forward", direction="forward",
dropout=0., dropout=0.,
time_major=False): time_major=False,
dtype="float64"):
super(GRU, self).__init__() super(GRU, self).__init__()
if direction in ["forward", "backward"]: if direction in ["forward", "backward"]:
is_reverse = direction == "backward" is_reverse = direction == "backward"
cell = GRUCell(input_size, hidden_size) cell = GRUCell(input_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell = GRUCell(hidden_size, hidden_size) cell = GRUCell(hidden_size, hidden_size, dtype=dtype)
self.append(RNN(cell, is_reverse, time_major)) self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional": elif direction == "bidirectional":
cell_fw = GRUCell(input_size, hidden_size) cell_fw = GRUCell(input_size, hidden_size, dtype=dtype)
cell_bw = GRUCell(input_size, hidden_size) cell_bw = GRUCell(input_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers): for i in range(1, num_layers):
cell_fw = GRUCell(2 * hidden_size, hidden_size) cell_fw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype)
cell_bw = GRUCell(2 * hidden_size, hidden_size) cell_bw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype)
self.append(BiRNN(cell_fw, cell_bw, time_major)) self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
import sys
sys.path.append("./rnn")
from rnn_numpy import GRU
from convert import get_params_for_net
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
class TestGRUOp(OpTest):
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def setUp(self):
self.op_type = "rnn"
self.dtype = "float64"
self.sequence_length = np.array(
[12, 11, 10, 9, 8, 7, 6, 5], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.is_test = False
self.mode = "GRU"
self.dropout = 0.
seq_length = 12
batch_size = 8
input_size = 4
self.hidden_size = 2
self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
if self.sequence_length is not None:
input[3][1:][:] = 0
input[4][2:][:] = 0
input[2][3:][:] = 0
input[1][4:][:] = 0
rnn1 = GRU(input_size,
self.hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction,
dropout=self.dropout,
dtype=self.dtype)
flat_w = get_params_for_net(rnn1)
output, last_hidden = rnn1(input, sequence_length=self.sequence_length)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
self.hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)],
}
self.attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec,
'input_size': input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'is_test': self.is_test,
'mode': self.mode
}
self.outputs = {
'Out': output,
'State': [('last_hidden', last_hidden)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
def set_attrs(self):
pass
def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
def test_grad(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h']
grad_check_list.extend(var_name_list)
self.check_grad(set(grad_check_list), ['Out', 'last_hidden'])
class TestGRUOp1(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
class TestGRUOp2(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
class TestGRUOp3(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
self.is_test = True
class TestGRUOp4(TestGRUOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
self.is_test = True
class TestGRUOpAvx(TestGRUOp):
def set_attrs(self):
self.dtype = "float32"
self.hidden_size = 8
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import paddle.fluid.core as core
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
import sys
from op_test import OpTest
sys.path.append("./rnn")
from rnn_numpy import SimpleRNN, LSTM, GRU
from convert import get_params_for_net
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
class TestRNNOp(OpTest):
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def setUp(self):
self.op_type = "rnn"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.mode = "LSTM"
self.is_test = False
self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
seq_length = 12
batch_size = 5
input_size = 3
hidden_size = 2
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
if self.sequence_length is not None:
input[11][1:][:] = 0
input[10][2:][:] = 0
input[9][3:][:] = 0
input[8][4:][:] = 0
rnn1 = LSTM(
input_size,
hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction)
flat_w = get_params_for_net(rnn1)
output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype)
init_c = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h), ('init_c', init_c)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h), ('init_c', init_c)],
}
self.attrs = {
'dropout_prob': 0.0,
'is_bidirec': self.is_bidirec,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': self.num_layers,
'mode': self.mode,
'is_test': self.is_test
}
self.outputs = {
'Out': output,
"State": [('last_hidden', last_hidden), ('last_cell', last_cell)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
def set_attrs(self):
pass
def test_grad(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h', 'init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell'])
class TestRNNOp1(TestRNNOp):
def set_attrs(self):
self.sequence_length = None
class TestRNNOp2(TestRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
class TestRNNOp3(TestRNNOp):
def set_attrs(self):
self.is_test = True
self.sequence_length = None
class TestRNNOp4(TestRNNOp):
def set_attrs(self):
self.is_test = True
self.sequence_length = None
self.is_bidirec = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
import sys
sys.path.append("./rnn")
from rnn_numpy import SimpleRNN
from convert import get_params_for_net
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
class TestSimpleRNNOp(OpTest):
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def setUp(self):
self.op_type = "rnn"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.is_test = False
self.mode = "RNN_TANH"
self.dropout = 0.
self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
seq_length = 12
batch_size = 5
input_size = 3
hidden_size = 2
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
if self.sequence_length is not None:
input[11][1:][:] = 0
input[10][2:][:] = 0
input[9][3:][:] = 0
input[8][4:][:] = 0
rnn1 = SimpleRNN(
input_size,
hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction,
dropout=self.dropout,
nonlinearity=self.mode)
flat_w = get_params_for_net(rnn1)
output, last_hidden = rnn1(input, sequence_length=self.sequence_length)
init_h = np.zeros((self.num_layers * self.direction_num, batch_size,
hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h)]
}
self.attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': self.num_layers,
'is_test': self.is_test,
'mode': self.mode
}
self.outputs = {
'Out': output,
'State': [('last_hidden', last_hidden)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
def set_attrs(self):
pass
def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
def test_grad(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h']
grad_check_list.extend(var_name_list)
self.check_grad(set(grad_check_list), ['Out', 'last_hidden'])
class TestSimpleRNNOp1(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
class TestSimpleRNNOp2(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
class TestSimpleRNNOp3(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_test = True
class TestSimpleRNNOp4(TestSimpleRNNOp):
def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
self.is_test = True
class TestSimpleRNNOp5(TestSimpleRNNOp):
def set_attrs(self):
self.mode = "RNN_RELU"
if __name__ == '__main__':
unittest.main()
...@@ -27,4 +27,5 @@ NEED_TO_FIX_OP_LIST = [ ...@@ -27,4 +27,5 @@ NEED_TO_FIX_OP_LIST = [
'tree_conv', 'tree_conv',
'cvm', 'cvm',
'cudnn_lstm', 'cudnn_lstm',
'rnn',
] ]
...@@ -28,4 +28,5 @@ no_check_set_white_list = [ ...@@ -28,4 +28,5 @@ no_check_set_white_list = [
'check_finite_and_unscale', 'check_finite_and_unscale',
'update_loss_scaling', 'update_loss_scaling',
'cudnn_lstm', 'cudnn_lstm',
'rnn',
] ]
...@@ -43,7 +43,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ ...@@ -43,7 +43,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'yolov3_loss', \ 'yolov3_loss', \
'inverse', \ 'inverse', \
'bilateral_slice',\ 'bilateral_slice',\
'cudnn_lstm' 'cudnn_lstm', \
'rnn', \
] ]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
......
...@@ -985,8 +985,7 @@ class RNNBase(LayerList): ...@@ -985,8 +985,7 @@ class RNNBase(LayerList):
"direction should be forward, backward or bidirectional, " "direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction)) "received direction = {}".format(direction))
self.could_use_cudnn = get_device().startswith( self.could_use_cudnn = True
"gpu:") and get_cudnn_version()
self.could_use_cudnn &= direction != "backward" self.could_use_cudnn &= direction != "backward"
self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * ( self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * (
2 if direction == "bidirectional" else 1) 2 if direction == "bidirectional" else 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册