提交 4c7be265 编写于 作者: Q Qiao Longfei

update avx gru grad kernel test=develop

上级 9b16e540
...@@ -110,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -110,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
ActivationType active_node) { ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -140,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -140,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
&r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad, &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
&r_out_grad, active_node); &r_out_grad, active_node, origin_mode);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
......
...@@ -132,16 +132,26 @@ class gru_stateGrad { ...@@ -132,16 +132,26 @@ class gru_stateGrad {
__m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_frame_state, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_output, __m256 *grad_prev_out, __m256 *grad_output,
ActivationType act_input, bool origin_mode) { ActivationType act_input, bool origin_mode) {
*grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state); if (origin_mode) {
*grad_update_gate = _mm256_sub_ps( *grad_update_gate = _mm256_mul_ps(
*grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out)); *grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state));
*grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
_mm256_sub_ps(*grad_prev_out, *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
_mm256_mul_ps(*grad_output, *value_update_gate)), *grad_frame_state = activation(
*grad_output); _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*grad_frame_state = *value_update_gate)),
activation(_mm256_mul_ps(*grad_output, *value_update_gate), *value_frame_state, act_input);
*value_frame_state, act_input); } else {
*grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out));
*grad_prev_out = _mm256_add_ps(
*grad_prev_out,
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*value_update_gate)));
*grad_frame_state =
activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input);
}
} }
#endif #endif
#endif #endif
......
...@@ -92,7 +92,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -92,7 +92,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const detail::ActivationType active_node, const detail::ActivationType active_node,
const detail::ActivationType active_gate) { const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
...@@ -112,14 +113,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -112,14 +113,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
/* is_batch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value, detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node); grad.output_grad, frame_size, batch_size, active_node, origin_mode);
} else { } else {
detail::KeGruBackwardStateGrad< detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>, detail::backward::gru_stateGrad<T>,
/* is_batch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value, detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node); grad.output_grad, frame_size, batch_size, active_node, origin_mode);
} }
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context); auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册