提交 e5b51c4d 编写于 作者: Q qingqing01

Make lstm_op follow google code style.

上级 d89061c3
...@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later. // the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size; lstm_value.check_og = lstm_value.check_fg + frame_size;
} else { } else {
lstm_value.checkIg = nullptr; lstm_value.check_ig = nullptr;
lstm_value.checkFg = nullptr; lstm_value.check_fg = nullptr;
lstm_value.checkOg = nullptr; lstm_value.check_og = nullptr;
} }
lstm_value.prevStateValue = nullptr; lstm_value.prev_state_value = nullptr;
Tensor ordered_c0; Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data(); const size_t* order = batch_gate->lod()[2].data();
if (cell_t0) { if (cell_t0) {
...@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
// to reorder. // to reorder.
ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0, ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
true); true);
lstm_value.prevStateValue = ordered_c0.data<T>(); lstm_value.prev_state_value = ordered_c0.data<T>();
} }
// Use the local variable as here. // Use the local variable as here.
...@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast<T>(1.0)); static_cast<T>(1.0));
} }
lstm_value.gateValue = gate_t.data<T>(); lstm_value.gate_value = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>(); lstm_value.output_value = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>(); lstm_value.state_value = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>(); lstm_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value, math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value,
frame_size, cur_batch_size, frame_size, cur_batch_size,
gate_act, cell_act, cand_act); gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue; lstm_value.prev_state_value = lstm_value.state_value;
} }
math::Batch2LoDTensorFunctor<Place, T> to_seq; math::Batch2LoDTensorFunctor<Place, T> to_seq;
...@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
if (bias && ctx.Attr<bool>("use_peepholes")) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size; lstm_value.check_og = lstm_value.check_fg + frame_size;
} else { } else {
lstm_value.checkIg = nullptr; lstm_value.check_ig = nullptr;
lstm_value.checkFg = nullptr; lstm_value.check_fg = nullptr;
lstm_value.checkOg = nullptr; lstm_value.check_og = nullptr;
} }
math::LstmMetaGrad<T> lstm_grad; math::LstmMetaGrad<T> lstm_grad;
...@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
} }
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) { if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>(); T* bias_g_data = bias_g->data<T>();
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size;
} else { } else {
lstm_grad.checkIgGrad = nullptr; lstm_grad.check_ig_grad = nullptr;
lstm_grad.checkFgGrad = nullptr; lstm_grad.check_fg_grad = nullptr;
lstm_grad.checkOgGrad = nullptr; lstm_grad.check_og_grad = nullptr;
} }
math::LoDTensor2BatchFunctor<Place, T> to_batch; math::LoDTensor2BatchFunctor<Place, T> to_batch;
...@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
Tensor gate = batch_gate->Slice(bstart, bend); Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
lstm_value.gateValue = gate.data<T>(); lstm_value.gate_value = gate.data<T>();
lstm_value.stateValue = cell.data<T>(); lstm_value.state_value = cell.data<T>();
lstm_value.stateActiveValue = cell_pre_act.data<T>(); lstm_value.state_active_value = cell_pre_act.data<T>();
Tensor out_g = batch_hidden_g.Slice(bstart, bend); Tensor out_g = batch_hidden_g.Slice(bstart, bend);
Tensor gate_g = batch_gate_g.Slice(bstart, bend); Tensor gate_g = batch_gate_g.Slice(bstart, bend);
Tensor cell_g = batch_cell_g.Slice(bstart, bend); Tensor cell_g = batch_cell_g.Slice(bstart, bend);
lstm_grad.stateGrad = cell_g.data<T>(); lstm_grad.state_grad = cell_g.data<T>();
lstm_grad.gateGrad = gate_g.data<T>(); lstm_grad.gate_grad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>(); lstm_grad.output_grad = out_g.data<T>();
if (n > 0) { if (n > 0) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]); int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>(); lstm_value.prev_state_value = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>(); lstm_grad.prev_state_grad = cell_pre_g.data<T>();
} else { } else {
lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr; lstm_value.prev_state_value = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr; lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data<T>() : nullptr;
} }
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
......
...@@ -26,278 +26,284 @@ namespace detail { ...@@ -26,278 +26,284 @@ namespace detail {
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize, int frame_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
T rValueIn; T r_value_in;
T rValueIg; T r_value_ig;
T rValueFg; T r_value_fg;
T rValueOg; T r_value_og;
T rCheckI; T r_checkI;
T rCheckF; T r_checkF;
T rCheckO; T r_checkO;
T rState; T r_state;
T rPrevState = 0; T r_prev_state = 0;
T rStateAtv; T r_state_atv;
T rOut; T r_out;
T *valueIn = value.gateValue; T *value_in = value.gate_value;
T *valueIg = value.gateValue + frameSize; T *value_ig = value.gate_value + frame_size;
T *valueFg = value.gateValue + frameSize * 2; T *value_fg = value.gate_value + frame_size * 2;
T *valueOg = value.gateValue + frameSize * 3; T *value_og = value.gate_value + frame_size * 3;
for (int i = 0; i < frameSize; i++) { for (int i = 0; i < frame_size; i++) {
rValueIn = valueIn[i]; r_value_in = value_in[i];
rValueIg = valueIg[i]; r_value_ig = value_ig[i];
rValueFg = valueFg[i]; r_value_fg = value_fg[i];
rValueOg = valueOg[i]; r_value_og = value_og[i];
rCheckI = value.checkIg ? value.checkIg[i] : 0; r_checkI = value.check_ig ? value.check_ig[i] : 0;
rCheckF = value.checkFg ? value.checkFg[i] : 0; r_checkF = value.check_fg ? value.check_fg[i] : 0;
rCheckO = value.checkOg ? value.checkOg[i] : 0; r_checkO = value.check_og ? value.check_og[i] : 0;
if (value.prevStateValue) { if (value.prev_state_value) {
rPrevState = value.prevStateValue[i]; r_prev_state = value.prev_state_value[i];
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
active_gate, active_state);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg; value_in[i] = r_value_in;
valueFg[i] = rValueFg; value_ig[i] = r_value_ig;
valueOg[i] = rValueOg; value_fg[i] = r_value_fg;
value.stateValue[i] = rState; value_og[i] = r_value_og;
value.stateActiveValue[i] = rStateAtv; value.state_value[i] = r_state;
value.outputValue[i] = rOut; value.state_active_value[i] = r_state_atv;
value.output_value[i] = r_out;
} }
} }
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
T rValueIn; T r_value_in;
T rValueIg; T r_value_ig;
T rValueFg; T r_value_fg;
T rValueOg; T r_value_og;
T rGradIn; T r_grad_in;
T rGradIg; T r_grad_ig;
T rGradFg; T r_grad_fg;
T rGradOg; T r_grad_og;
T rPrevState = 0; T r_prev_state = 0;
T rPrevStateGrad; T r_prev_state_grad;
T rState; T r_state;
T rStateGrad; T r_state_grad;
T rStateAtv; T r_state_atv;
T rOutputGrad; T r_output_grad;
T rCheckI; T r_checkI;
T rCheckF; T r_checkF;
T rCheckO; T r_checkO;
T rCheckIGrad; T r_checkIGrad;
T rCheckFGrad; T r_checkFGrad;
T rCheckOGrad; T r_checkOGrad;
T *valueIn = value.gateValue; T *value_in = value.gate_value;
T *valueIg = value.gateValue + frameSize; T *value_ig = value.gate_value + frame_size;
T *valueFg = value.gateValue + frameSize * 2; T *value_fg = value.gate_value + frame_size * 2;
T *valueOg = value.gateValue + frameSize * 3; T *value_og = value.gate_value + frame_size * 3;
T *gradIn = grad.gateGrad; T *grad_in = grad.gate_grad;
T *gradIg = grad.gateGrad + frameSize; T *grad_ig = grad.gate_grad + frame_size;
T *gradFg = grad.gateGrad + frameSize * 2; T *grad_fg = grad.gate_grad + frame_size * 2;
T *gradOg = grad.gateGrad + frameSize * 3; T *grad_og = grad.gate_grad + frame_size * 3;
for (int i = 0; i < frameSize; i++) { for (int i = 0; i < frame_size; i++) {
rValueIn = valueIn[i]; r_value_in = value_in[i];
rValueIg = valueIg[i]; r_value_ig = value_ig[i];
rValueFg = valueFg[i]; r_value_fg = value_fg[i];
rValueOg = valueOg[i]; r_value_og = value_og[i];
rCheckI = value.checkIg ? value.checkIg[i] : 0; r_checkI = value.check_ig ? value.check_ig[i] : 0;
rCheckF = value.checkFg ? value.checkFg[i] : 0; r_checkF = value.check_fg ? value.check_fg[i] : 0;
rCheckO = value.checkOg ? value.checkOg[i] : 0; r_checkO = value.check_og ? value.check_og[i] : 0;
rState = value.stateValue[i]; r_state = value.state_value[i];
rStateAtv = value.stateActiveValue[i]; r_state_atv = value.state_active_value[i];
rOutputGrad = grad.outputGrad[i]; r_output_grad = grad.output_grad[i];
rStateGrad = grad.stateGrad[i]; r_state_grad = grad.state_grad[i];
if (value.prevStateValue) { if (value.prev_state_value) {
rPrevState = value.prevStateValue[i]; r_prev_state = value.prev_state_value[i];
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
rCheckOGrad, active_node, active_gate, active_state); r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg; grad_in[i] = r_grad_in;
gradFg[i] = rGradFg; grad_ig[i] = r_grad_ig;
gradOg[i] = rGradOg; grad_fg[i] = r_grad_fg;
grad.stateGrad[i] = rStateGrad; grad_og[i] = r_grad_og;
grad.state_grad[i] = r_state_grad;
if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad;
if (value.prevStateValue) { if (grad.prev_state_grad) grad.prev_state_grad[i] = r_prev_state_grad;
if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad; if (value.prev_state_value) {
if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad; if (grad.check_ig_grad) grad.check_ig_grad[i] += r_checkIGrad;
if (grad.check_fg_grad) grad.check_fg_grad[i] += r_checkFGrad;
} }
if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad; if (grad.check_og_grad) grad.check_og_grad[i] += r_checkOGrad;
} }
} }
template <class T, class Op> template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize, void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
#ifdef __AVX__ #ifdef __AVX__
__m256 rValueIn; __m256 r_value_in;
__m256 rValueIg; __m256 r_value_ig;
__m256 rValueFg; __m256 r_value_fg;
__m256 rValueOg; __m256 r_value_og;
__m256 rCheckI = _mm256_set1_ps(0.0f); __m256 r_checkI = _mm256_set1_ps(0.0f);
__m256 rCheckF = _mm256_set1_ps(0.0f); __m256 r_checkF = _mm256_set1_ps(0.0f);
__m256 rCheckO = _mm256_set1_ps(0.0f); __m256 r_checkO = _mm256_set1_ps(0.0f);
__m256 rState; __m256 r_state;
__m256 rPrevState = _mm256_set1_ps(0.0f); __m256 r_prev_state = _mm256_set1_ps(0.0f);
__m256 rStateAtv; __m256 r_state_atv;
__m256 rOut; __m256 r_out;
__m256 *valueIn = (__m256 *)value.gateValue; __m256 *value_in = (__m256 *)value.gate_value;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize); __m256 *value_ig = (__m256 *)(value.gate_value + frame_size);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); __m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); __m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3);
for (int i = 0; i < frameSize / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
rValueIn = valueIn[i]; r_value_in = value_in[i];
rValueIg = valueIg[i]; r_value_ig = value_ig[i];
rValueFg = valueFg[i]; r_value_fg = value_fg[i];
rValueOg = valueOg[i]; r_value_og = value_og[i];
if (value.checkIg) { if (value.check_ig) {
rCheckI = ((__m256 *)value.checkIg)[i]; r_checkI = ((__m256 *)value.check_ig)[i];
rCheckF = ((__m256 *)value.checkFg)[i]; r_checkF = ((__m256 *)value.check_fg)[i];
rCheckO = ((__m256 *)value.checkOg)[i]; r_checkO = ((__m256 *)value.check_og)[i];
} }
if (value.prevStateValue) { if (value.prev_state_value) {
rPrevState = ((__m256 *)value.prevStateValue)[i]; r_prev_state = ((__m256 *)value.prev_state_value)[i];
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
active_gate, active_state);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg; value_in[i] = r_value_in;
valueFg[i] = rValueFg; value_ig[i] = r_value_ig;
valueOg[i] = rValueOg; value_fg[i] = r_value_fg;
((__m256 *)value.stateValue)[i] = rState; value_og[i] = r_value_og;
((__m256 *)value.stateActiveValue)[i] = rStateAtv; ((__m256 *)value.state_value)[i] = r_state;
((__m256 *)value.outputValue)[i] = rOut; ((__m256 *)value.state_active_value)[i] = r_state_atv;
((__m256 *)value.output_value)[i] = r_out;
} }
#endif #endif
} }
template <class T, class Op> template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
#ifdef __AVX__ #ifdef __AVX__
__m256 rValueIn; __m256 r_value_in;
__m256 rValueIg; __m256 r_value_ig;
__m256 rValueFg; __m256 r_value_fg;
__m256 rValueOg; __m256 r_value_og;
__m256 rGradIn; __m256 r_grad_in;
__m256 rGradIg; __m256 r_grad_ig;
__m256 rGradFg; __m256 r_grad_fg;
__m256 rGradOg; __m256 r_grad_og;
__m256 rPrevState = _mm256_set1_ps(0.0f); __m256 r_prev_state = _mm256_set1_ps(0.0f);
__m256 rPrevStateGrad; __m256 r_prev_state_grad;
__m256 rStateGrad; __m256 r_state_grad;
__m256 rState; __m256 r_state;
__m256 rStateAtv; __m256 r_state_atv;
__m256 rOutputGrad; __m256 r_output_grad;
__m256 rCheckI = _mm256_set1_ps(0.0f); __m256 r_checkI = _mm256_set1_ps(0.0f);
__m256 rCheckF = _mm256_set1_ps(0.0f); __m256 r_checkF = _mm256_set1_ps(0.0f);
__m256 rCheckO = _mm256_set1_ps(0.0f); __m256 r_checkO = _mm256_set1_ps(0.0f);
__m256 rCheckIGrad; __m256 r_checkIGrad;
__m256 rCheckFGrad; __m256 r_checkFGrad;
__m256 rCheckOGrad; __m256 r_checkOGrad;
__m256 *valueIn = (__m256 *)value.gateValue; __m256 *value_in = (__m256 *)value.gate_value;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize); __m256 *value_ig = (__m256 *)(value.gate_value + frame_size);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); __m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); __m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3);
__m256 *gradIn = (__m256 *)grad.gateGrad; __m256 *grad_in = (__m256 *)grad.gate_grad;
__m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize); __m256 *grad_ig = (__m256 *)(grad.gate_grad + frame_size);
__m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2); __m256 *grad_fg = (__m256 *)(grad.gate_grad + frame_size * 2);
__m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3); __m256 *grad_og = (__m256 *)(grad.gate_grad + frame_size * 3);
for (int i = 0; i < frameSize / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
rValueIn = valueIn[i]; r_value_in = value_in[i];
rValueIg = valueIg[i]; r_value_ig = value_ig[i];
rValueFg = valueFg[i]; r_value_fg = value_fg[i];
rValueOg = valueOg[i]; r_value_og = value_og[i];
if (value.checkIg) { if (value.check_ig) {
rCheckI = ((__m256 *)value.checkIg)[i]; r_checkI = ((__m256 *)value.check_ig)[i];
rCheckF = ((__m256 *)value.checkFg)[i]; r_checkF = ((__m256 *)value.check_fg)[i];
rCheckO = ((__m256 *)value.checkOg)[i]; r_checkO = ((__m256 *)value.check_og)[i];
} }
rState = ((__m256 *)value.stateValue)[i]; r_state = ((__m256 *)value.state_value)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i]; r_state_atv = ((__m256 *)value.state_active_value)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i]; r_output_grad = ((__m256 *)grad.output_grad)[i];
rStateGrad = ((__m256 *)grad.stateGrad)[i]; r_state_grad = ((__m256 *)grad.state_grad)[i];
if (value.prevStateValue) { if (value.prev_state_value) {
rPrevState = ((__m256 *)value.prevStateValue)[i]; r_prev_state = ((__m256 *)value.prev_state_value)[i];
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
rCheckOGrad, active_node, active_gate, active_state); r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg; grad_in[i] = r_grad_in;
gradFg[i] = rGradFg; grad_ig[i] = r_grad_ig;
gradOg[i] = rGradOg; grad_fg[i] = r_grad_fg;
((__m256 *)grad.stateGrad)[i] = rStateGrad; grad_og[i] = r_grad_og;
((__m256 *)grad.state_grad)[i] = r_state_grad;
if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad;
if (value.prevStateValue) { if (grad.prev_state_grad)
if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad; ((__m256 *)grad.prev_state_grad)[i] = r_prev_state_grad;
if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad; if (value.prev_state_value) {
if (grad.check_ig_grad) ((__m256 *)grad.check_ig_grad)[i] += r_checkIGrad;
if (grad.check_fg_grad) ((__m256 *)grad.check_fg_grad)[i] += r_checkFGrad;
} }
if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad; if (grad.check_og_grad) ((__m256 *)grad.check_og_grad)[i] += r_checkOGrad;
} }
#endif #endif
} }
template <class T, class Op> template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize, void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node, avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node, naive_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state); active_gate, active_state);
} }
} }
template <class T, class Op> template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, activation_mode_t active_node, int frame_size, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node, avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node, naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size,
active_gate, active_state); active_node, active_gate, active_state);
} }
} }
......
...@@ -26,189 +26,192 @@ namespace math { ...@@ -26,189 +26,192 @@ namespace math {
namespace detail { namespace detail {
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class T, class Op, bool isBatch> template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
int batchSize, activation_mode_t active_node, int batch_size, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
value.gateValue += batchIdx * frameSize * 4; value.gate_value += batch_idx * frame_size * 4;
value.outputValue += batchIdx * frameSize; value.output_value += batch_idx * frame_size;
value.stateValue += batchIdx * frameSize; value.state_value += batch_idx * frame_size;
value.stateActiveValue += batchIdx * frameSize; value.state_active_value += batch_idx * frame_size;
} }
T rState; T r_state;
T rPrevState = 0; T r_prev_state = 0;
T rStateAtv; T r_state_atv;
T rOut; T r_out;
T rValueIn; T r_value_in;
T rValueIg; T r_value_ig;
T rValueFg; T r_value_fg;
T rValueOg; T r_value_og;
T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0;
T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; T r_checkO = value.check_og ? value.check_og[frame_idx] : 0;
rValueIn = value.gateValue[frameIdx]; r_value_in = value.gate_value[frame_idx];
rValueIg = value.gateValue[frameIdx + frameSize]; r_value_ig = value.gate_value[frame_idx + frame_size];
rValueFg = value.gateValue[frameIdx + frameSize * 2]; r_value_fg = value.gate_value[frame_idx + frame_size * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3]; r_value_og = value.gate_value[frame_idx + frame_size * 3];
if (value.prevStateValue) { if (value.prev_state_value) {
if (isBatch) value.prevStateValue += batchIdx * frameSize; if (is_batch) value.prev_state_value += batch_idx * frame_size;
rPrevState = value.prevStateValue[frameIdx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate,
active_state);
value.gateValue[frameIdx] = rValueIn; value.gate_value[frame_idx] = r_value_in;
value.gateValue[frameIdx + frameSize] = rValueIg; value.gate_value[frame_idx + frame_size] = r_value_ig;
value.gateValue[frameIdx + frameSize * 2] = rValueFg; value.gate_value[frame_idx + frame_size * 2] = r_value_fg;
value.gateValue[frameIdx + frameSize * 3] = rValueOg; value.gate_value[frame_idx + frame_size * 3] = r_value_og;
value.stateValue[frameIdx] = rState; value.state_value[frame_idx] = r_state;
value.stateActiveValue[frameIdx] = rStateAtv; value.state_active_value[frame_idx] = r_state_atv;
value.outputValue[frameIdx] = rOut; value.output_value[frame_idx] = r_out;
} }
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class T, class Op, bool isBatch> template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frame_size,
int batchSize, activation_mode_t active_node, int batch_size, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
value.gateValue += batchIdx * frameSize * 4; value.gate_value += batch_idx * frame_size * 4;
value.stateValue += batchIdx * frameSize; value.state_value += batch_idx * frame_size;
value.stateActiveValue += batchIdx * frameSize; value.state_active_value += batch_idx * frame_size;
grad.gateGrad += batchIdx * frameSize * 4; grad.gate_grad += batch_idx * frame_size * 4;
grad.stateGrad += batchIdx * frameSize; grad.state_grad += batch_idx * frame_size;
grad.outputGrad += batchIdx * frameSize; grad.output_grad += batch_idx * frame_size;
} }
T rValueIn; T r_value_in;
T rValueIg; T r_value_ig;
T rValueFg; T r_value_fg;
T rValueOg; T r_value_og;
T rGradIn; T r_grad_in;
T rGradIg; T r_grad_ig;
T rGradFg; T r_grad_fg;
T rGradOg; T r_grad_og;
T rPrevState = 0; T r_prev_state = 0;
T rPrevStateGrad; T r_prev_state_grad;
T rState; T r_state;
T rStateGrad; T r_state_grad;
T rStateAtv; T r_state_atv;
T rOutputGrad; T r_output_grad;
T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0;
T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; T r_checkO = value.check_og ? value.check_og[frame_idx] : 0;
T rCheckIGrad; T r_checkIGrad;
T rCheckFGrad; T r_checkFGrad;
T rCheckOGrad; T r_checkOGrad;
rValueIn = value.gateValue[frameIdx]; r_value_in = value.gate_value[frame_idx];
rValueIg = value.gateValue[frameIdx + frameSize]; r_value_ig = value.gate_value[frame_idx + frame_size];
rValueFg = value.gateValue[frameIdx + frameSize * 2]; r_value_fg = value.gate_value[frame_idx + frame_size * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3]; r_value_og = value.gate_value[frame_idx + frame_size * 3];
rState = value.stateValue[frameIdx]; r_state = value.state_value[frame_idx];
rStateAtv = value.stateActiveValue[frameIdx]; r_state_atv = value.state_active_value[frame_idx];
rOutputGrad = grad.outputGrad[frameIdx]; r_output_grad = grad.output_grad[frame_idx];
rStateGrad = grad.stateGrad[frameIdx]; r_state_grad = grad.state_grad[frame_idx];
if (value.prevStateValue) { if (value.prev_state_value) {
if (isBatch) value.prevStateValue += batchIdx * frameSize; if (is_batch) value.prev_state_value += batch_idx * frame_size;
rPrevState = value.prevStateValue[frameIdx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
active_node, active_gate, active_state); r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg; grad.gate_grad[frame_idx] = r_grad_in;
grad.gateGrad[frameIdx + frameSize * 2] = rGradFg; grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
grad.gateGrad[frameIdx + frameSize * 3] = rGradOg; grad.gate_grad[frame_idx + frame_size * 2] = r_grad_fg;
grad.stateGrad[frameIdx] = rStateGrad; grad.gate_grad[frame_idx + frame_size * 3] = r_grad_og;
if (grad.prevStateGrad) { grad.state_grad[frame_idx] = r_state_grad;
if (isBatch) grad.prevStateGrad += batchIdx * frameSize; if (grad.prev_state_grad) {
grad.prevStateGrad[frameIdx] = rPrevStateGrad; if (is_batch) grad.prev_state_grad += batch_idx * frame_size;
grad.prev_state_grad[frame_idx] = r_prev_state_grad;
} }
if (isBatch) { if (is_batch) {
if (value.prevStateValue) { if (value.prev_state_value) {
if (grad.checkIgGrad) if (grad.check_ig_grad)
paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx, paddle::platform::CudaAtomicAdd(grad.check_ig_grad + frame_idx,
rCheckIGrad); r_checkIGrad);
if (grad.checkFgGrad) if (grad.check_fg_grad)
paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx, paddle::platform::CudaAtomicAdd(grad.check_fg_grad + frame_idx,
rCheckFGrad); r_checkFGrad);
} }
if (grad.checkOgGrad) if (grad.check_og_grad)
paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad); paddle::platform::CudaAtomicAdd(grad.check_og_grad + frame_idx,
r_checkOGrad);
} else { } else {
if (value.prevStateValue) { if (value.prev_state_value) {
if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad; if (grad.check_ig_grad) grad.check_ig_grad[frame_idx] += r_checkIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad; if (grad.check_fg_grad) grad.check_fg_grad[frame_idx] += r_checkFGrad;
} }
if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad; if (grad.check_og_grad) grad.check_og_grad[frame_idx] += r_checkOGrad;
} }
} }
template <class T, class Op> template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op, void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frameSize, int batchSize, LstmMetaValue<T> value, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batchSize == 1) { if (batch_size == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024; int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1); threads = dim3(frame_per_block, 1);
grid = dim3(frameBlocks, 1); grid = dim3(frame_blocks, 1);
} else { } else {
/* framePerBlock = 32 batchPerBlock = 32 */ /* frame_per_block = 32 batch_per_block = 32 */
threads = dim3(32, 32); threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
} }
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) { if (batch_size == 1) {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, op, value, frame_size, batch_size, active_node, active_gate,
active_state); active_state);
} else { } else {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, op, value, frame_size, batch_size, active_node, active_gate,
active_state); active_state);
} }
} }
...@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, ...@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template <class T, class Op> template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op, void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, int batchSize, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batchSize == 1) { if (batch_size == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024; int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1); threads = dim3(frame_per_block, 1);
grid = dim3(frameBlocks, 1); grid = dim3(frame_blocks, 1);
} else { } else {
/* framePerBlock = 32 batchPerBlock = 16 */ /* frame_per_block = 32 batch_per_block = 16 */
threads = dim3(32, 16); threads = dim3(32, 16);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 16 - 1) / 16);
} }
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) { if (batch_size == 1) {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate, op, value, grad, frame_size, batch_size, active_node, active_gate,
active_state); active_state);
} else { } else {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate, op, value, grad, frame_size, batch_size, active_node, active_gate,
active_state); active_state);
} }
} }
......
...@@ -27,19 +27,19 @@ namespace forward { ...@@ -27,19 +27,19 @@ namespace forward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &prevState, T &state, T &stateAtv, T &output, T &prev_state, T &state, T &state_atv, T &output,
T &checkI, T &checkF, T &checkO, T &checkI, T &checkF, T &checkO,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
valueIn = activation(valueIn, active_node); value_in = activation(value_in, active_node);
valueIg = activation(valueIg + prevState * checkI, active_gate); value_ig = activation(value_ig + prev_state * checkI, active_gate);
valueFg = activation(valueFg + prevState * checkF, active_gate); value_fg = activation(value_fg + prev_state * checkF, active_gate);
state = valueIn * valueIg + prevState * valueFg; state = value_in * value_ig + prev_state * value_fg;
valueOg = activation(valueOg + state * checkO, active_gate); value_og = activation(value_og + state * checkO, active_gate);
stateAtv = activation(state, active_state); state_atv = activation(state, active_state);
output = valueOg * stateAtv; output = value_og * state_atv;
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -48,24 +48,27 @@ class lstm { ...@@ -48,24 +48,27 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig,
__m256 &valueOg, __m256 &prevState, __m256 &state, __m256 &value_fg, __m256 &value_og,
__m256 &stateAtv, __m256 &output, __m256 &checkI, __m256 &prev_state, __m256 &state,
__m256 &state_atv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO, __m256 &checkF, __m256 &checkO,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
valueIn = activation(valueIn, active_node); value_in = activation(value_in, active_node);
valueIg = activation( value_ig =
_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate); activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
valueFg = activation( active_gate);
_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate); value_fg =
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)),
_mm256_mul_ps(prevState, valueFg)); active_gate);
valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)), state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig),
active_gate); _mm256_mul_ps(prev_state, value_fg));
stateAtv = activation(state, active_state); value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)),
output = _mm256_mul_ps(valueOg, stateAtv); active_gate);
state_atv = activation(state, active_state);
output = _mm256_mul_ps(value_og, state_atv);
} }
#endif #endif
#endif #endif
...@@ -78,25 +81,26 @@ namespace backward { ...@@ -78,25 +81,26 @@ namespace backward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &gradIn, T &gradIg, T &gradFg, T &gradOg, T &grad_in, T &grad_ig, T &grad_fg, T &grad_og,
T &prevState, T &prevStateGrad, T &state, T &prev_state, T &prev_state_grad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad, T &state_grad, T &state_atv, T &output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad, T &checkFGrad, T &checkOGrad,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_gate,
activation_mode_t active_state) { activation_mode_t active_state) {
gradOg = activation(outputGrad * stateAtv, valueOg, active_gate); grad_og = activation(output_grad * state_atv, value_og, active_gate);
stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) + state_grad += activation(output_grad * value_og, state_atv, active_state) +
gradOg * checkO; grad_og * checkO;
gradIn = activation(stateGrad * valueIg, valueIn, active_node); grad_in = activation(state_grad * value_ig, value_in, active_node);
gradIg = activation(stateGrad * valueIn, valueIg, active_gate); grad_ig = activation(state_grad * value_in, value_ig, active_gate);
gradFg = activation(stateGrad * prevState, valueFg, active_gate); grad_fg = activation(state_grad * prev_state, value_fg, active_gate);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; prev_state_grad =
checkIGrad = gradIg * prevState; grad_ig * checkI + grad_fg * checkF + state_grad * value_fg;
checkFGrad = gradFg * prevState; checkIGrad = grad_ig * prev_state;
checkOGrad = gradOg * state; checkFGrad = grad_fg * prev_state;
checkOGrad = grad_og * state;
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -105,32 +109,32 @@ class lstm { ...@@ -105,32 +109,32 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()( HOSTDEVICE void operator()(
__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg, __m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og,
__m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg, __m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og,
__m256 &prevState, __m256 &prevStateGrad, __m256 &state, __m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI, __m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
__m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkOGrad, activation_mode_t active_node, __m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_state) { activation_mode_t active_gate, activation_mode_t active_state) {
gradOg = grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate); active_gate);
stateGrad = _mm256_add_ps( state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state), state_atv, active_state),
stateGrad); state_grad);
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad);
gradIn = grad_in =
activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node); activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node);
gradIg = grad_ig =
activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate); activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate);
gradFg = grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg,
activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate); active_gate);
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI),
_mm256_mul_ps(gradFg, checkF)); _mm256_mul_ps(grad_fg, checkF));
prevStateGrad = prev_state_grad =
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad); _mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad);
checkIGrad = _mm256_mul_ps(gradIg, prevState); checkIGrad = _mm256_mul_ps(grad_ig, prev_state);
checkFGrad = _mm256_mul_ps(gradFg, prevState); checkFGrad = _mm256_mul_ps(grad_fg, prev_state);
checkOGrad = _mm256_mul_ps(gradOg, state); checkOGrad = _mm256_mul_ps(grad_og, state);
} }
#endif #endif
#endif #endif
......
...@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> { ...@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size, detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act), ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act)); ActiveType(cell_act));
value.gateValue += frame_size * 4; value.gate_value += frame_size * 4;
value.stateValue += frame_size; value.state_value += frame_size;
value.stateActiveValue += frame_size; value.state_active_value += frame_size;
value.outputValue += frame_size; value.output_value += frame_size;
if (value.prevStateValue) { if (value.prev_state_value) {
value.prevStateValue += frame_size; value.prev_state_value += frame_size;
} }
} }
} }
...@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> { ...@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
frame_size, ActiveType(cand_act), frame_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act)); ActiveType(gate_act), ActiveType(cell_act));
value.gateValue += frame_size * 4; value.gate_value += frame_size * 4;
value.stateValue += frame_size; value.state_value += frame_size;
value.stateActiveValue += frame_size; value.state_active_value += frame_size;
value.outputValue += frame_size; value.output_value += frame_size;
if (value.prevStateValue) { if (value.prev_state_value) {
value.prevStateValue += frame_size; value.prev_state_value += frame_size;
} }
grad.gateGrad += frame_size * 4; grad.gate_grad += frame_size * 4;
grad.stateGrad += frame_size; grad.state_grad += frame_size;
grad.stateActiveGrad += frame_size; grad.state_active_grad += frame_size;
grad.outputGrad += frame_size; grad.output_grad += frame_size;
if (grad.prevStateGrad) { if (grad.prev_state_grad) {
grad.prevStateGrad += frame_size; grad.prev_state_grad += frame_size;
} }
} }
} }
......
...@@ -31,26 +31,26 @@ typedef enum { ...@@ -31,26 +31,26 @@ typedef enum {
template <class T> template <class T>
struct LstmMetaValue { struct LstmMetaValue {
T *gateValue; T *gate_value;
T *prevStateValue; T *prev_state_value;
T *stateValue; T *state_value;
T *stateActiveValue; T *state_active_value;
T *outputValue; T *output_value;
T *checkIg; T *check_ig;
T *checkFg; T *check_fg;
T *checkOg; T *check_og;
}; };
template <class T> template <class T>
struct LstmMetaGrad { struct LstmMetaGrad {
T *gateGrad; T *gate_grad;
T *prevStateGrad; T *prev_state_grad;
T *stateGrad; T *state_grad;
T *stateActiveGrad; T *state_active_grad;
T *outputGrad; T *output_grad;
T *checkIgGrad; T *check_ig_grad;
T *checkFgGrad; T *check_fg_grad;
T *checkOgGrad; T *check_og_grad;
}; };
inline activation_mode_t ActiveType(const std::string &type) { inline activation_mode_t ActiveType(const std::string &type) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册