提交 9b16e540 编写于 作者: Q Qiao Longfei

update gru_grad_op

test=develop
上级 e477d789
...@@ -256,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -256,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
T r_update_gate_value; T r_update_gate_value;
T r_update_gate_grad; T r_update_gate_grad;
T r_frame_state_value; T r_frame_state_value;
...@@ -282,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -282,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
...@@ -297,7 +298,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -297,7 +298,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
ActivationType active_gate) { ActivationType active_gate,
bool origin_mode) {
T r_update_gate_value; T r_update_gate_value;
T r_update_gate_grad; T r_update_gate_grad;
T r_reset_gate_value; T r_reset_gate_value;
...@@ -327,7 +329,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -327,7 +329,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
&r_prev_out_grad, &r_reset_output_grad, active_gate); &r_prev_out_grad, &r_reset_output_grad, active_gate,
origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;
...@@ -341,8 +344,8 @@ template <class OpStateGrad, typename T> ...@@ -341,8 +344,8 @@ template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size, ActivationType active_node,
ActivationType active_node) { bool origin_mode) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_update_gate_value; __m256 r_update_gate_value;
__m256 r_update_gate_grad; __m256 r_update_gate_grad;
...@@ -371,7 +374,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -371,7 +374,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
...@@ -386,8 +389,8 @@ template <class OpResetGrad, typename T> ...@@ -386,8 +389,8 @@ template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size, ActivationType active_gate,
ActivationType active_gate) { bool origin_mode) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_update_gate_value; __m256 r_update_gate_value;
__m256 r_update_gate_grad; __m256 r_update_gate_grad;
...@@ -419,7 +422,8 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -419,7 +422,8 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
&r_prev_out_grad, &r_reset_output_grad, active_gate); &r_prev_out_grad, &r_reset_output_grad, active_gate,
origin_mode);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;
...@@ -434,16 +438,18 @@ template <class OpStateGrad, typename T> ...@@ -434,16 +438,18 @@ template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad, inline void backward_state_grad(OpStateGrad op_state_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_node) { ActivationType active_node, bool origin_mode) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad( hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value,
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node); grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
} else { } else {
hl_naive_gru_backward_state_grad( hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value,
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node); grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
...@@ -463,16 +469,18 @@ template <class OpResetGrad, typename T> ...@@ -463,16 +469,18 @@ template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad op_reset_grad, inline void backward_reset_grad(OpResetGrad op_reset_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_gate) { ActivationType active_gate, bool origin_mode) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad( hl_avx_gru_backward_reset_grad(op_reset_grad, value.gate_value,
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); grad.prev_out_grad, grad.reset_output_grad,
frame_size, active_gate, origin_mode);
} else { } else {
hl_naive_gru_backward_reset_grad( hl_naive_gru_backward_reset_grad(
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate,
origin_mode);
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
......
...@@ -103,13 +103,23 @@ class gru_stateGrad { ...@@ -103,13 +103,23 @@ class gru_stateGrad {
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T *value_frame_state, T *grad_frame_state, T *value_frame_state, T *grad_frame_state,
T *value_prev_out, T *grad_prev_out, T *value_prev_out, T *grad_prev_out,
T *grad_output, ActivationType act_input) { T *grad_output, ActivationType act_input,
*grad_update_gate = (*grad_output * (*value_frame_state)); bool origin_mode) {
*grad_update_gate -= (*grad_output * (*value_prev_out)); if (origin_mode) {
*grad_prev_out -= (*grad_output * (*value_update_gate)); *grad_update_gate =
*grad_prev_out += *grad_output; (*grad_output) * ((*value_prev_out) - (*value_frame_state));
*grad_frame_state = activation(*grad_output * (*value_update_gate), *grad_prev_out += (*grad_output * (*value_update_gate));
*value_frame_state, act_input); *grad_frame_state = activation(
*grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_input);
} else {
*grad_update_gate =
(*grad_output) * ((*value_frame_state) - (*value_prev_out));
*grad_prev_out +=
(*grad_output * (static_cast<T>(1.0) - *value_update_gate));
*grad_frame_state = activation(*grad_output * (*value_update_gate),
*value_frame_state, act_input);
}
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
...@@ -121,7 +131,7 @@ class gru_stateGrad { ...@@ -121,7 +131,7 @@ class gru_stateGrad {
__m256 *value_frame_state, __m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_frame_state, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_output, __m256 *grad_prev_out, __m256 *grad_output,
ActivationType act_input) { ActivationType act_input, bool origin_mode) {
*grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state); *grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state);
*grad_update_gate = _mm256_sub_ps( *grad_update_gate = _mm256_sub_ps(
*grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out)); *grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out));
...@@ -143,7 +153,8 @@ class gru_resetGrad { ...@@ -143,7 +153,8 @@ class gru_resetGrad {
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T *value_reset_gate, T *grad_reset_gate, T *value_reset_gate, T *grad_reset_gate,
T *value_prev_out, T *grad_prev_out, T *value_prev_out, T *grad_prev_out,
T *grad_reset_output, ActivationType act_gate) { T *grad_reset_output, ActivationType act_gate,
bool origin_mode) {
*grad_reset_gate = (*grad_reset_output * (*value_prev_out)); *grad_reset_gate = (*grad_reset_output * (*value_prev_out));
*grad_prev_out += (*grad_reset_output * (*value_reset_gate)); *grad_prev_out += (*grad_reset_output * (*value_reset_gate));
*grad_update_gate = *grad_update_gate =
...@@ -160,7 +171,7 @@ class gru_resetGrad { ...@@ -160,7 +171,7 @@ class gru_resetGrad {
__m256 *grad_update_gate, __m256 *value_reset_gate, __m256 *grad_update_gate, __m256 *value_reset_gate,
__m256 *grad_reset_gate, __m256 *value_prev_out, __m256 *grad_reset_gate, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_reset_output, __m256 *grad_prev_out, __m256 *grad_reset_output,
ActivationType act_gate) { ActivationType act_gate, bool origin_mode) {
*grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out); *grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out);
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate)); *grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate));
......
...@@ -60,7 +60,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -60,7 +60,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
bool origin_mode) { bool origin_mode) {
#ifndef __NVCC__ #ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node); grad, frame_size, batch_size, active_node,
origin_mode);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) { if (value.prev_out_value && grad.prev_out_grad) {
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
...@@ -77,7 +78,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -77,7 +78,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
} }
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value, detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frame_size, batch_size, active_gate); grad, frame_size, batch_size, active_gate,
origin_mode);
if (grad.prev_out_grad && value.prev_out_value) { if (grad.prev_out_grad && value.prev_out_value) {
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, grad.gate_grad, frame_size * 3, value.gate_weight,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册