From 3e552cdcac5370a59152c60670008e575a80da5d Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 29 Nov 2017 11:31:15 +0800 Subject: [PATCH] Fix gru_op related code style --- paddle/operators/gru_op.h | 46 +- paddle/operators/math/detail/gru_cpu_kernel.h | 540 +++++++++--------- paddle/operators/math/detail/gru_gpu_kernel.h | 252 ++++---- paddle/operators/math/detail/gru_kernel.h | 135 +++-- paddle/operators/math/gru_compute.cc | 64 ++- paddle/operators/math/gru_compute.cu | 148 ++--- paddle/operators/math/gru_compute.h | 31 +- 7 files changed, 617 insertions(+), 599 deletions(-) diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index 1b18368e0..564489d3a 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel { int frame_size = hidden_dims[1]; math::hl_gru_value gru_value; - gru_value.gateWeight = const_cast(weight_data); - gru_value.stateWeight = + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); Tensor ordered_h0; const size_t* order = batch_gate->lod()[2].data(); @@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel { // to reorder. ReorderInitState(context.device_context(), *h0, order, &ordered_h0, true); - gru_value.prevOutValue = ordered_h0.data(); + gru_value.prev_out_value = ordered_h0.data(); } else { - gru_value.prevOutValue = nullptr; + gru_value.prev_out_value = nullptr; } auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; @@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel { Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.outputValue = hidden_t.data(); - gru_value.gateValue = gate_t.data(); - gru_value.resetOutputValue = reset_hidden_prev_t.data(); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); math::GRUUnitFunctor::compute( dev_ctx, gru_value, frame_size, cur_batch_size, math::ActiveType(context.Attr("activation")), math::ActiveType(context.Attr("gate_activation"))); - gru_value.prevOutValue = gru_value.outputValue; + gru_value.prev_out_value = gru_value.output_value; } math::Batch2LoDTensorFunctor to_seq; @@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel { to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); math::hl_gru_value gru_value; - gru_value.gateWeight = const_cast(weight_data); - gru_value.stateWeight = + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); math::hl_gru_grad gru_grad; if (weight_grad) { - gru_grad.gateWeightGrad = + gru_grad.gate_weight_grad = weight_grad->mutable_data(context.GetPlace()); zero(dev_ctx, weight_grad, static_cast(0.0)); - gru_grad.stateWeightGrad = + gru_grad.state_weight_grad = weight_grad->data() + 2 * frame_size * frame_size; } else { - gru_grad.gateWeightGrad = nullptr; - gru_grad.stateWeightGrad = nullptr; + gru_grad.gate_weight_grad = nullptr; + gru_grad.state_weight_grad = nullptr; } auto batch_starts = batch_hidden_grad.lod()[0]; @@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel { int cur_batch_size = bend - bstart; Tensor gate_t = batch_gate->Slice(bstart, bend); - gru_value.gateValue = gate_t.data(); + gru_value.gate_value = gate_t.data(); Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); - gru_value.resetOutputValue = reset_hidden_prev_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); - gru_grad.outputGrad = hidden_grad_t.data(); + gru_grad.output_grad = hidden_grad_t.data(); Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); - gru_grad.gateGrad = gate_grad_t.data(); + gru_grad.gate_grad = gate_grad_t.data(); Tensor reset_hidden_prev_grad_t = batch_reset_hidden_prev_grad.Slice(bstart, bend); - gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data(); + gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data(); if (n == 0) { - gru_value.prevOutValue = h0 ? ordered_h0.data() : nullptr; - gru_grad.prevOutGrad = + gru_value.prev_out_value = h0 ? ordered_h0.data() : nullptr; + gru_grad.prev_out_grad = h0 && h0_grad ? ordered_h0_grad.data() : nullptr; } else { int bstart_pre = static_cast(batch_starts[n - 1]); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); - gru_value.prevOutValue = hidden_prev_t.data(); + gru_value.prev_out_value = hidden_prev_t.data(); Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart); - gru_grad.prevOutGrad = hidden_prev_grad_t.data(); + gru_grad.prev_out_grad = hidden_prev_grad_t.data(); } math::GRUUnitGradFunctor::compute( diff --git a/paddle/operators/math/detail/gru_cpu_kernel.h b/paddle/operators/math/detail/gru_cpu_kernel.h index 51af140cf..4c67dec9c 100644 --- a/paddle/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/operators/math/detail/gru_cpu_kernel.h @@ -25,393 +25,397 @@ namespace detail { #ifndef __NVCC__ template -void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput, - T *gateValue, T *resetOutputValue, - T *prevOutputValue, int frameSize, +void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, + T *gate_value, T *reset_output_value, + T *prev_output_value, int frame_size, activation_mode_t active_gate) { - T rValueUpdateGate; - T rValueResetGate; - T rValueResetOutput; - T rPrevOut = 0; - T *updateGate = gateValue; - T *resetGate = gateValue + frameSize; - - for (int i = 0; i < frameSize; i++) { - rValueUpdateGate = updateGate[i]; - rValueResetGate = resetGate[i]; - if (prevOutputValue) { - rPrevOut = prevOutputValue[i]; + T r_value_update_gate; + T r_value_reset_gate; + T r_value_reset_output; + T r_prev_out = 0; + T *update_gate = gate_value; + T *reset_gate = gate_value + frame_size; + + for (int i = 0; i < frame_size; i++) { + r_value_update_gate = update_gate[i]; + r_value_reset_gate = reset_gate[i]; + if (prev_output_value) { + r_prev_out = prev_output_value[i]; } - opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, - rValueResetOutput, active_gate); + op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, + r_value_reset_output, active_gate); - updateGate[i] = rValueUpdateGate; - resetGate[i] = rValueResetGate; - resetOutputValue[i] = rValueResetOutput; + update_gate[i] = r_value_update_gate; + reset_gate[i] = r_value_reset_gate; + reset_output_value[i] = r_value_reset_output; } } template -void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput, - T *gateValue, T *prevOutputValue, - T *outputValue, int frameSize, +void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, + T *gate_value, T *prev_output_value, + T *output_value, int frame_size, activation_mode_t active_node) { - T rValueUpdateGate; - T rValueFrameState; - T rPrevOut = 0; - T rOutput; - T *updateGate = gateValue; - T *frameState = gateValue + frameSize * 2; - - for (int i = 0; i < frameSize; i++) { - rValueUpdateGate = updateGate[i]; - rValueFrameState = frameState[i]; - if (prevOutputValue) { - rPrevOut = prevOutputValue[i]; + T r_value_update_gate; + T r_value_frame_state; + T r_prev_out = 0; + T r_output; + T *update_gate = gate_value; + T *frame_state = gate_value + frame_size * 2; + + for (int i = 0; i < frame_size; i++) { + r_value_update_gate = update_gate[i]; + r_value_frame_state = frame_state[i]; + if (prev_output_value) { + r_prev_out = prev_output_value[i]; } - opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, - active_node); + op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, + r_output, active_node); - frameState[i] = rValueFrameState; - outputValue[i] = rOutput; + frame_state[i] = r_value_frame_state; + output_value[i] = r_output; } } template -void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue, - T *resetOutputValue, T *prevOutputValue, - int frameSize, +void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, + T *gate_value, T *reset_output_value, + T *prev_output_value, int frame_size, activation_mode_t active_gate) { #ifdef __AVX__ - __m256 rValueUpdateGate; - __m256 rValueResetGate; - __m256 rValueResetOutput; - __m256 rPrevOut = _mm256_set1_ps(0.0f); - __m256 *updateGate = (__m256 *)gateValue; - __m256 *resetGate = (__m256 *)(gateValue + frameSize); - - for (int i = 0; i < frameSize / 8; i++) { - rValueUpdateGate = updateGate[i]; - rValueResetGate = resetGate[i]; - if (prevOutputValue) { - rPrevOut = ((__m256 *)prevOutputValue)[i]; + __m256 r_value_update_gate; + __m256 r_value_reset_gate; + __m256 r_value_reset_output; + __m256 r_prev_out = _mm256_set1_ps(0.0f); + __m256 *update_gate = (__m256 *)gate_value; + __m256 *reset_gate = (__m256 *)(gate_value + frame_size); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_update_gate = update_gate[i]; + r_value_reset_gate = reset_gate[i]; + if (prev_output_value) { + r_prev_out = ((__m256 *)prev_output_value)[i]; } - opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, - rValueResetOutput, active_gate); + op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, + r_value_reset_output, active_gate); - updateGate[i] = rValueUpdateGate; - resetGate[i] = rValueResetGate; - ((__m256 *)resetOutputValue)[i] = rValueResetOutput; + update_gate[i] = r_value_update_gate; + reset_gate[i] = r_value_reset_gate; + ((__m256 *)reset_output_value)[i] = r_value_reset_output; } #endif } template -void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue, - T *prevOutputValue, T *outputValue, - int frameSize, +void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, + T *gate_value, T *prev_output_value, + T *output_value, int frame_size, activation_mode_t active_node) { #ifdef __AVX__ - __m256 rValueUpdateGate; - __m256 rValueFrameState; - __m256 rPrevOut = _mm256_set1_ps(0.0f); - __m256 rOutput; - __m256 *updateGate = (__m256 *)gateValue; - __m256 *frameState = (__m256 *)(gateValue + frameSize * 2); - - for (int i = 0; i < frameSize / 8; i++) { - rValueUpdateGate = updateGate[i]; - rValueFrameState = frameState[i]; - if (prevOutputValue) { - rPrevOut = ((__m256 *)prevOutputValue)[i]; + __m256 r_value_update_gate; + __m256 r_value_frame_state; + __m256 r_prev_out = _mm256_set1_ps(0.0f); + __m256 r_output; + __m256 *update_gate = (__m256 *)gate_value; + __m256 *frame_state = (__m256 *)(gate_value + frame_size * 2); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_update_gate = update_gate[i]; + r_value_frame_state = frame_state[i]; + if (prev_output_value) { + r_prev_out = ((__m256 *)prev_output_value)[i]; } - opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, - active_node); + op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, + r_output, active_node); - frameState[i] = rValueFrameState; - ((__m256 *)outputValue)[i] = rOutput; + frame_state[i] = r_value_frame_state; + ((__m256 *)output_value)[i] = r_output; } #endif } template -inline void forward_reset_output(OpResetOutput opResetOutput, - hl_gru_value value, int frameSize, - int batchSize, activation_mode_t active_gate) { - for (int b = 0; b < batchSize; b++) { - if (OpResetOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { +inline void forward_reset_output(OpResetOutput op_reset_output, + hl_gru_value value, int frame_size, + int batch_size, + activation_mode_t active_gate) { + for (int b = 0; b < batch_size; b++) { + if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_reset_output( - opResetOutput, value.gateValue, value.resetOutputValue, - value.prevOutValue, frameSize, active_gate); + op_reset_output, value.gate_value, value.reset_output_value, + value.prev_out_value, frame_size, active_gate); } else { hl_naive_gru_forward_reset_output( - opResetOutput, value.gateValue, value.resetOutputValue, - value.prevOutValue, frameSize, active_gate); + op_reset_output, value.gate_value, value.reset_output_value, + value.prev_out_value, frame_size, active_gate); } - value.gateValue += frameSize * 3; - value.resetOutputValue += frameSize; - if (value.prevOutValue) { - value.prevOutValue += frameSize; + value.gate_value += frame_size * 3; + value.reset_output_value += frame_size; + if (value.prev_out_value) { + value.prev_out_value += frame_size; } } } template -inline void forward_final_output(OpFinalOutput opFinalOutput, - hl_gru_value value, int frameSize, - int batchSize, activation_mode_t active_node) { - for (int b = 0; b < batchSize; b++) { - if (OpFinalOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_forward_final_output(opFinalOutput, value.gateValue, - value.prevOutValue, value.outputValue, - frameSize, active_node); +inline void forward_final_output(OpFinalOutput op_final_output, + hl_gru_value value, int frame_size, + int batch_size, + activation_mode_t active_node) { + for (int b = 0; b < batch_size; b++) { + if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { + hl_avx_gru_forward_final_output(op_final_output, value.gate_value, + value.prev_out_value, value.output_value, + frame_size, active_node); } else { - hl_naive_gru_forward_final_output(opFinalOutput, value.gateValue, - value.prevOutValue, value.outputValue, - frameSize, active_node); + hl_naive_gru_forward_final_output( + op_final_output, value.gate_value, value.prev_out_value, + value.output_value, frame_size, active_node); } - value.gateValue += frameSize * 3; - value.outputValue += frameSize; - if (value.prevOutValue) { - value.prevOutValue += frameSize; + value.gate_value += frame_size * 3; + value.output_value += frame_size; + if (value.prev_out_value) { + value.prev_out_value += frame_size; } } } template -void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, - T *gateGrad, T *prevOutValue, - T *prevOutGrad, T *outputGrad, - int frameSize, +void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, + T *gate_grad, T *prev_out_value, + T *prev_out_grad, T *output_grad, + int frame_size, activation_mode_t active_node) { - T rUpdateGateValue; - T rUpdateGateGrad; - T rFrameStateValue; - T rFrameStateGrad; - T rOutGrad; - T rPrevOutValue = 0; - T rPrevOutGrad = 0; - T *updateGateValue = gateValue; - T *updateGateGrad = gateGrad; - T *frameStateValue = gateValue + frameSize * 2; - T *frameStateGrad = gateGrad + frameSize * 2; - - for (int i = 0; i < frameSize; i++) { - rUpdateGateValue = updateGateValue[i]; - rFrameStateValue = frameStateValue[i]; - rOutGrad = outputGrad[i]; - if (prevOutValue) { - rPrevOutValue = prevOutValue[i]; + T r_update_gate_value; + T r_update_gate_grad; + T r_frame_state_value; + T r_frame_state_grad; + T r_out_grad; + T r_prev_out_value = 0; + T r_prev_out_grad = 0; + T *update_gate_value = gate_value; + T *update_gate_grad = gate_grad; + T *frame_state_value = gate_value + frame_size * 2; + T *frame_state_grad = gate_grad + frame_size * 2; + + for (int i = 0; i < frame_size; i++) { + r_update_gate_value = update_gate_value[i]; + r_frame_state_value = frame_state_value[i]; + r_out_grad = output_grad[i]; + if (prev_out_value) { + r_prev_out_value = prev_out_value[i]; } - if (prevOutGrad) { - rPrevOutGrad = prevOutGrad[i]; + if (prev_out_grad) { + r_prev_out_grad = prev_out_grad[i]; } - opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, - rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, - active_node); + op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, + r_frame_state_grad, r_prev_out_value, r_prev_out_grad, + r_out_grad, active_node); - updateGateGrad[i] = rUpdateGateGrad; - frameStateGrad[i] = rFrameStateGrad; - if (prevOutGrad) { - prevOutGrad[i] = rPrevOutGrad; + update_gate_grad[i] = r_update_gate_grad; + frame_state_grad[i] = r_frame_state_grad; + if (prev_out_grad) { + prev_out_grad[i] = r_prev_out_grad; } } } template -void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, - T *gateGrad, T *prevOutValue, - T *prevOutGrad, T *resetOutputGrad, - int frameSize, +void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, + T *gate_grad, T *prev_out_value, + T *prev_out_grad, T *reset_output_grad, + int frame_size, activation_mode_t active_gate) { - T rUpdateGateValue; - T rUpdateGateGrad; - T rResetGateValue; - T rResetGateGrad; - T rResetOutputGrad = 0; - T rPrevOutValue = 0; - T rPrevOutGrad = 0; - T *updateGateValue = gateValue; - T *updateGateGrad = gateGrad; - T *resetGateValue = gateValue + frameSize; - T *resetGateGrad = gateGrad + frameSize; - - for (int i = 0; i < frameSize; i++) { - rUpdateGateValue = updateGateValue[i]; - rUpdateGateGrad = updateGateGrad[i]; - rResetGateValue = resetGateValue[i]; - - if (prevOutValue && prevOutGrad) { - rResetOutputGrad = resetOutputGrad[i]; + T r_update_gate_value; + T r_update_gate_grad; + T r_reset_gate_value; + T r_reset_gate_grad; + T r_reset_output_grad = 0; + T r_prev_out_value = 0; + T r_prev_out_grad = 0; + T *update_gate_value = gate_value; + T *update_gate_grad = gate_grad; + T *reset_gate_value = gate_value + frame_size; + T *reset_gate_grad = gate_grad + frame_size; + + for (int i = 0; i < frame_size; i++) { + r_update_gate_value = update_gate_value[i]; + r_update_gate_grad = update_gate_grad[i]; + r_reset_gate_value = reset_gate_value[i]; + + if (prev_out_value && prev_out_grad) { + r_reset_output_grad = reset_output_grad[i]; } - if (prevOutValue) { - rPrevOutValue = prevOutValue[i]; + if (prev_out_value) { + r_prev_out_value = prev_out_value[i]; } - if (prevOutGrad) { - rPrevOutGrad = prevOutGrad[i]; + if (prev_out_grad) { + r_prev_out_grad = prev_out_grad[i]; } - opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, - rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, - active_gate); + op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, + r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, + r_reset_output_grad, active_gate); - updateGateGrad[i] = rUpdateGateGrad; - resetGateGrad[i] = rResetGateGrad; - if (prevOutGrad) { - prevOutGrad[i] = rPrevOutGrad; + update_gate_grad[i] = r_update_gate_grad; + reset_gate_grad[i] = r_reset_gate_grad; + if (prev_out_grad) { + prev_out_grad[i] = r_prev_out_grad; } } } template -void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, - T *gateGrad, T *prevOutValue, - T *prevOutGrad, T *outputGrad, - int frameSize, +void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, + T *gate_grad, T *prev_out_value, + T *prev_out_grad, T *output_grad, + int frame_size, activation_mode_t active_node) { #ifdef __AVX__ - __m256 rUpdateGateValue; - __m256 rUpdateGateGrad; - __m256 rFrameStateValue; - __m256 rFrameStateGrad; - __m256 rOutGrad; - __m256 rPrevOutValue = _mm256_set1_ps(0.0f); - __m256 rPrevOutGrad = _mm256_set1_ps(0.0f); - __m256 *updateGateValue = (__m256 *)gateValue; - __m256 *updateGateGrad = (__m256 *)gateGrad; - __m256 *frameStateValue = (__m256 *)(gateValue + frameSize * 2); - __m256 *frameStateGrad = (__m256 *)(gateGrad + frameSize * 2); - - for (int i = 0; i < frameSize / 8; i++) { - rUpdateGateValue = updateGateValue[i]; - rFrameStateValue = frameStateValue[i]; - rOutGrad = ((__m256 *)outputGrad)[i]; - if (prevOutValue) { - rPrevOutValue = ((__m256 *)prevOutValue)[i]; + __m256 r_update_gate_value; + __m256 r_update_gate_grad; + __m256 r_frame_state_value; + __m256 r_frame_state_grad; + __m256 r_out_grad; + __m256 r_prev_out_value = _mm256_set1_ps(0.0f); + __m256 r_prev_out_grad = _mm256_set1_ps(0.0f); + __m256 *update_gate_value = (__m256 *)gate_value; + __m256 *update_gate_grad = (__m256 *)gate_grad; + __m256 *frame_state_value = (__m256 *)(gate_value + frame_size * 2); + __m256 *frame_state_grad = (__m256 *)(gate_grad + frame_size * 2); + + for (int i = 0; i < frame_size / 8; i++) { + r_update_gate_value = update_gate_value[i]; + r_frame_state_value = frame_state_value[i]; + r_out_grad = ((__m256 *)output_grad)[i]; + if (prev_out_value) { + r_prev_out_value = ((__m256 *)prev_out_value)[i]; } - if (prevOutGrad) { - rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; + if (prev_out_grad) { + r_prev_out_grad = ((__m256 *)prev_out_grad)[i]; } - opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, - rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, - active_node); + op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, + r_frame_state_grad, r_prev_out_value, r_prev_out_grad, + r_out_grad, active_node); - updateGateGrad[i] = rUpdateGateGrad; - frameStateGrad[i] = rFrameStateGrad; - if (prevOutGrad) { - ((__m256 *)prevOutGrad)[i] = rPrevOutGrad; + update_gate_grad[i] = r_update_gate_grad; + frame_state_grad[i] = r_frame_state_grad; + if (prev_out_grad) { + ((__m256 *)prev_out_grad)[i] = r_prev_out_grad; } } #endif } template -void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, - T *gateGrad, T *prevOutValue, - T *prevOutGrad, T *resetOutputGrad, - int frameSize, +void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, + T *gate_grad, T *prev_out_value, + T *prev_out_grad, T *reset_output_grad, + int frame_size, activation_mode_t active_gate) { #ifdef __AVX__ - __m256 rUpdateGateValue; - __m256 rUpdateGateGrad; - __m256 rResetGateValue; - __m256 rResetGateGrad; - __m256 rResetOutputGrad = _mm256_set1_ps(0.0f); - __m256 rPrevOutValue = _mm256_set1_ps(0.0f); - __m256 rPrevOutGrad = _mm256_set1_ps(0.0f); - __m256 *updateGateValue = (__m256 *)gateValue; - __m256 *updateGateGrad = (__m256 *)gateGrad; - __m256 *resetGateValue = (__m256 *)(gateValue + frameSize); - __m256 *resetGateGrad = (__m256 *)(gateGrad + frameSize); - - for (int i = 0; i < frameSize / 8; i++) { - rUpdateGateValue = updateGateValue[i]; - rUpdateGateGrad = updateGateGrad[i]; - rResetGateValue = resetGateValue[i]; - - if (prevOutValue && prevOutGrad) { - rResetOutputGrad = ((__m256 *)resetOutputGrad)[i]; + __m256 r_update_gate_value; + __m256 r_update_gate_grad; + __m256 r_reset_gate_value; + __m256 r_reset_gate_grad; + __m256 r_reset_output_grad = _mm256_set1_ps(0.0f); + __m256 r_prev_out_value = _mm256_set1_ps(0.0f); + __m256 r_prev_out_grad = _mm256_set1_ps(0.0f); + __m256 *update_gate_value = (__m256 *)gate_value; + __m256 *update_gate_grad = (__m256 *)gate_grad; + __m256 *reset_gate_value = (__m256 *)(gate_value + frame_size); + __m256 *reset_gate_grad = (__m256 *)(gate_grad + frame_size); + + for (int i = 0; i < frame_size / 8; i++) { + r_update_gate_value = update_gate_value[i]; + r_update_gate_grad = update_gate_grad[i]; + r_reset_gate_value = reset_gate_value[i]; + + if (prev_out_value && prev_out_grad) { + r_reset_output_grad = ((__m256 *)reset_output_grad)[i]; } - if (prevOutValue) { - rPrevOutValue = ((__m256 *)prevOutValue)[i]; + if (prev_out_value) { + r_prev_out_value = ((__m256 *)prev_out_value)[i]; } - if (prevOutGrad) { - rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; + if (prev_out_grad) { + r_prev_out_grad = ((__m256 *)prev_out_grad)[i]; } - opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, - rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, - active_gate); + op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, + r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, + r_reset_output_grad, active_gate); - updateGateGrad[i] = rUpdateGateGrad; - resetGateGrad[i] = rResetGateGrad; - if (prevOutGrad) { - ((__m256 *)prevOutGrad)[i] = rPrevOutGrad; + update_gate_grad[i] = r_update_gate_grad; + reset_gate_grad[i] = r_reset_gate_grad; + if (prev_out_grad) { + ((__m256 *)prev_out_grad)[i] = r_prev_out_grad; } } #endif } template -inline void backward_state_grad(OpStateGrad opStateGrad, hl_gru_value value, - hl_gru_grad grad, int frameSize, - int batchSize, activation_mode_t active_node) { - for (int b = 0; b < batchSize; b++) { - if (OpStateGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { +inline void backward_state_grad(OpStateGrad op_state_grad, + hl_gru_value value, hl_gru_grad grad, + int frame_size, int batch_size, + activation_mode_t active_node) { + for (int b = 0; b < batch_size; b++) { + if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_state_grad( - opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, - grad.prevOutGrad, grad.outputGrad, frameSize, active_node); + op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.output_grad, frame_size, active_node); } else { hl_naive_gru_backward_state_grad( - opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, - grad.prevOutGrad, grad.outputGrad, frameSize, active_node); + op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.output_grad, frame_size, active_node); } - value.gateValue += frameSize * 3; - if (value.prevOutValue) { - value.prevOutValue += frameSize; + value.gate_value += frame_size * 3; + if (value.prev_out_value) { + value.prev_out_value += frame_size; } - grad.gateGrad += frameSize * 3; - grad.outputGrad += frameSize; - if (grad.prevOutGrad) { - grad.prevOutGrad += frameSize; + grad.gate_grad += frame_size * 3; + grad.output_grad += frame_size; + if (grad.prev_out_grad) { + grad.prev_out_grad += frame_size; } } } template -inline void backward_reset_grad(OpResetGrad opResetGrad, hl_gru_value value, - hl_gru_grad grad, int frameSize, - int batchSize, activation_mode_t active_gate) { - for (int b = 0; b < batchSize; b++) { - if (OpResetGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { +inline void backward_reset_grad(OpResetGrad op_reset_grad, + hl_gru_value value, hl_gru_grad grad, + int frame_size, int batch_size, + activation_mode_t active_gate) { + for (int b = 0; b < batch_size; b++) { + if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_reset_grad( - opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, - grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); + op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); } else { hl_naive_gru_backward_reset_grad( - opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, - grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); + op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); } - value.gateValue += frameSize * 3; - if (value.prevOutValue) { - value.prevOutValue += frameSize; + value.gate_value += frame_size * 3; + if (value.prev_out_value) { + value.prev_out_value += frame_size; } - grad.gateGrad += frameSize * 3; - grad.resetOutputGrad += frameSize; - if (grad.prevOutGrad) { - grad.prevOutGrad += frameSize; + grad.gate_grad += frame_size * 3; + grad.reset_output_grad += frame_size; + if (grad.prev_out_grad) { + grad.prev_out_grad += frame_size; } } } diff --git a/paddle/operators/math/detail/gru_gpu_kernel.h b/paddle/operators/math/detail/gru_gpu_kernel.h index 6441c648b..f3983c519 100644 --- a/paddle/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/operators/math/detail/gru_gpu_kernel.h @@ -27,174 +27,174 @@ namespace math { namespace detail { /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template -__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput, - T *gateValue, T *resetOutputValue, - T *prevOutputValue, int frameSize, - int batchSize, +template +__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, + T *gate_value, T *reset_output_value, + T *prev_output_value, int frame_size, + int batch_size, activation_mode_t active_gate) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - gateValue += batchIdx * 3 * frameSize; - resetOutputValue += batchIdx * frameSize; + const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; + if (frame_idx >= frame_size) return; + + int batch_idx = 0; + if (is_batch) { + batch_idx = block_idx.y * block_dim.y + thread_idx.y; + if (batch_idx >= batch_size) return; + gate_value += batch_idx * 3 * frame_size; + reset_output_value += batch_idx * frame_size; } - T rPrevOut = 0; - T rValueResetOutput; - T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; - T rValueResetGate = gateValue[frameIdx + frameSize * 1]; + T r_prev_out = 0; + T r_value_reset_output; + T r_value_update_gate = gate_value[frame_idx + frame_size * 0]; + T r_value_reset_gate = gate_value[frame_idx + frame_size * 1]; - if (prevOutputValue) { - if (isBatch) prevOutputValue += batchIdx * frameSize; - rPrevOut = prevOutputValue[frameIdx]; + if (prev_output_value) { + if (is_batch) prev_output_value += batch_idx * frame_size; + r_prev_out = prev_output_value[frame_idx]; } - opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput, - active_gate); + op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, + r_value_reset_output, active_gate); - gateValue[frameIdx + frameSize * 0] = rValueUpdateGate; - gateValue[frameIdx + frameSize * 1] = rValueResetGate; - resetOutputValue[frameIdx] = rValueResetOutput; + gate_value[frame_idx + frame_size * 0] = r_value_update_gate; + gate_value[frame_idx + frame_size * 1] = r_value_reset_gate; + reset_output_value[frame_idx] = r_value_reset_output; } /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template -__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput, - T *gateValue, T *prevOutputValue, - T *outputValue, int frameSize, - int batchSize, +template +__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, + T *gate_value, T *prev_output_value, + T *output_value, int frame_size, + int batch_size, activation_mode_t active_node) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - gateValue += batchIdx * 3 * frameSize; - outputValue += batchIdx * frameSize; + const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; + if (frame_idx >= frame_size) return; + int batch_idx = 0; + if (is_batch) { + batch_idx = block_idx.y * block_dim.y + thread_idx.y; + if (batch_idx >= batch_size) return; + gate_value += batch_idx * 3 * frame_size; + output_value += batch_idx * frame_size; } - T rOutput; - T rPrevOut = 0; - T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; - T rValueFrameState = gateValue[frameIdx + frameSize * 2]; + T r_output; + T r_prev_out = 0; + T r_value_update_gate = gate_value[frame_idx + frame_size * 0]; + T r_value_frame_state = gate_value[frame_idx + frame_size * 2]; - if (prevOutputValue) { - if (isBatch) prevOutputValue += batchIdx * frameSize; - rPrevOut = prevOutputValue[frameIdx]; + if (prev_output_value) { + if (is_batch) prev_output_value += batch_idx * frame_size; + r_prev_out = prev_output_value[frame_idx]; } - opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, - active_node); + op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, + r_output, active_node); - gateValue[frameIdx + frameSize * 2] = rValueFrameState; - outputValue[frameIdx] = rOutput; + gate_value[frame_idx + frame_size * 2] = r_value_frame_state; + output_value[frame_idx] = r_output; } /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template -__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue, - T *gateGrad, T *prevOutValue, - T *prevOutGrad, T *outputGrad, - int frameSize, int batchSize, +template +__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, + T *gate_grad, T *prev_out_value, + T *prev_out_grad, T *output_grad, + int frame_size, int batch_size, activation_mode_t active_node) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - gateValue += batchIdx * 3 * frameSize; - gateGrad += batchIdx * 3 * frameSize; - outputGrad += batchIdx * frameSize; + const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; + if (frame_idx >= frame_size) return; + int batch_idx = 0; + if (is_batch) { + batch_idx = block_idx.y * block_dim.y + thread_idx.y; + if (batch_idx >= batch_size) return; + gate_value += batch_idx * 3 * frame_size; + gate_grad += batch_idx * 3 * frame_size; + output_grad += batch_idx * frame_size; } - T rUpdateGateGrad; - T rFrameStateGrad; - T rPrevOutValue = 0; - T rPrevOutGrad = 0; - T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; - T rFrameStateValue = gateValue[frameIdx + frameSize * 2]; - T rOutGrad = outputGrad[frameIdx]; + T r_update_gate_grad; + T r_frame_state_grad; + T r_prev_out_value = 0; + T r_prev_out_grad = 0; + T r_update_gate_value = gate_value[frame_idx + frame_size * 0]; + T r_frame_state_value = gate_value[frame_idx + frame_size * 2]; + T r_out_grad = output_grad[frame_idx]; - if (prevOutValue && prevOutGrad) { - if (isBatch) prevOutValue += batchIdx * frameSize; - rPrevOutValue = prevOutValue[frameIdx]; + if (prev_out_value && prev_out_grad) { + if (is_batch) prev_out_value += batch_idx * frame_size; + r_prev_out_value = prev_out_value[frame_idx]; - if (isBatch) prevOutGrad += batchIdx * frameSize; - rPrevOutGrad = prevOutGrad[frameIdx]; + if (is_batch) prev_out_grad += batch_idx * frame_size; + r_prev_out_grad = prev_out_grad[frame_idx]; } - opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, - rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, - active_node); + op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, + r_frame_state_grad, r_prev_out_value, r_prev_out_grad, + r_out_grad, active_node); - gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; - gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad; - if (prevOutGrad) { - prevOutGrad[frameIdx] = rPrevOutGrad; + gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; + gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; + if (prev_out_grad) { + prev_out_grad[frame_idx] = r_prev_out_grad; } } /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template -__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue, - T *gateGrad, T *prevOutValue, - T *prevOutGrad, T *resetOutputGrad, - int frameSize, int batchSize, +template +__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, + T *gate_grad, T *prev_out_value, + T *prev_out_grad, T *reset_output_grad, + int frame_size, int batch_size, activation_mode_t active_gate) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - gateValue += batchIdx * 3 * frameSize; - gateGrad += batchIdx * 3 * frameSize; - resetOutputGrad += batchIdx * frameSize; + const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; + if (frame_idx >= frame_size) return; + int batch_idx = 0; + if (is_batch) { + batch_idx = block_idx.y * block_dim.y + thread_idx.y; + if (batch_idx >= batch_size) return; + gate_value += batch_idx * 3 * frame_size; + gate_grad += batch_idx * 3 * frame_size; + reset_output_grad += batch_idx * frame_size; } - T rResetGateGrad; - T rPrevOutValue = 0; - T rPrevOutGrad = 0; - T rResetOutputGrad = 0; - T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; - T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0]; - T rResetGateValue = gateValue[frameIdx + frameSize * 1]; - - if (prevOutValue && prevOutGrad) { - if (isBatch) prevOutValue += batchIdx * frameSize; - if (isBatch) prevOutGrad += batchIdx * frameSize; - rPrevOutValue = prevOutValue[frameIdx]; - rPrevOutGrad = prevOutGrad[frameIdx]; - rResetOutputGrad = resetOutputGrad[frameIdx]; + T r_reset_gate_grad; + T r_prev_out_value = 0; + T r_prev_out_grad = 0; + T r_reset_output_grad = 0; + T r_update_gate_value = gate_value[frame_idx + frame_size * 0]; + T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0]; + T r_reset_gate_value = gate_value[frame_idx + frame_size * 1]; + + if (prev_out_value && prev_out_grad) { + if (is_batch) prev_out_value += batch_idx * frame_size; + if (is_batch) prev_out_grad += batch_idx * frame_size; + r_prev_out_value = prev_out_value[frame_idx]; + r_prev_out_grad = prev_out_grad[frame_idx]; + r_reset_output_grad = reset_output_grad[frame_idx]; } - opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, - rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, - active_gate); + op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, + r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, + r_reset_output_grad, active_gate); - gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; - gateGrad[frameIdx + frameSize * 1] = rResetGateGrad; - if (prevOutGrad) { - prevOutGrad[frameIdx] = rPrevOutGrad; + gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; + gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad; + if (prev_out_grad) { + prev_out_grad[frame_idx] = r_prev_out_grad; } } } // namespace detail diff --git a/paddle/operators/math/detail/gru_kernel.h b/paddle/operators/math/detail/gru_kernel.h index 8a681d8d8..acd84be01 100644 --- a/paddle/operators/math/detail/gru_kernel.h +++ b/paddle/operators/math/detail/gru_kernel.h @@ -28,23 +28,25 @@ namespace forward { template class gru_resetOutput { public: - HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut, - T &valueResetOutput, activation_mode_t actGate) { - valueUpdateGate = activation(valueUpdateGate, actGate); - valueResetGate = activation(valueResetGate, actGate); - valueResetOutput = prevOut * valueResetGate; + HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate, + T &prev_out, T &value_reset_output, + activation_mode_t act_gate) { + value_update_gate = activation(value_update_gate, act_gate); + value_reset_gate = activation(value_reset_gate, act_gate); + value_reset_output = prev_out * value_reset_gate; } #ifndef __NVCC__ #ifndef __AVX__ static const bool avx = false; #else static const bool avx = true; - HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate, - __m256 &prevOut, __m256 &valueResetOutput, - activation_mode_t actGate) { - valueUpdateGate = activation(valueUpdateGate, actGate); - valueResetGate = activation(valueResetGate, actGate); - valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate); + HOSTDEVICE void operator()(__m256 &value_update_gate, + __m256 &value_reset_gate, __m256 &prev_out, + __m256 &value_reset_output, + activation_mode_t act_gate) { + value_update_gate = activation(value_update_gate, act_gate); + value_reset_gate = activation(value_reset_gate, act_gate); + value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate); } #endif #endif @@ -53,24 +55,26 @@ class gru_resetOutput { template class gru_finalOutput { public: - HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut, - T &valueOutput, activation_mode_t actInput) { - valueFrameState = activation(valueFrameState, actInput); - valueOutput = prevOut - (valueUpdateGate * prevOut) + - (valueUpdateGate * valueFrameState); + HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state, + T &prev_out, T &value_output, + activation_mode_t act_input) { + value_frame_state = activation(value_frame_state, act_input); + value_output = prev_out - (value_update_gate * prev_out) + + (value_update_gate * value_frame_state); } #ifndef __NVCC__ #ifndef __AVX__ static const bool avx = false; #else static const bool avx = true; - HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState, - __m256 &prevOut, __m256 &valueOutput, - activation_mode_t actInput) { - valueFrameState = activation(valueFrameState, actInput); - valueOutput = _mm256_add_ps( - _mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)), - _mm256_mul_ps(valueUpdateGate, valueFrameState)); + HOSTDEVICE void operator()(__m256 &value_update_gate, + __m256 &value_frame_state, __m256 &prev_out, + __m256 &value_output, + activation_mode_t act_input) { + value_frame_state = activation(value_frame_state, act_input); + value_output = _mm256_add_ps( + _mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)), + _mm256_mul_ps(value_update_gate, value_frame_state)); } #endif #endif @@ -82,34 +86,37 @@ namespace backward { template class gru_stateGrad { public: - HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, - T &valueFrameState, T &gradFrameState, - T &valuePrevOut, T &gradPrevOut, T &gradOutput, - activation_mode_t actInput) { - gradUpdateGate = (gradOutput * valueFrameState); - gradUpdateGate -= (gradOutput * valuePrevOut); - gradPrevOut -= (gradOutput * valueUpdateGate); - gradPrevOut += gradOutput; - gradFrameState = - activation(gradOutput * valueUpdateGate, valueFrameState, actInput); + HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, + T &value_frame_state, T &grad_frame_state, + T &value_prev_out, T &grad_prev_out, + T &grad_output, activation_mode_t act_input) { + grad_update_gate = (grad_output * value_frame_state); + grad_update_gate -= (grad_output * value_prev_out); + grad_prev_out -= (grad_output * value_update_gate); + grad_prev_out += grad_output; + grad_frame_state = activation(grad_output * value_update_gate, + value_frame_state, act_input); } #ifndef __NVCC__ #ifndef __AVX__ static const bool avx = false; #else static const bool avx = true; - HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, - __m256 &valueFrameState, __m256 &gradFrameState, - __m256 &valuePrevOut, __m256 &gradPrevOut, - __m256 &gradOutput, activation_mode_t actInput) { - gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState); - gradUpdateGate = - _mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut)); - gradPrevOut = _mm256_add_ps( - _mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)), - gradOutput); - gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate), - valueFrameState, actInput); + HOSTDEVICE void operator()(__m256 &value_update_gate, + __m256 &grad_update_gate, + __m256 &value_frame_state, + __m256 &grad_frame_state, __m256 &value_prev_out, + __m256 &grad_prev_out, __m256 &grad_output, + activation_mode_t act_input) { + grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state); + grad_update_gate = _mm256_sub_ps( + grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out)); + grad_prev_out = _mm256_add_ps( + _mm256_sub_ps(grad_prev_out, + _mm256_mul_ps(grad_output, value_update_gate)), + grad_output); + grad_frame_state = activation(_mm256_mul_ps(grad_output, value_update_gate), + value_frame_state, act_input); } #endif #endif @@ -118,30 +125,32 @@ class gru_stateGrad { template class gru_resetGrad { public: - HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, - T &valueResetGate, T &gradResetGate, - T &valuePrevOut, T &gradPrevOut, - T &gradResetOutput, activation_mode_t actGate) { - gradResetGate = (gradResetOutput * valuePrevOut); - gradPrevOut += (gradResetOutput * valueResetGate); - gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate); - gradResetGate = activation(gradResetGate, valueResetGate, actGate); + HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, + T &value_reset_gate, T &grad_reset_gate, + T &value_prev_out, T &grad_prev_out, + T &grad_reset_output, activation_mode_t act_gate) { + grad_reset_gate = (grad_reset_output * value_prev_out); + grad_prev_out += (grad_reset_output * value_reset_gate); + grad_update_gate = + activation(grad_update_gate, value_update_gate, act_gate); + grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate); } #ifndef __NVCC__ #ifndef __AVX__ static const bool avx = false; #else static const bool avx = true; - HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, - __m256 &valueResetGate, __m256 &gradResetGate, - __m256 &valuePrevOut, __m256 &gradPrevOut, - __m256 &gradResetOutput, - activation_mode_t actGate) { - gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut); - gradPrevOut = _mm256_add_ps(gradPrevOut, - _mm256_mul_ps(gradResetOutput, valueResetGate)); - gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate); - gradResetGate = activation(gradResetGate, valueResetGate, actGate); + HOSTDEVICE void operator()(__m256 &value_update_gate, + __m256 &grad_update_gate, __m256 &value_reset_gate, + __m256 &grad_reset_gate, __m256 &value_prev_out, + __m256 &grad_prev_out, __m256 &grad_reset_output, + activation_mode_t act_gate) { + grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out); + grad_prev_out = _mm256_add_ps( + grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate)); + grad_update_gate = + activation(grad_update_gate, value_update_gate, act_gate); + grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate); } #endif #endif diff --git a/paddle/operators/math/gru_compute.cc b/paddle/operators/math/gru_compute.cc index 125af449d..ae4e47b01 100644 --- a/paddle/operators/math/gru_compute.cc +++ b/paddle/operators/math/gru_compute.cc @@ -21,29 +21,29 @@ namespace math { template struct GRUUnitFunctor { static void compute(const platform::DeviceContext &context, - hl_gru_value value, int frameSize, int batchSize, + hl_gru_value value, int frame_size, int batch_size, activation_mode_t active_node, activation_mode_t active_gate) { #ifndef __NVCC__ - if (value.prevOutValue) { + if (value.prev_out_value) { math::gemm( - context, false, false, batchSize, frameSize * 2, frameSize, 1, - value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, - value.gateValue, frameSize * 3); + context, false, false, batch_size, frame_size * 2, frame_size, 1, + value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, + 1, value.gate_value, frame_size * 3); } detail::forward_reset_output(detail::forward::gru_resetOutput(), value, - frameSize, batchSize, active_gate); + frame_size, batch_size, active_gate); - if (value.prevOutValue) { + if (value.prev_out_value) { math::gemm( - context, false, false, batchSize, frameSize, frameSize, 1, - value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, - value.gateValue + frameSize * 2, frameSize * 3); + context, false, false, batch_size, frame_size, frame_size, 1, + value.reset_output_value, frame_size, value.state_weight, frame_size, + 1, value.gate_value + frame_size * 2, frame_size * 3); } detail::forward_final_output(detail::forward::gru_finalOutput(), value, - frameSize, batchSize, active_node); + frame_size, batch_size, active_node); #endif } }; @@ -51,41 +51,43 @@ struct GRUUnitFunctor { template struct GRUUnitGradFunctor { static void compute(const platform::DeviceContext &context, - hl_gru_value value, hl_gru_grad grad, int frameSize, - int batchSize, activation_mode_t active_node, + hl_gru_value value, hl_gru_grad grad, + int frame_size, int batch_size, + activation_mode_t active_node, activation_mode_t active_gate) { #ifndef __NVCC__ detail::backward_state_grad(detail::backward::gru_stateGrad(), value, - grad, frameSize, batchSize, active_node); + grad, frame_size, batch_size, active_node); - if (value.prevOutValue && grad.prevOutGrad) { + if (value.prev_out_value && grad.prev_out_grad) { math::gemm( - context, false, true, batchSize, frameSize, frameSize, 1, - grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, - frameSize, 0, grad.resetOutputGrad, frameSize); + context, false, true, batch_size, frame_size, frame_size, 1, + grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight, + frame_size, 0, grad.reset_output_grad, frame_size); - if (grad.stateWeightGrad) { + if (grad.state_weight_grad) { math::gemm( - context, true, false, frameSize, frameSize, batchSize, 1, - value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, - frameSize * 3, 1, grad.stateWeightGrad, frameSize); + context, true, false, frame_size, frame_size, batch_size, 1, + value.reset_output_value, frame_size, + grad.gate_grad + frame_size * 2, frame_size * 3, 1, + grad.state_weight_grad, frame_size); } } detail::backward_reset_grad(detail::backward::gru_resetGrad(), value, - grad, frameSize, batchSize, active_gate); + grad, frame_size, batch_size, active_gate); - if (grad.prevOutGrad && value.prevOutValue) { + if (grad.prev_out_grad && value.prev_out_value) { math::gemm( - context, false, true, batchSize, frameSize, frameSize * 2, 1, - grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, - grad.prevOutGrad, frameSize); + context, false, true, batch_size, frame_size, frame_size * 2, 1, + grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1, + grad.prev_out_grad, frame_size); - if (grad.gateWeightGrad) { + if (grad.gate_weight_grad) { math::gemm( - context, true, false, frameSize, frameSize * 2, batchSize, 1, - value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, - grad.gateWeightGrad, frameSize * 2); + context, true, false, frame_size, frame_size * 2, batch_size, 1, + value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1, + grad.gate_weight_grad, frame_size * 2); } } #endif diff --git a/paddle/operators/math/gru_compute.cu b/paddle/operators/math/gru_compute.cu index 7b9e54ac0..0252bdbdb 100644 --- a/paddle/operators/math/gru_compute.cu +++ b/paddle/operators/math/gru_compute.cu @@ -21,66 +21,66 @@ namespace math { template struct GRUUnitFunctor { static void compute(const platform::DeviceContext &context, - hl_gru_value value, int frameSize, int batchSize, + hl_gru_value value, int frame_size, int batch_size, activation_mode_t active_node, activation_mode_t active_gate) { auto stream = reinterpret_cast(context).stream(); dim3 threads; dim3 grid; - if (batchSize == 1) { - int framePerBlock = frameSize <= 1024 ? frameSize : 1024; - int frameBlocks = (frameSize + 1024 - 1) / 1024; - threads = dim3(framePerBlock, 1); - grid = dim3(frameBlocks, 1); + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); } else { threads = dim3(32, 32); - grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); } - if (value.prevOutValue) { + if (value.prev_out_value) { math::gemm( - context, false, false, batchSize, frameSize * 2, frameSize, 1, - value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, - value.gateValue, frameSize * 3); + context, false, false, batch_size, frame_size * 2, frame_size, 1, + value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, + 1, value.gate_value, frame_size * 3); } - if (batchSize == 1) { + if (batch_size == 1) { detail::KeGruForwardResetOutput, - /* isBatch= */ false, + /* is_batch= */ false, T><<>>( - detail::forward::gru_resetOutput(), value.gateValue, - value.resetOutputValue, value.prevOutValue, frameSize, batchSize, - active_gate); + detail::forward::gru_resetOutput(), value.gate_value, + value.reset_output_value, value.prev_out_value, frame_size, + batch_size, active_gate); } else { detail::KeGruForwardResetOutput, - /* isBatch= */ true, + /* is_batch= */ true, T><<>>( - detail::forward::gru_resetOutput(), value.gateValue, - value.resetOutputValue, value.prevOutValue, frameSize, batchSize, - active_gate); + detail::forward::gru_resetOutput(), value.gate_value, + value.reset_output_value, value.prev_out_value, frame_size, + batch_size, active_gate); } - if (value.prevOutValue) { + if (value.prev_out_value) { math::gemm( - context, false, false, batchSize, frameSize, frameSize, 1, - value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, - value.gateValue + frameSize * 2, frameSize * 3); + context, false, false, batch_size, frame_size, frame_size, 1, + value.reset_output_value, frame_size, value.state_weight, frame_size, + 1, value.gate_value + frame_size * 2, frame_size * 3); } - if (batchSize == 1) { + if (batch_size == 1) { detail::KeGruForwardFinalOutput, - /* isBatch= */ false, + /* is_batch= */ false, T><<>>( - detail::forward::gru_finalOutput(), value.gateValue, - value.prevOutValue, value.outputValue, frameSize, batchSize, + detail::forward::gru_finalOutput(), value.gate_value, + value.prev_out_value, value.output_value, frame_size, batch_size, active_node); } else { detail::KeGruForwardFinalOutput, - /* isBatch= */ true, + /* is_batch= */ true, T><<>>( - detail::forward::gru_finalOutput(), value.gateValue, - value.prevOutValue, value.outputValue, frameSize, batchSize, + detail::forward::gru_finalOutput(), value.gate_value, + value.prev_out_value, value.output_value, frame_size, batch_size, active_node); } } @@ -89,80 +89,82 @@ struct GRUUnitFunctor { template struct GRUUnitGradFunctor { static void compute(const platform::DeviceContext &context, - hl_gru_value value, hl_gru_grad grad, int frameSize, - int batchSize, activation_mode_t active_node, + hl_gru_value value, hl_gru_grad grad, + int frame_size, int batch_size, + activation_mode_t active_node, activation_mode_t active_gate) { auto stream = reinterpret_cast(context).stream(); dim3 threads; dim3 grid; - if (batchSize == 1) { - int framePerBlock = frameSize <= 1024 ? frameSize : 1024; - int frameBlocks = (frameSize + 1024 - 1) / 1024; - threads = dim3(framePerBlock, 1); - grid = dim3(frameBlocks, 1); + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); } else { threads = dim3(32, 32); - grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); } - if (batchSize == 1) { + if (batch_size == 1) { detail::KeGruBackwardStateGrad< detail::backward::gru_stateGrad, - /* isBatch= */ false><<>>( - detail::backward::gru_stateGrad(), value.gateValue, grad.gateGrad, - value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, - batchSize, active_node); + /* is_batch= */ false><<>>( + detail::backward::gru_stateGrad(), value.gate_value, + grad.gate_grad, value.prev_out_value, grad.prev_out_grad, + grad.output_grad, frame_size, batch_size, active_node); } else { detail::KeGruBackwardStateGrad< detail::backward::gru_stateGrad, - /* isBatch= */ true><<>>( - detail::backward::gru_stateGrad(), value.gateValue, grad.gateGrad, - value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, - batchSize, active_node); + /* is_batch= */ true><<>>( + detail::backward::gru_stateGrad(), value.gate_value, + grad.gate_grad, value.prev_out_value, grad.prev_out_grad, + grad.output_grad, frame_size, batch_size, active_node); } - if (value.prevOutValue && grad.prevOutGrad) { + if (value.prev_out_value && grad.prev_out_grad) { math::gemm( - context, false, true, batchSize, frameSize, frameSize, 1, - grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, - frameSize, 0, grad.resetOutputGrad, frameSize); + context, false, true, batch_size, frame_size, frame_size, 1, + grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight, + frame_size, 0, grad.reset_output_grad, frame_size); - if (grad.stateWeightGrad) { + if (grad.state_weight_grad) { math::gemm( - context, true, false, frameSize, frameSize, batchSize, 1, - value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, - frameSize * 3, 1, grad.stateWeightGrad, frameSize); + context, true, false, frame_size, frame_size, batch_size, 1, + value.reset_output_value, frame_size, + grad.gate_grad + frame_size * 2, frame_size * 3, 1, + grad.state_weight_grad, frame_size); } } - if (batchSize == 1) { + if (batch_size == 1) { detail::KeGruBackwardResetGrad< detail::backward::gru_resetGrad, - /* isBatch= */ false><<>>( - detail::backward::gru_resetGrad(), value.gateValue, grad.gateGrad, - value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, - batchSize, active_gate); + /* is_batch= */ false><<>>( + detail::backward::gru_resetGrad(), value.gate_value, + grad.gate_grad, value.prev_out_value, grad.prev_out_grad, + grad.reset_output_grad, frame_size, batch_size, active_gate); } else { detail::KeGruBackwardResetGrad< detail::backward::gru_resetGrad, - /* isBatch= */ true><<>>( - detail::backward::gru_resetGrad(), value.gateValue, grad.gateGrad, - value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, - batchSize, active_gate); + /* is_batch= */ true><<>>( + detail::backward::gru_resetGrad(), value.gate_value, + grad.gate_grad, value.prev_out_value, grad.prev_out_grad, + grad.reset_output_grad, frame_size, batch_size, active_gate); } - if (grad.prevOutGrad && value.prevOutValue) { + if (grad.prev_out_grad && value.prev_out_value) { math::gemm( - context, false, true, batchSize, frameSize, frameSize * 2, 1, - grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, - grad.prevOutGrad, frameSize); + context, false, true, batch_size, frame_size, frame_size * 2, 1, + grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1, + grad.prev_out_grad, frame_size); - if (grad.gateWeightGrad) { + if (grad.gate_weight_grad) { math::gemm( - context, true, false, frameSize, frameSize * 2, batchSize, 1, - value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, - grad.gateWeightGrad, frameSize * 2); + context, true, false, frame_size, frame_size * 2, batch_size, 1, + value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1, + grad.gate_weight_grad, frame_size * 2); } } } diff --git a/paddle/operators/math/gru_compute.h b/paddle/operators/math/gru_compute.h index 1475fb381..58ea59f68 100644 --- a/paddle/operators/math/gru_compute.h +++ b/paddle/operators/math/gru_compute.h @@ -22,28 +22,28 @@ namespace math { // TODO(guosheng): refine code style in gru_compute template struct hl_gru_value { - T *gateWeight; - T *stateWeight; - T *gateValue; - T *resetOutputValue; - T *outputValue; - T *prevOutValue; + T *gate_weight; + T *state_weight; + T *gate_value; + T *reset_output_value; + T *output_value; + T *prev_out_value; }; template struct hl_gru_grad { - T *gateWeightGrad; - T *stateWeightGrad; - T *gateGrad; - T *resetOutputGrad; - T *outputGrad; - T *prevOutGrad; + T *gate_weight_grad; + T *state_weight_grad; + T *gate_grad; + T *reset_output_grad; + T *output_grad; + T *prev_out_grad; }; template struct GRUUnitFunctor { static void compute(const platform::DeviceContext &context, - hl_gru_value value, int frameSize, int batchSize, + hl_gru_value value, int frame_size, int batch_size, activation_mode_t active_node, activation_mode_t active_gate); }; @@ -51,8 +51,9 @@ struct GRUUnitFunctor { template struct GRUUnitGradFunctor { static void compute(const platform::DeviceContext &context, - hl_gru_value value, hl_gru_grad grad, int frameSize, - int batchSize, activation_mode_t active_node, + hl_gru_value value, hl_gru_grad grad, + int frame_size, int batch_size, + activation_mode_t active_node, activation_mode_t active_gate); }; -- GitLab