提交 4feae253 编写于 作者: Q Qiao Longfei

fix build problem test=develop

上级 e641ffe7
......@@ -298,8 +298,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size,
ActivationType active_gate,
bool origin_mode) {
ActivationType active_gate) {
T r_update_gate_value;
T r_update_gate_grad;
T r_reset_gate_value;
......@@ -329,8 +328,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
&r_prev_out_grad, &r_reset_output_grad, active_gate,
origin_mode);
&r_prev_out_grad, &r_reset_output_grad, active_gate);
update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad;
......@@ -389,8 +387,8 @@ template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size, ActivationType active_gate,
bool origin_mode) {
int frame_size,
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
......@@ -422,8 +420,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
&r_prev_out_grad, &r_reset_output_grad, active_gate,
origin_mode);
&r_prev_out_grad, &r_reset_output_grad, active_gate);
update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad;
......@@ -469,18 +466,16 @@ template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad op_reset_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
ActivationType active_gate, bool origin_mode) {
ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad(op_reset_grad, value.gate_value,
grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad,
frame_size, active_gate, origin_mode);
hl_avx_gru_backward_reset_grad(
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
} else {
hl_naive_gru_backward_reset_grad(
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate,
origin_mode);
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
}
value.gate_value += frame_size * 3;
......
......@@ -159,7 +159,8 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size,
ActivationType active_gate) {
ActivationType active_gate,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -189,7 +190,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
&r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
&r_reset_output_grad, active_gate);
&r_reset_output_grad, active_gate, origin_mode);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
......
......@@ -163,8 +163,7 @@ class gru_resetGrad {
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T *value_reset_gate, T *grad_reset_gate,
T *value_prev_out, T *grad_prev_out,
T *grad_reset_output, ActivationType act_gate,
bool origin_mode) {
T *grad_reset_output, ActivationType act_gate) {
*grad_reset_gate = (*grad_reset_output * (*value_prev_out));
*grad_prev_out += (*grad_reset_output * (*value_reset_gate));
*grad_update_gate =
......@@ -181,7 +180,7 @@ class gru_resetGrad {
__m256 *grad_update_gate, __m256 *value_reset_gate,
__m256 *grad_reset_gate, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_reset_output,
ActivationType act_gate, bool origin_mode) {
ActivationType act_gate) {
*grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out);
*grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate));
......
......@@ -78,8 +78,7 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
}
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frame_size, batch_size, active_gate,
origin_mode);
grad, frame_size, batch_size, active_gate);
if (grad.prev_out_grad && value.prev_out_value) {
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight,
......
......@@ -885,8 +885,9 @@ def dynamic_gru(input,
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
if origin_mode is True, then the equation is from paper
`Learning Phrase Representations using RNN Encoder–Decoder for Statistical
if origin_mode is True then the equation is from paper
Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
......@@ -1014,7 +1015,7 @@ def gru_unit(input,
**GRU unit layer**
if origin_mode is True, then the equation of a gru step is from paper
`Learning Phrase Representations using RNN EncoderDecoder for Statistical
`Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册