From e5b51c4d102ed180aef3940bd8e885c4bf5f9d95 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Sun, 3 Dec 2017 16:50:24 +0800 Subject: [PATCH] Make lstm_op follow google code style. --- paddle/operators/lstm_op.h | 70 +-- .../operators/math/detail/lstm_cpu_kernel.h | 426 +++++++++--------- .../operators/math/detail/lstm_gpu_kernel.h | 305 ++++++------- paddle/operators/math/detail/lstm_kernel.h | 128 +++--- paddle/operators/math/lstm_compute.cc | 36 +- paddle/operators/math/lstm_compute.h | 32 +- 6 files changed, 505 insertions(+), 492 deletions(-) diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 721aa42c92..a78f548aaf 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel { T* bias_data = const_cast(bias->data()); // the code style in LstmMetaValue will be updated later. - lstm_value.checkIg = bias_data + 4 * frame_size; - lstm_value.checkFg = lstm_value.checkIg + frame_size; - lstm_value.checkOg = lstm_value.checkFg + frame_size; + lstm_value.check_ig = bias_data + 4 * frame_size; + lstm_value.check_fg = lstm_value.check_ig + frame_size; + lstm_value.check_og = lstm_value.check_fg + frame_size; } else { - lstm_value.checkIg = nullptr; - lstm_value.checkFg = nullptr; - lstm_value.checkOg = nullptr; + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; } - lstm_value.prevStateValue = nullptr; + lstm_value.prev_state_value = nullptr; Tensor ordered_c0; const size_t* order = batch_gate->lod()[2].data(); if (cell_t0) { @@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel { // to reorder. ReorderInitState(device_ctx, *cell_t0, order, &ordered_c0, true); - lstm_value.prevStateValue = ordered_c0.data(); + lstm_value.prev_state_value = ordered_c0.data(); } // Use the local variable as here. @@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel { static_cast(1.0)); } - lstm_value.gateValue = gate_t.data(); - lstm_value.outputValue = out_t.data(); - lstm_value.stateValue = cell_t.data(); - lstm_value.stateActiveValue = cell_pre_act_t.data(); + lstm_value.gate_value = gate_t.data(); + lstm_value.output_value = out_t.data(); + lstm_value.state_value = cell_t.data(); + lstm_value.state_active_value = cell_pre_act_t.data(); math::LstmUnitFunctor::compute(device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, cand_act); - lstm_value.prevStateValue = lstm_value.stateValue; + lstm_value.prev_state_value = lstm_value.state_value; } math::Batch2LoDTensorFunctor to_seq; @@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel { math::LstmMetaValue lstm_value; if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); - lstm_value.checkIg = bias_data + 4 * frame_size; - lstm_value.checkFg = lstm_value.checkIg + frame_size; - lstm_value.checkOg = lstm_value.checkFg + frame_size; + lstm_value.check_ig = bias_data + 4 * frame_size; + lstm_value.check_fg = lstm_value.check_ig + frame_size; + lstm_value.check_og = lstm_value.check_fg + frame_size; } else { - lstm_value.checkIg = nullptr; - lstm_value.checkFg = nullptr; - lstm_value.checkOg = nullptr; + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; } math::LstmMetaGrad lstm_grad; @@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel { } if (bias && bias_g && ctx.Attr("use_peepholes")) { T* bias_g_data = bias_g->data(); - lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; - lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; - lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; + lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size; + lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size; + lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size; } else { - lstm_grad.checkIgGrad = nullptr; - lstm_grad.checkFgGrad = nullptr; - lstm_grad.checkOgGrad = nullptr; + lstm_grad.check_ig_grad = nullptr; + lstm_grad.check_fg_grad = nullptr; + lstm_grad.check_og_grad = nullptr; } math::LoDTensor2BatchFunctor to_batch; @@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel { Tensor gate = batch_gate->Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend); Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); - lstm_value.gateValue = gate.data(); - lstm_value.stateValue = cell.data(); - lstm_value.stateActiveValue = cell_pre_act.data(); + lstm_value.gate_value = gate.data(); + lstm_value.state_value = cell.data(); + lstm_value.state_active_value = cell_pre_act.data(); Tensor out_g = batch_hidden_g.Slice(bstart, bend); Tensor gate_g = batch_gate_g.Slice(bstart, bend); Tensor cell_g = batch_cell_g.Slice(bstart, bend); - lstm_grad.stateGrad = cell_g.data(); - lstm_grad.gateGrad = gate_g.data(); - lstm_grad.outputGrad = out_g.data(); + lstm_grad.state_grad = cell_g.data(); + lstm_grad.gate_grad = gate_g.data(); + lstm_grad.output_grad = out_g.data(); if (n > 0) { int bstart_pre = static_cast(batch_starts[n - 1]); Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); - lstm_value.prevStateValue = cell_pre.data(); - lstm_grad.prevStateGrad = cell_pre_g.data(); + lstm_value.prev_state_value = cell_pre.data(); + lstm_grad.prev_state_grad = cell_pre_g.data(); } else { - lstm_value.prevStateValue = c0 ? ordered_c0.data() : nullptr; - lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data() : nullptr; + lstm_value.prev_state_value = c0 ? ordered_c0.data() : nullptr; + lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data() : nullptr; } int cur_batch_size = bend - bstart; diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index fc3ad0ce58..a734ad31ee 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -26,278 +26,284 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frameSize, + int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - T rCheckI; - T rCheckF; - T rCheckO; - T rState; - T rPrevState = 0; - T rStateAtv; - T rOut; - - T *valueIn = value.gateValue; - T *valueIg = value.gateValue + frameSize; - T *valueFg = value.gateValue + frameSize * 2; - T *valueOg = value.gateValue + frameSize * 3; - - for (int i = 0; i < frameSize; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - rCheckI = value.checkIg ? value.checkIg[i] : 0; - rCheckF = value.checkFg ? value.checkFg[i] : 0; - rCheckO = value.checkOg ? value.checkOg[i] : 0; - - if (value.prevStateValue) { - rPrevState = value.prevStateValue[i]; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_checkI; + T r_checkF; + T r_checkO; + T r_state; + T r_prev_state = 0; + T r_state_atv; + T r_out; + + T *value_in = value.gate_value; + T *value_ig = value.gate_value + frame_size; + T *value_fg = value.gate_value + frame_size * 2; + T *value_og = value.gate_value + frame_size * 3; + + for (int i = 0; i < frame_size; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + r_checkI = value.check_ig ? value.check_ig[i] : 0; + r_checkF = value.check_fg ? value.check_fg[i] : 0; + r_checkO = value.check_og ? value.check_og[i] : 0; + + if (value.prev_state_value) { + r_prev_state = value.prev_state_value[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); - - valueIn[i] = rValueIn; - valueIg[i] = rValueIg; - valueFg[i] = rValueFg; - valueOg[i] = rValueOg; - value.stateValue[i] = rState; - value.stateActiveValue[i] = rStateAtv; - value.outputValue[i] = rOut; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, + r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, + active_gate, active_state); + + value_in[i] = r_value_in; + value_ig[i] = r_value_ig; + value_fg[i] = r_value_fg; + value_og[i] = r_value_og; + value.state_value[i] = r_state; + value.state_active_value[i] = r_state_atv; + value.output_value[i] = r_out; } } template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, + LstmMetaGrad grad, int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - T rGradIn; - T rGradIg; - T rGradFg; - T rGradOg; - T rPrevState = 0; - T rPrevStateGrad; - T rState; - T rStateGrad; - T rStateAtv; - T rOutputGrad; - T rCheckI; - T rCheckF; - T rCheckO; - T rCheckIGrad; - T rCheckFGrad; - T rCheckOGrad; - - T *valueIn = value.gateValue; - T *valueIg = value.gateValue + frameSize; - T *valueFg = value.gateValue + frameSize * 2; - T *valueOg = value.gateValue + frameSize * 3; - T *gradIn = grad.gateGrad; - T *gradIg = grad.gateGrad + frameSize; - T *gradFg = grad.gateGrad + frameSize * 2; - T *gradOg = grad.gateGrad + frameSize * 3; - - for (int i = 0; i < frameSize; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - rCheckI = value.checkIg ? value.checkIg[i] : 0; - rCheckF = value.checkFg ? value.checkFg[i] : 0; - rCheckO = value.checkOg ? value.checkOg[i] : 0; - rState = value.stateValue[i]; - rStateAtv = value.stateActiveValue[i]; - rOutputGrad = grad.outputGrad[i]; - rStateGrad = grad.stateGrad[i]; - if (value.prevStateValue) { - rPrevState = value.prevStateValue[i]; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_grad_in; + T r_grad_ig; + T r_grad_fg; + T r_grad_og; + T r_prev_state = 0; + T r_prev_state_grad; + T r_state; + T r_state_grad; + T r_state_atv; + T r_output_grad; + T r_checkI; + T r_checkF; + T r_checkO; + T r_checkIGrad; + T r_checkFGrad; + T r_checkOGrad; + + T *value_in = value.gate_value; + T *value_ig = value.gate_value + frame_size; + T *value_fg = value.gate_value + frame_size * 2; + T *value_og = value.gate_value + frame_size * 3; + T *grad_in = grad.gate_grad; + T *grad_ig = grad.gate_grad + frame_size; + T *grad_fg = grad.gate_grad + frame_size * 2; + T *grad_og = grad.gate_grad + frame_size * 3; + + for (int i = 0; i < frame_size; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + r_checkI = value.check_ig ? value.check_ig[i] : 0; + r_checkF = value.check_fg ? value.check_fg[i] : 0; + r_checkO = value.check_og ? value.check_og[i] : 0; + r_state = value.state_value[i]; + r_state_atv = value.state_active_value[i]; + r_output_grad = grad.output_grad[i]; + r_state_grad = grad.state_grad[i]; + if (value.prev_state_value) { + r_prev_state = value.prev_state_value[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, - rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, - rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, active_node, active_gate, active_state); - - gradIn[i] = rGradIn; - gradIg[i] = rGradIg; - gradFg[i] = rGradFg; - gradOg[i] = rGradOg; - grad.stateGrad[i] = rStateGrad; - - if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad; - if (value.prevStateValue) { - if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad; - if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, + r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, + r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, + r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, + active_state); + + grad_in[i] = r_grad_in; + grad_ig[i] = r_grad_ig; + grad_fg[i] = r_grad_fg; + grad_og[i] = r_grad_og; + grad.state_grad[i] = r_state_grad; + + if (grad.prev_state_grad) grad.prev_state_grad[i] = r_prev_state_grad; + if (value.prev_state_value) { + if (grad.check_ig_grad) grad.check_ig_grad[i] += r_checkIGrad; + if (grad.check_fg_grad) grad.check_fg_grad[i] += r_checkFGrad; } - if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad; + if (grad.check_og_grad) grad.check_og_grad[i] += r_checkOGrad; } } template -void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, +void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, + int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { #ifdef __AVX__ - __m256 rValueIn; - __m256 rValueIg; - __m256 rValueFg; - __m256 rValueOg; - __m256 rCheckI = _mm256_set1_ps(0.0f); - __m256 rCheckF = _mm256_set1_ps(0.0f); - __m256 rCheckO = _mm256_set1_ps(0.0f); - __m256 rState; - __m256 rPrevState = _mm256_set1_ps(0.0f); - __m256 rStateAtv; - __m256 rOut; - - __m256 *valueIn = (__m256 *)value.gateValue; - __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); - __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); - __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); - - for (int i = 0; i < frameSize / 8; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - if (value.checkIg) { - rCheckI = ((__m256 *)value.checkIg)[i]; - rCheckF = ((__m256 *)value.checkFg)[i]; - rCheckO = ((__m256 *)value.checkOg)[i]; + __m256 r_value_in; + __m256 r_value_ig; + __m256 r_value_fg; + __m256 r_value_og; + __m256 r_checkI = _mm256_set1_ps(0.0f); + __m256 r_checkF = _mm256_set1_ps(0.0f); + __m256 r_checkO = _mm256_set1_ps(0.0f); + __m256 r_state; + __m256 r_prev_state = _mm256_set1_ps(0.0f); + __m256 r_state_atv; + __m256 r_out; + + __m256 *value_in = (__m256 *)value.gate_value; + __m256 *value_ig = (__m256 *)(value.gate_value + frame_size); + __m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2); + __m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + if (value.check_ig) { + r_checkI = ((__m256 *)value.check_ig)[i]; + r_checkF = ((__m256 *)value.check_fg)[i]; + r_checkO = ((__m256 *)value.check_og)[i]; } - if (value.prevStateValue) { - rPrevState = ((__m256 *)value.prevStateValue)[i]; + if (value.prev_state_value) { + r_prev_state = ((__m256 *)value.prev_state_value)[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); - - valueIn[i] = rValueIn; - valueIg[i] = rValueIg; - valueFg[i] = rValueFg; - valueOg[i] = rValueOg; - ((__m256 *)value.stateValue)[i] = rState; - ((__m256 *)value.stateActiveValue)[i] = rStateAtv; - ((__m256 *)value.outputValue)[i] = rOut; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, + r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, + active_gate, active_state); + + value_in[i] = r_value_in; + value_ig[i] = r_value_ig; + value_fg[i] = r_value_fg; + value_og[i] = r_value_og; + ((__m256 *)value.state_value)[i] = r_state; + ((__m256 *)value.state_active_value)[i] = r_state_atv; + ((__m256 *)value.output_value)[i] = r_out; } #endif } template void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, + LstmMetaGrad grad, int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { #ifdef __AVX__ - __m256 rValueIn; - __m256 rValueIg; - __m256 rValueFg; - __m256 rValueOg; - __m256 rGradIn; - __m256 rGradIg; - __m256 rGradFg; - __m256 rGradOg; - __m256 rPrevState = _mm256_set1_ps(0.0f); - __m256 rPrevStateGrad; - __m256 rStateGrad; - __m256 rState; - __m256 rStateAtv; - __m256 rOutputGrad; - __m256 rCheckI = _mm256_set1_ps(0.0f); - __m256 rCheckF = _mm256_set1_ps(0.0f); - __m256 rCheckO = _mm256_set1_ps(0.0f); - __m256 rCheckIGrad; - __m256 rCheckFGrad; - __m256 rCheckOGrad; - - __m256 *valueIn = (__m256 *)value.gateValue; - __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); - __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); - __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); - __m256 *gradIn = (__m256 *)grad.gateGrad; - __m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize); - __m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2); - __m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3); - - for (int i = 0; i < frameSize / 8; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - if (value.checkIg) { - rCheckI = ((__m256 *)value.checkIg)[i]; - rCheckF = ((__m256 *)value.checkFg)[i]; - rCheckO = ((__m256 *)value.checkOg)[i]; + __m256 r_value_in; + __m256 r_value_ig; + __m256 r_value_fg; + __m256 r_value_og; + __m256 r_grad_in; + __m256 r_grad_ig; + __m256 r_grad_fg; + __m256 r_grad_og; + __m256 r_prev_state = _mm256_set1_ps(0.0f); + __m256 r_prev_state_grad; + __m256 r_state_grad; + __m256 r_state; + __m256 r_state_atv; + __m256 r_output_grad; + __m256 r_checkI = _mm256_set1_ps(0.0f); + __m256 r_checkF = _mm256_set1_ps(0.0f); + __m256 r_checkO = _mm256_set1_ps(0.0f); + __m256 r_checkIGrad; + __m256 r_checkFGrad; + __m256 r_checkOGrad; + + __m256 *value_in = (__m256 *)value.gate_value; + __m256 *value_ig = (__m256 *)(value.gate_value + frame_size); + __m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2); + __m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3); + __m256 *grad_in = (__m256 *)grad.gate_grad; + __m256 *grad_ig = (__m256 *)(grad.gate_grad + frame_size); + __m256 *grad_fg = (__m256 *)(grad.gate_grad + frame_size * 2); + __m256 *grad_og = (__m256 *)(grad.gate_grad + frame_size * 3); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + if (value.check_ig) { + r_checkI = ((__m256 *)value.check_ig)[i]; + r_checkF = ((__m256 *)value.check_fg)[i]; + r_checkO = ((__m256 *)value.check_og)[i]; } - rState = ((__m256 *)value.stateValue)[i]; - rStateAtv = ((__m256 *)value.stateActiveValue)[i]; - rOutputGrad = ((__m256 *)grad.outputGrad)[i]; - rStateGrad = ((__m256 *)grad.stateGrad)[i]; - if (value.prevStateValue) { - rPrevState = ((__m256 *)value.prevStateValue)[i]; + r_state = ((__m256 *)value.state_value)[i]; + r_state_atv = ((__m256 *)value.state_active_value)[i]; + r_output_grad = ((__m256 *)grad.output_grad)[i]; + r_state_grad = ((__m256 *)grad.state_grad)[i]; + if (value.prev_state_value) { + r_prev_state = ((__m256 *)value.prev_state_value)[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, - rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, - rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, active_node, active_gate, active_state); - - gradIn[i] = rGradIn; - gradIg[i] = rGradIg; - gradFg[i] = rGradFg; - gradOg[i] = rGradOg; - ((__m256 *)grad.stateGrad)[i] = rStateGrad; - - if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad; - if (value.prevStateValue) { - if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad; - if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, + r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, + r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, + r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, + active_state); + + grad_in[i] = r_grad_in; + grad_ig[i] = r_grad_ig; + grad_fg[i] = r_grad_fg; + grad_og[i] = r_grad_og; + ((__m256 *)grad.state_grad)[i] = r_state_grad; + + if (grad.prev_state_grad) + ((__m256 *)grad.prev_state_grad)[i] = r_prev_state_grad; + if (value.prev_state_value) { + if (grad.check_ig_grad) ((__m256 *)grad.check_ig_grad)[i] += r_checkIGrad; + if (grad.check_fg_grad) ((__m256 *)grad.check_fg_grad)[i] += r_checkFGrad; } - if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad; + if (grad.check_og_grad) ((__m256 *)grad.check_og_grad)[i] += r_checkOGrad; } #endif } template -void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, +void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { - avx_lstm_forward_one_sequence(op, value, frameSize, active_node, + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_forward_one_sequence(op, value, frame_size, active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frameSize, active_node, + naive_lstm_forward_one_sequence(op, value, frame_size, active_node, active_gate, active_state); } } template void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, - int frameSize, activation_mode_t active_node, + int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { - avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_backward_one_sequence(op, value, grad, frame_size, active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, - active_gate, active_state); + naive_lstm_backward_one_sequence(op, value, grad, frame_size, + active_node, active_gate, active_state); } } diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index d138bbe411..91bfedea53 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -26,189 +26,192 @@ namespace math { namespace detail { /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template -__global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, - int batchSize, activation_mode_t active_node, +template +__global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, + int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - value.gateValue += batchIdx * frameSize * 4; - value.outputValue += batchIdx * frameSize; - value.stateValue += batchIdx * frameSize; - value.stateActiveValue += batchIdx * frameSize; + const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (frame_idx >= frame_size) return; + + int batch_idx = 0; + if (is_batch) { + batch_idx = blockIdx.y * blockDim.y + threadIdx.y; + if (batch_idx >= batch_size) return; + value.gate_value += batch_idx * frame_size * 4; + value.output_value += batch_idx * frame_size; + value.state_value += batch_idx * frame_size; + value.state_active_value += batch_idx * frame_size; } - T rState; - T rPrevState = 0; - T rStateAtv; - T rOut; - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - - T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; - T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; - T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; - - rValueIn = value.gateValue[frameIdx]; - rValueIg = value.gateValue[frameIdx + frameSize]; - rValueFg = value.gateValue[frameIdx + frameSize * 2]; - rValueOg = value.gateValue[frameIdx + frameSize * 3]; - - if (value.prevStateValue) { - if (isBatch) value.prevStateValue += batchIdx * frameSize; - rPrevState = value.prevStateValue[frameIdx]; + T r_state; + T r_prev_state = 0; + T r_state_atv; + T r_out; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + + T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0; + T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0; + T r_checkO = value.check_og ? value.check_og[frame_idx] : 0; + + r_value_in = value.gate_value[frame_idx]; + r_value_ig = value.gate_value[frame_idx + frame_size]; + r_value_fg = value.gate_value[frame_idx + frame_size * 2]; + r_value_og = value.gate_value[frame_idx + frame_size * 3]; + + if (value.prev_state_value) { + if (is_batch) value.prev_state_value += batch_idx * frame_size; + r_prev_state = value.prev_state_value[frame_idx]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, + r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate, + active_state); - value.gateValue[frameIdx] = rValueIn; - value.gateValue[frameIdx + frameSize] = rValueIg; - value.gateValue[frameIdx + frameSize * 2] = rValueFg; - value.gateValue[frameIdx + frameSize * 3] = rValueOg; + value.gate_value[frame_idx] = r_value_in; + value.gate_value[frame_idx + frame_size] = r_value_ig; + value.gate_value[frame_idx + frame_size * 2] = r_value_fg; + value.gate_value[frame_idx + frame_size * 3] = r_value_og; - value.stateValue[frameIdx] = rState; - value.stateActiveValue[frameIdx] = rStateAtv; - value.outputValue[frameIdx] = rOut; + value.state_value[frame_idx] = r_state; + value.state_active_value[frame_idx] = r_state_atv; + value.output_value[frame_idx] = r_out; } /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template +template __global__ void KeLstmBackward(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, - int batchSize, activation_mode_t active_node, + LstmMetaGrad grad, int frame_size, + int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - value.gateValue += batchIdx * frameSize * 4; - value.stateValue += batchIdx * frameSize; - value.stateActiveValue += batchIdx * frameSize; - grad.gateGrad += batchIdx * frameSize * 4; - grad.stateGrad += batchIdx * frameSize; - grad.outputGrad += batchIdx * frameSize; + const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (frame_idx >= frame_size) return; + + int batch_idx = 0; + if (is_batch) { + batch_idx = blockIdx.y * blockDim.y + threadIdx.y; + if (batch_idx >= batch_size) return; + value.gate_value += batch_idx * frame_size * 4; + value.state_value += batch_idx * frame_size; + value.state_active_value += batch_idx * frame_size; + grad.gate_grad += batch_idx * frame_size * 4; + grad.state_grad += batch_idx * frame_size; + grad.output_grad += batch_idx * frame_size; } - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - T rGradIn; - T rGradIg; - T rGradFg; - T rGradOg; - T rPrevState = 0; - T rPrevStateGrad; - T rState; - T rStateGrad; - T rStateAtv; - T rOutputGrad; - T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; - T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; - T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; - - T rCheckIGrad; - T rCheckFGrad; - T rCheckOGrad; - - rValueIn = value.gateValue[frameIdx]; - rValueIg = value.gateValue[frameIdx + frameSize]; - rValueFg = value.gateValue[frameIdx + frameSize * 2]; - rValueOg = value.gateValue[frameIdx + frameSize * 3]; - rState = value.stateValue[frameIdx]; - rStateAtv = value.stateActiveValue[frameIdx]; - rOutputGrad = grad.outputGrad[frameIdx]; - rStateGrad = grad.stateGrad[frameIdx]; - - if (value.prevStateValue) { - if (isBatch) value.prevStateValue += batchIdx * frameSize; - rPrevState = value.prevStateValue[frameIdx]; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_grad_in; + T r_grad_ig; + T r_grad_fg; + T r_grad_og; + T r_prev_state = 0; + T r_prev_state_grad; + T r_state; + T r_state_grad; + T r_state_atv; + T r_output_grad; + T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0; + T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0; + T r_checkO = value.check_og ? value.check_og[frame_idx] : 0; + + T r_checkIGrad; + T r_checkFGrad; + T r_checkOGrad; + + r_value_in = value.gate_value[frame_idx]; + r_value_ig = value.gate_value[frame_idx + frame_size]; + r_value_fg = value.gate_value[frame_idx + frame_size * 2]; + r_value_og = value.gate_value[frame_idx + frame_size * 3]; + r_state = value.state_value[frame_idx]; + r_state_atv = value.state_active_value[frame_idx]; + r_output_grad = grad.output_grad[frame_idx]; + r_state_grad = grad.state_grad[frame_idx]; + + if (value.prev_state_value) { + if (is_batch) value.prev_state_value += batch_idx * frame_size; + r_prev_state = value.prev_state_value[frame_idx]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, - rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, - rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - active_node, active_gate, active_state); - - grad.gateGrad[frameIdx] = rGradIn; - grad.gateGrad[frameIdx + frameSize] = rGradIg; - grad.gateGrad[frameIdx + frameSize * 2] = rGradFg; - grad.gateGrad[frameIdx + frameSize * 3] = rGradOg; - grad.stateGrad[frameIdx] = rStateGrad; - if (grad.prevStateGrad) { - if (isBatch) grad.prevStateGrad += batchIdx * frameSize; - grad.prevStateGrad[frameIdx] = rPrevStateGrad; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, + r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, + r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, + r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, + active_state); + + grad.gate_grad[frame_idx] = r_grad_in; + grad.gate_grad[frame_idx + frame_size] = r_grad_ig; + grad.gate_grad[frame_idx + frame_size * 2] = r_grad_fg; + grad.gate_grad[frame_idx + frame_size * 3] = r_grad_og; + grad.state_grad[frame_idx] = r_state_grad; + if (grad.prev_state_grad) { + if (is_batch) grad.prev_state_grad += batch_idx * frame_size; + grad.prev_state_grad[frame_idx] = r_prev_state_grad; } - if (isBatch) { - if (value.prevStateValue) { - if (grad.checkIgGrad) - paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx, - rCheckIGrad); - if (grad.checkFgGrad) - paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx, - rCheckFGrad); + if (is_batch) { + if (value.prev_state_value) { + if (grad.check_ig_grad) + paddle::platform::CudaAtomicAdd(grad.check_ig_grad + frame_idx, + r_checkIGrad); + if (grad.check_fg_grad) + paddle::platform::CudaAtomicAdd(grad.check_fg_grad + frame_idx, + r_checkFGrad); } - if (grad.checkOgGrad) - paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad); + if (grad.check_og_grad) + paddle::platform::CudaAtomicAdd(grad.check_og_grad + frame_idx, + r_checkOGrad); } else { - if (value.prevStateValue) { - if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad; - if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad; + if (value.prev_state_value) { + if (grad.check_ig_grad) grad.check_ig_grad[frame_idx] += r_checkIGrad; + if (grad.check_fg_grad) grad.check_fg_grad[frame_idx] += r_checkFGrad; } - if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad; + if (grad.check_og_grad) grad.check_og_grad[frame_idx] += r_checkOGrad; } } template void gpu_lstm_forward(const platform::DeviceContext& context, Op op, - LstmMetaValue value, int frameSize, int batchSize, + LstmMetaValue value, int frame_size, int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { dim3 threads; dim3 grid; - if (batchSize == 1) { - int framePerBlock = frameSize <= 1024 ? frameSize : 1024; - int frameBlocks = (frameSize + 1024 - 1) / 1024; - threads = dim3(framePerBlock, 1); - grid = dim3(frameBlocks, 1); + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); } else { - /* framePerBlock = 32 batchPerBlock = 32 */ + /* frame_per_block = 32 batch_per_block = 32 */ threads = dim3(32, 32); - grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); } auto stream = reinterpret_cast(context).stream(); - if (batchSize == 1) { + if (batch_size == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ false><<>>( + op, value, frame_size, batch_size, active_node, active_gate, active_state); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ true><<>>( + op, value, frame_size, batch_size, active_node, active_gate, active_state); } } @@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, template void gpu_lstm_backward(const platform::DeviceContext& context, Op op, LstmMetaValue value, LstmMetaGrad grad, - int frameSize, int batchSize, + int frame_size, int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { dim3 threads; dim3 grid; - if (batchSize == 1) { - int framePerBlock = frameSize <= 1024 ? frameSize : 1024; - int frameBlocks = (frameSize + 1024 - 1) / 1024; - threads = dim3(framePerBlock, 1); - grid = dim3(frameBlocks, 1); + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); } else { - /* framePerBlock = 32 batchPerBlock = 16 */ + /* frame_per_block = 32 batch_per_block = 16 */ threads = dim3(32, 16); - grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 16 - 1) / 16); } auto stream = reinterpret_cast(context).stream(); - if (batchSize == 1) { + if (batch_size == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ false><<>>( + op, value, grad, frame_size, batch_size, active_node, active_gate, active_state); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ true><<>>( + op, value, grad, frame_size, batch_size, active_node, active_gate, active_state); } } diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 9daaf91981..78f9a249a3 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -27,19 +27,19 @@ namespace forward { template class lstm { public: - HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &prevState, T &state, T &stateAtv, T &output, + HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, + T &prev_state, T &state, T &state_atv, T &output, T &checkI, T &checkF, T &checkO, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - valueIn = activation(valueIn, active_node); - valueIg = activation(valueIg + prevState * checkI, active_gate); - valueFg = activation(valueFg + prevState * checkF, active_gate); - state = valueIn * valueIg + prevState * valueFg; - valueOg = activation(valueOg + state * checkO, active_gate); - stateAtv = activation(state, active_state); - output = valueOg * stateAtv; + value_in = activation(value_in, active_node); + value_ig = activation(value_ig + prev_state * checkI, active_gate); + value_fg = activation(value_fg + prev_state * checkF, active_gate); + state = value_in * value_ig + prev_state * value_fg; + value_og = activation(value_og + state * checkO, active_gate); + state_atv = activation(state, active_state); + output = value_og * state_atv; } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -48,24 +48,27 @@ class lstm { // Only float support AVX optimization static const bool avx = std::is_same::value; - HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, - __m256 &valueOg, __m256 &prevState, __m256 &state, - __m256 &stateAtv, __m256 &output, __m256 &checkI, + HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig, + __m256 &value_fg, __m256 &value_og, + __m256 &prev_state, __m256 &state, + __m256 &state_atv, __m256 &output, __m256 &checkI, __m256 &checkF, __m256 &checkO, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - valueIn = activation(valueIn, active_node); - valueIg = activation( - _mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate); - valueFg = activation( - _mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate); - state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), - _mm256_mul_ps(prevState, valueFg)); - valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)), - active_gate); - stateAtv = activation(state, active_state); - output = _mm256_mul_ps(valueOg, stateAtv); + value_in = activation(value_in, active_node); + value_ig = + activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), + active_gate); + value_fg = + activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)), + active_gate); + state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig), + _mm256_mul_ps(prev_state, value_fg)); + value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)), + active_gate); + state_atv = activation(state, active_state); + output = _mm256_mul_ps(value_og, state_atv); } #endif #endif @@ -78,25 +81,26 @@ namespace backward { template class lstm { public: - HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &gradIn, T &gradIg, T &gradFg, T &gradOg, - T &prevState, T &prevStateGrad, T &state, - T &stateGrad, T &stateAtv, T &outputGrad, + HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, + T &grad_in, T &grad_ig, T &grad_fg, T &grad_og, + T &prev_state, T &prev_state_grad, T &state, + T &state_grad, T &state_atv, T &output_grad, T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - gradOg = activation(outputGrad * stateAtv, valueOg, active_gate); - stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) + - gradOg * checkO; - gradIn = activation(stateGrad * valueIg, valueIn, active_node); - gradIg = activation(stateGrad * valueIn, valueIg, active_gate); - gradFg = activation(stateGrad * prevState, valueFg, active_gate); - prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; - checkIGrad = gradIg * prevState; - checkFGrad = gradFg * prevState; - checkOGrad = gradOg * state; + grad_og = activation(output_grad * state_atv, value_og, active_gate); + state_grad += activation(output_grad * value_og, state_atv, active_state) + + grad_og * checkO; + grad_in = activation(state_grad * value_ig, value_in, active_node); + grad_ig = activation(state_grad * value_in, value_ig, active_gate); + grad_fg = activation(state_grad * prev_state, value_fg, active_gate); + prev_state_grad = + grad_ig * checkI + grad_fg * checkF + state_grad * value_fg; + checkIGrad = grad_ig * prev_state; + checkFGrad = grad_fg * prev_state; + checkOGrad = grad_og * state; } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -105,32 +109,32 @@ class lstm { // Only float support AVX optimization static const bool avx = std::is_same::value; HOSTDEVICE void operator()( - __m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg, - __m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg, - __m256 &prevState, __m256 &prevStateGrad, __m256 &state, - __m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI, - __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, - __m256 &checkOGrad, activation_mode_t active_node, + __m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og, + __m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og, + __m256 &prev_state, __m256 &prev_state_grad, __m256 &state, + __m256 &state_grad, __m256 &state_atv, __m256 &output_grad, + __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, + __m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - gradOg = - activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate); - stateGrad = _mm256_add_ps( - activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state), - stateGrad); - stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); - gradIn = - activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node); - gradIg = - activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate); - gradFg = - activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate); - prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), - _mm256_mul_ps(gradFg, checkF)); - prevStateGrad = - _mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad); - checkIGrad = _mm256_mul_ps(gradIg, prevState); - checkFGrad = _mm256_mul_ps(gradFg, prevState); - checkOGrad = _mm256_mul_ps(gradOg, state); + grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, + active_gate); + state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), + state_atv, active_state), + state_grad); + state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad); + grad_in = + activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node); + grad_ig = + activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate); + grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg, + active_gate); + prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI), + _mm256_mul_ps(grad_fg, checkF)); + prev_state_grad = + _mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad); + checkIGrad = _mm256_mul_ps(grad_ig, prev_state); + checkFGrad = _mm256_mul_ps(grad_fg, prev_state); + checkOGrad = _mm256_mul_ps(grad_og, state); } #endif #endif diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index 0febf8e3b7..ad3a59bcdb 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -30,12 +30,12 @@ struct LstmUnitFunctor { detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; } } } @@ -53,20 +53,20 @@ struct LstmUnitGradFunctor { frame_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; } - grad.gateGrad += frame_size * 4; - grad.stateGrad += frame_size; - grad.stateActiveGrad += frame_size; - grad.outputGrad += frame_size; - if (grad.prevStateGrad) { - grad.prevStateGrad += frame_size; + grad.gate_grad += frame_size * 4; + grad.state_grad += frame_size; + grad.state_active_grad += frame_size; + grad.output_grad += frame_size; + if (grad.prev_state_grad) { + grad.prev_state_grad += frame_size; } } } diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 28d2c6fd3b..9652399d4c 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -31,26 +31,26 @@ typedef enum { template struct LstmMetaValue { - T *gateValue; - T *prevStateValue; - T *stateValue; - T *stateActiveValue; - T *outputValue; - T *checkIg; - T *checkFg; - T *checkOg; + T *gate_value; + T *prev_state_value; + T *state_value; + T *state_active_value; + T *output_value; + T *check_ig; + T *check_fg; + T *check_og; }; template struct LstmMetaGrad { - T *gateGrad; - T *prevStateGrad; - T *stateGrad; - T *stateActiveGrad; - T *outputGrad; - T *checkIgGrad; - T *checkFgGrad; - T *checkOgGrad; + T *gate_grad; + T *prev_state_grad; + T *state_grad; + T *state_active_grad; + T *output_grad; + T *check_ig_grad; + T *check_fg_grad; + T *check_og_grad; }; inline activation_mode_t ActiveType(const std::string &type) { -- GitLab