提交 4feae253 编写于 作者: Q Qiao Longfei

fix build problem test=develop

上级 e641ffe7
...@@ -298,8 +298,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -298,8 +298,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
ActivationType active_gate, ActivationType active_gate) {
bool origin_mode) {
T r_update_gate_value; T r_update_gate_value;
T r_update_gate_grad; T r_update_gate_grad;
T r_reset_gate_value; T r_reset_gate_value;
...@@ -329,8 +328,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -329,8 +328,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
&r_prev_out_grad, &r_reset_output_grad, active_gate, &r_prev_out_grad, &r_reset_output_grad, active_gate);
origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;
...@@ -389,8 +387,8 @@ template <class OpResetGrad, typename T> ...@@ -389,8 +387,8 @@ template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, ActivationType active_gate, int frame_size,
bool origin_mode) { ActivationType active_gate) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_update_gate_value; __m256 r_update_gate_value;
__m256 r_update_gate_grad; __m256 r_update_gate_grad;
...@@ -422,8 +420,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -422,8 +420,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
&r_prev_out_grad, &r_reset_output_grad, active_gate, &r_prev_out_grad, &r_reset_output_grad, active_gate);
origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;
...@@ -469,18 +466,16 @@ template <class OpResetGrad, typename T> ...@@ -469,18 +466,16 @@ template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad op_reset_grad, inline void backward_reset_grad(OpResetGrad op_reset_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_gate, bool origin_mode) { ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad(op_reset_grad, value.gate_value, hl_avx_gru_backward_reset_grad(
grad.gate_grad, value.prev_out_value, op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
frame_size, active_gate, origin_mode);
} else { } else {
hl_naive_gru_backward_reset_grad( hl_naive_gru_backward_reset_grad(
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate, grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
origin_mode);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
......
...@@ -159,7 +159,8 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -159,7 +159,8 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_gate) { ActivationType active_gate,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -189,7 +190,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -189,7 +190,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
&r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad, &r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
&r_reset_output_grad, active_gate); &r_reset_output_grad, active_gate, origin_mode);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad; gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
......
...@@ -163,8 +163,7 @@ class gru_resetGrad { ...@@ -163,8 +163,7 @@ class gru_resetGrad {
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T *value_reset_gate, T *grad_reset_gate, T *value_reset_gate, T *grad_reset_gate,
T *value_prev_out, T *grad_prev_out, T *value_prev_out, T *grad_prev_out,
T *grad_reset_output, ActivationType act_gate, T *grad_reset_output, ActivationType act_gate) {
bool origin_mode) {
*grad_reset_gate = (*grad_reset_output * (*value_prev_out)); *grad_reset_gate = (*grad_reset_output * (*value_prev_out));
*grad_prev_out += (*grad_reset_output * (*value_reset_gate)); *grad_prev_out += (*grad_reset_output * (*value_reset_gate));
*grad_update_gate = *grad_update_gate =
...@@ -181,7 +180,7 @@ class gru_resetGrad { ...@@ -181,7 +180,7 @@ class gru_resetGrad {
__m256 *grad_update_gate, __m256 *value_reset_gate, __m256 *grad_update_gate, __m256 *value_reset_gate,
__m256 *grad_reset_gate, __m256 *value_prev_out, __m256 *grad_reset_gate, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_reset_output, __m256 *grad_prev_out, __m256 *grad_reset_output,
ActivationType act_gate, bool origin_mode) { ActivationType act_gate) {
*grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out); *grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out);
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate)); *grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate));
......
...@@ -78,8 +78,7 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -78,8 +78,7 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
} }
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value, detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frame_size, batch_size, active_gate, grad, frame_size, batch_size, active_gate);
origin_mode);
if (grad.prev_out_grad && value.prev_out_value) { if (grad.prev_out_grad && value.prev_out_value) {
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, grad.gate_grad, frame_size * 3, value.gate_weight,
......
...@@ -885,8 +885,9 @@ def dynamic_gru(input, ...@@ -885,8 +885,9 @@ def dynamic_gru(input,
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t} h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
if origin_mode is True, then the equation is from paper
`Learning Phrase Representations using RNN Encoder–Decoder for Statistical if origin_mode is True then the equation is from paper
Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_ Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math:: .. math::
...@@ -1014,7 +1015,7 @@ def gru_unit(input, ...@@ -1014,7 +1015,7 @@ def gru_unit(input,
**GRU unit layer** **GRU unit layer**
if origin_mode is True, then the equation of a gru step is from paper if origin_mode is True, then the equation of a gru step is from paper
`Learning Phrase Representations using RNN EncoderDecoder for Statistical `Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_ Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math:: .. math::
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册