提交 e5b51c4d 编写于 作者: Q qingqing01

Make lstm_op follow google code style.

上级 d89061c3
......@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.check_og = lstm_value.check_fg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
}
lstm_value.prevStateValue = nullptr;
lstm_value.prev_state_value = nullptr;
Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
if (cell_t0) {
......@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
// to reorder.
ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prevStateValue = ordered_c0.data<T>();
lstm_value.prev_state_value = ordered_c0.data<T>();
}
// Use the local variable as here.
......@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast<T>(1.0));
}
lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
lstm_value.gate_value = gate_t.data<T>();
lstm_value.output_value = out_t.data<T>();
lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value,
frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue;
lstm_value.prev_state_value = lstm_value.state_value;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
......@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::LstmMetaValue<T> lstm_value;
if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.check_og = lstm_value.check_fg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
}
math::LstmMetaGrad<T> lstm_grad;
......@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>();
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size;
lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size;
lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size;
} else {
lstm_grad.checkIgGrad = nullptr;
lstm_grad.checkFgGrad = nullptr;
lstm_grad.checkOgGrad = nullptr;
lstm_grad.check_ig_grad = nullptr;
lstm_grad.check_fg_grad = nullptr;
lstm_grad.check_og_grad = nullptr;
}
math::LoDTensor2BatchFunctor<Place, T> to_batch;
......@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
lstm_value.gateValue = gate.data<T>();
lstm_value.stateValue = cell.data<T>();
lstm_value.stateActiveValue = cell_pre_act.data<T>();
lstm_value.gate_value = gate.data<T>();
lstm_value.state_value = cell.data<T>();
lstm_value.state_active_value = cell_pre_act.data<T>();
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
Tensor gate_g = batch_gate_g.Slice(bstart, bend);
Tensor cell_g = batch_cell_g.Slice(bstart, bend);
lstm_grad.stateGrad = cell_g.data<T>();
lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>();
lstm_grad.state_grad = cell_g.data<T>();
lstm_grad.gate_grad = gate_g.data<T>();
lstm_grad.output_grad = out_g.data<T>();
if (n > 0) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
lstm_value.prev_state_value = cell_pre.data<T>();
lstm_grad.prev_state_grad = cell_pre_g.data<T>();
} else {
lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr;
lstm_value.prev_state_value = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data<T>() : nullptr;
}
int cur_batch_size = bend - bstart;
......
......@@ -26,278 +26,284 @@ namespace detail {
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize,
int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI;
T rCheckF;
T rCheckO;
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg ? value.checkOg[i] : 0;
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
T r_value_in;
T r_value_ig;
T r_value_fg;
T r_value_og;
T r_checkI;
T r_checkF;
T r_checkO;
T r_state;
T r_prev_state = 0;
T r_state_atv;
T r_out;
T *value_in = value.gate_value;
T *value_ig = value.gate_value + frame_size;
T *value_fg = value.gate_value + frame_size * 2;
T *value_og = value.gate_value + frame_size * 3;
for (int i = 0; i < frame_size; i++) {
r_value_in = value_in[i];
r_value_ig = value_ig[i];
r_value_fg = value_fg[i];
r_value_og = value_og[i];
r_checkI = value.check_ig ? value.check_ig[i] : 0;
r_checkF = value.check_fg ? value.check_fg[i] : 0;
r_checkO = value.check_og ? value.check_og[i] : 0;
if (value.prev_state_value) {
r_prev_state = value.prev_state_value[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state);
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
active_gate, active_state);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
value.stateValue[i] = rState;
value.stateActiveValue[i] = rStateAtv;
value.outputValue[i] = rOut;
value_in[i] = r_value_in;
value_ig[i] = r_value_ig;
value_fg[i] = r_value_fg;
value_og[i] = r_value_og;
value.state_value[i] = r_state;
value.state_active_value[i] = r_state_atv;
value.output_value[i] = r_out;
}
}
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI;
T rCheckF;
T rCheckO;
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
T *gradIn = grad.gateGrad;
T *gradIg = grad.gateGrad + frameSize;
T *gradFg = grad.gateGrad + frameSize * 2;
T *gradOg = grad.gateGrad + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg ? value.checkOg[i] : 0;
rState = value.stateValue[i];
rStateAtv = value.stateActiveValue[i];
rOutputGrad = grad.outputGrad[i];
rStateGrad = grad.stateGrad[i];
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
T r_value_in;
T r_value_ig;
T r_value_fg;
T r_value_og;
T r_grad_in;
T r_grad_ig;
T r_grad_fg;
T r_grad_og;
T r_prev_state = 0;
T r_prev_state_grad;
T r_state;
T r_state_grad;
T r_state_atv;
T r_output_grad;
T r_checkI;
T r_checkF;
T r_checkO;
T r_checkIGrad;
T r_checkFGrad;
T r_checkOGrad;
T *value_in = value.gate_value;
T *value_ig = value.gate_value + frame_size;
T *value_fg = value.gate_value + frame_size * 2;
T *value_og = value.gate_value + frame_size * 3;
T *grad_in = grad.gate_grad;
T *grad_ig = grad.gate_grad + frame_size;
T *grad_fg = grad.gate_grad + frame_size * 2;
T *grad_og = grad.gate_grad + frame_size * 3;
for (int i = 0; i < frame_size; i++) {
r_value_in = value_in[i];
r_value_ig = value_ig[i];
r_value_fg = value_fg[i];
r_value_og = value_og[i];
r_checkI = value.check_ig ? value.check_ig[i] : 0;
r_checkF = value.check_fg ? value.check_fg[i] : 0;
r_checkO = value.check_og ? value.check_og[i] : 0;
r_state = value.state_value[i];
r_state_atv = value.state_active_value[i];
r_output_grad = grad.output_grad[i];
r_state_grad = grad.state_grad[i];
if (value.prev_state_value) {
r_prev_state = value.prev_state_value[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, active_node, active_gate, active_state);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
grad.stateGrad[i] = rStateGrad;
if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad;
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig;
grad_fg[i] = r_grad_fg;
grad_og[i] = r_grad_og;
grad.state_grad[i] = r_state_grad;
if (grad.prev_state_grad) grad.prev_state_grad[i] = r_prev_state_grad;
if (value.prev_state_value) {
if (grad.check_ig_grad) grad.check_ig_grad[i] += r_checkIGrad;
if (grad.check_fg_grad) grad.check_fg_grad[i] += r_checkFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad;
if (grad.check_og_grad) grad.check_og_grad[i] += r_checkOGrad;
}
}
template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rState;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rStateAtv;
__m256 rOut;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
if (value.checkIg) {
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
__m256 r_value_in;
__m256 r_value_ig;
__m256 r_value_fg;
__m256 r_value_og;
__m256 r_checkI = _mm256_set1_ps(0.0f);
__m256 r_checkF = _mm256_set1_ps(0.0f);
__m256 r_checkO = _mm256_set1_ps(0.0f);
__m256 r_state;
__m256 r_prev_state = _mm256_set1_ps(0.0f);
__m256 r_state_atv;
__m256 r_out;
__m256 *value_in = (__m256 *)value.gate_value;
__m256 *value_ig = (__m256 *)(value.gate_value + frame_size);
__m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2);
__m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3);
for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i];
r_value_ig = value_ig[i];
r_value_fg = value_fg[i];
r_value_og = value_og[i];
if (value.check_ig) {
r_checkI = ((__m256 *)value.check_ig)[i];
r_checkF = ((__m256 *)value.check_fg)[i];
r_checkO = ((__m256 *)value.check_og)[i];
}
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
if (value.prev_state_value) {
r_prev_state = ((__m256 *)value.prev_state_value)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state);
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
active_gate, active_state);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
((__m256 *)value.stateValue)[i] = rState;
((__m256 *)value.stateActiveValue)[i] = rStateAtv;
((__m256 *)value.outputValue)[i] = rOut;
value_in[i] = r_value_in;
value_ig[i] = r_value_ig;
value_fg[i] = r_value_fg;
value_og[i] = r_value_og;
((__m256 *)value.state_value)[i] = r_state;
((__m256 *)value.state_active_value)[i] = r_state_atv;
((__m256 *)value.output_value)[i] = r_out;
}
#endif
}
template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rGradIn;
__m256 rGradIg;
__m256 rGradFg;
__m256 rGradOg;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rPrevStateGrad;
__m256 rStateGrad;
__m256 rState;
__m256 rStateAtv;
__m256 rOutputGrad;
__m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rCheckIGrad;
__m256 rCheckFGrad;
__m256 rCheckOGrad;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
__m256 *gradIn = (__m256 *)grad.gateGrad;
__m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize);
__m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2);
__m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
if (value.checkIg) {
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
__m256 r_value_in;
__m256 r_value_ig;
__m256 r_value_fg;
__m256 r_value_og;
__m256 r_grad_in;
__m256 r_grad_ig;
__m256 r_grad_fg;
__m256 r_grad_og;
__m256 r_prev_state = _mm256_set1_ps(0.0f);
__m256 r_prev_state_grad;
__m256 r_state_grad;
__m256 r_state;
__m256 r_state_atv;
__m256 r_output_grad;
__m256 r_checkI = _mm256_set1_ps(0.0f);
__m256 r_checkF = _mm256_set1_ps(0.0f);
__m256 r_checkO = _mm256_set1_ps(0.0f);
__m256 r_checkIGrad;
__m256 r_checkFGrad;
__m256 r_checkOGrad;
__m256 *value_in = (__m256 *)value.gate_value;
__m256 *value_ig = (__m256 *)(value.gate_value + frame_size);
__m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2);
__m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3);
__m256 *grad_in = (__m256 *)grad.gate_grad;
__m256 *grad_ig = (__m256 *)(grad.gate_grad + frame_size);
__m256 *grad_fg = (__m256 *)(grad.gate_grad + frame_size * 2);
__m256 *grad_og = (__m256 *)(grad.gate_grad + frame_size * 3);
for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i];
r_value_ig = value_ig[i];
r_value_fg = value_fg[i];
r_value_og = value_og[i];
if (value.check_ig) {
r_checkI = ((__m256 *)value.check_ig)[i];
r_checkF = ((__m256 *)value.check_fg)[i];
r_checkO = ((__m256 *)value.check_og)[i];
}
rState = ((__m256 *)value.stateValue)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i];
rStateGrad = ((__m256 *)grad.stateGrad)[i];
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
r_state = ((__m256 *)value.state_value)[i];
r_state_atv = ((__m256 *)value.state_active_value)[i];
r_output_grad = ((__m256 *)grad.output_grad)[i];
r_state_grad = ((__m256 *)grad.state_grad)[i];
if (value.prev_state_value) {
r_prev_state = ((__m256 *)value.prev_state_value)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, active_node, active_gate, active_state);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
((__m256 *)grad.stateGrad)[i] = rStateGrad;
if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad;
if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad;
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig;
grad_fg[i] = r_grad_fg;
grad_og[i] = r_grad_og;
((__m256 *)grad.state_grad)[i] = r_state_grad;
if (grad.prev_state_grad)
((__m256 *)grad.prev_state_grad)[i] = r_prev_state_grad;
if (value.prev_state_value) {
if (grad.check_ig_grad) ((__m256 *)grad.check_ig_grad)[i] += r_checkIGrad;
if (grad.check_fg_grad) ((__m256 *)grad.check_fg_grad)[i] += r_checkFGrad;
}
if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad;
if (grad.check_og_grad) ((__m256 *)grad.check_og_grad)[i] += r_checkOGrad;
}
#endif
}
template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
naive_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state);
}
}
template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, activation_mode_t active_node,
int frame_size, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node,
active_gate, active_state);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size,
active_node, active_gate, active_state);
}
}
......
......@@ -26,189 +26,192 @@ namespace math {
namespace detail {
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize, activation_mode_t active_node,
template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
int batch_size, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.outputValue += batchIdx * frameSize;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
value.gate_value += batch_idx * frame_size * 4;
value.output_value += batch_idx * frame_size;
value.state_value += batch_idx * frame_size;
value.state_active_value += batch_idx * frame_size;
}
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
T r_state;
T r_prev_state = 0;
T r_state_atv;
T r_out;
T r_value_in;
T r_value_ig;
T r_value_fg;
T r_value_og;
T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0;
T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0;
T r_checkO = value.check_og ? value.check_og[frame_idx] : 0;
r_value_in = value.gate_value[frame_idx];
r_value_ig = value.gate_value[frame_idx + frame_size];
r_value_fg = value.gate_value[frame_idx + frame_size * 2];
r_value_og = value.gate_value[frame_idx + frame_size * 3];
if (value.prev_state_value) {
if (is_batch) value.prev_state_value += batch_idx * frame_size;
r_prev_state = value.prev_state_value[frame_idx];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state);
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate,
active_state);
value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg;
value.gateValue[frameIdx + frameSize * 2] = rValueFg;
value.gateValue[frameIdx + frameSize * 3] = rValueOg;
value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig;
value.gate_value[frame_idx + frame_size * 2] = r_value_fg;
value.gate_value[frame_idx + frame_size * 3] = r_value_og;
value.stateValue[frameIdx] = rState;
value.stateActiveValue[frameIdx] = rStateAtv;
value.outputValue[frameIdx] = rOut;
value.state_value[frame_idx] = r_state;
value.state_active_value[frame_idx] = r_state_atv;
value.output_value[frame_idx] = r_out;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <class T, class Op, bool isBatch>
template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
LstmMetaGrad<T> grad, int frame_size,
int batch_size, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
grad.gateGrad += batchIdx * frameSize * 4;
grad.stateGrad += batchIdx * frameSize;
grad.outputGrad += batchIdx * frameSize;
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
value.gate_value += batch_idx * frame_size * 4;
value.state_value += batch_idx * frame_size;
value.state_active_value += batch_idx * frame_size;
grad.gate_grad += batch_idx * frame_size * 4;
grad.state_grad += batch_idx * frame_size;
grad.output_grad += batch_idx * frame_size;
}
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
rState = value.stateValue[frameIdx];
rStateAtv = value.stateActiveValue[frameIdx];
rOutputGrad = grad.outputGrad[frameIdx];
rStateGrad = grad.stateGrad[frameIdx];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
T r_value_in;
T r_value_ig;
T r_value_fg;
T r_value_og;
T r_grad_in;
T r_grad_ig;
T r_grad_fg;
T r_grad_og;
T r_prev_state = 0;
T r_prev_state_grad;
T r_state;
T r_state_grad;
T r_state_atv;
T r_output_grad;
T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0;
T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0;
T r_checkO = value.check_og ? value.check_og[frame_idx] : 0;
T r_checkIGrad;
T r_checkFGrad;
T r_checkOGrad;
r_value_in = value.gate_value[frame_idx];
r_value_ig = value.gate_value[frame_idx + frame_size];
r_value_fg = value.gate_value[frame_idx + frame_size * 2];
r_value_og = value.gate_value[frame_idx + frame_size * 3];
r_state = value.state_value[frame_idx];
r_state_atv = value.state_active_value[frame_idx];
r_output_grad = grad.output_grad[frame_idx];
r_state_grad = grad.state_grad[frame_idx];
if (value.prev_state_value) {
if (is_batch) value.prev_state_value += batch_idx * frame_size;
r_prev_state = value.prev_state_value[frame_idx];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
active_node, active_gate, active_state);
grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg;
grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
grad.stateGrad[frameIdx] = rStateGrad;
if (grad.prevStateGrad) {
if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
grad.prevStateGrad[frameIdx] = rPrevStateGrad;
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
grad.gate_grad[frame_idx + frame_size * 2] = r_grad_fg;
grad.gate_grad[frame_idx + frame_size * 3] = r_grad_og;
grad.state_grad[frame_idx] = r_state_grad;
if (grad.prev_state_grad) {
if (is_batch) grad.prev_state_grad += batch_idx * frame_size;
grad.prev_state_grad[frame_idx] = r_prev_state_grad;
}
if (isBatch) {
if (value.prevStateValue) {
if (grad.checkIgGrad)
paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx,
rCheckIGrad);
if (grad.checkFgGrad)
paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx,
rCheckFGrad);
if (is_batch) {
if (value.prev_state_value) {
if (grad.check_ig_grad)
paddle::platform::CudaAtomicAdd(grad.check_ig_grad + frame_idx,
r_checkIGrad);
if (grad.check_fg_grad)
paddle::platform::CudaAtomicAdd(grad.check_fg_grad + frame_idx,
r_checkFGrad);
}
if (grad.checkOgGrad)
paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad);
if (grad.check_og_grad)
paddle::platform::CudaAtomicAdd(grad.check_og_grad + frame_idx,
r_checkOGrad);
} else {
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
if (value.prev_state_value) {
if (grad.check_ig_grad) grad.check_ig_grad[frame_idx] += r_checkIGrad;
if (grad.check_fg_grad) grad.check_fg_grad[frame_idx] += r_checkFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
if (grad.check_og_grad) grad.check_og_grad[frame_idx] += r_checkOGrad;
}
}
template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frameSize, int batchSize,
LstmMetaValue<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
/* frame_per_block = 32 batch_per_block = 32 */
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) {
if (batch_size == 1) {
KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, active_node, active_gate,
active_state);
} else {
KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, active_node, active_gate,
active_state);
}
}
......@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, int batchSize,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 16 */
/* frame_per_block = 32 batch_per_block = 16 */
threads = dim3(32, 16);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 16 - 1) / 16);
}
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) {
if (batch_size == 1) {
KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, active_node, active_gate,
active_state);
} else {
KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, active_node, active_gate,
active_state);
}
}
......
......@@ -27,19 +27,19 @@ namespace forward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &prev_state, T &state, T &state_atv, T &output,
T &checkI, T &checkF, T &checkO,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
valueIn = activation(valueIn, active_node);
valueIg = activation(valueIg + prevState * checkI, active_gate);
valueFg = activation(valueFg + prevState * checkF, active_gate);
state = valueIn * valueIg + prevState * valueFg;
valueOg = activation(valueOg + state * checkO, active_gate);
stateAtv = activation(state, active_state);
output = valueOg * stateAtv;
value_in = activation(value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate);
state = value_in * value_ig + prev_state * value_fg;
value_og = activation(value_og + state * checkO, active_gate);
state_atv = activation(state, active_state);
output = value_og * state_atv;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......@@ -48,24 +48,27 @@ class lstm {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI,
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig,
__m256 &value_fg, __m256 &value_og,
__m256 &prev_state, __m256 &state,
__m256 &state_atv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
valueIn = activation(valueIn, active_node);
valueIg = activation(
_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate);
valueFg = activation(
_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate);
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
_mm256_mul_ps(prevState, valueFg));
valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)),
value_in = activation(value_in, active_node);
value_ig =
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
active_gate);
stateAtv = activation(state, active_state);
output = _mm256_mul_ps(valueOg, stateAtv);
value_fg =
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)),
active_gate);
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig),
_mm256_mul_ps(prev_state, value_fg));
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)),
active_gate);
state_atv = activation(state, active_state);
output = _mm256_mul_ps(value_og, state_atv);
}
#endif
#endif
......@@ -78,25 +81,26 @@ namespace backward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad,
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og,
T &prev_state, T &prev_state_grad, T &state,
T &state_grad, T &state_atv, T &output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
gradOg = activation(outputGrad * stateAtv, valueOg, active_gate);
stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) +
gradOg * checkO;
gradIn = activation(stateGrad * valueIg, valueIn, active_node);
gradIg = activation(stateGrad * valueIn, valueIg, active_gate);
gradFg = activation(stateGrad * prevState, valueFg, active_gate);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
grad_og = activation(output_grad * state_atv, value_og, active_gate);
state_grad += activation(output_grad * value_og, state_atv, active_state) +
grad_og * checkO;
grad_in = activation(state_grad * value_ig, value_in, active_node);
grad_ig = activation(state_grad * value_in, value_ig, active_gate);
grad_fg = activation(state_grad * prev_state, value_fg, active_gate);
prev_state_grad =
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg;
checkIGrad = grad_ig * prev_state;
checkFGrad = grad_fg * prev_state;
checkOGrad = grad_og * state;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......@@ -105,32 +109,32 @@ class lstm {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(
__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg,
__m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg,
__m256 &prevState, __m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI,
__m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
__m256 &checkOGrad, activation_mode_t active_node,
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og,
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og,
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_state) {
gradOg =
activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate);
stateGrad = _mm256_add_ps(
activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state),
stateGrad);
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
gradIn =
activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node);
gradIg =
activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate);
gradFg =
activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate);
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
_mm256_mul_ps(gradFg, checkF));
prevStateGrad =
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
checkIGrad = _mm256_mul_ps(gradIg, prevState);
checkFGrad = _mm256_mul_ps(gradFg, prevState);
checkOGrad = _mm256_mul_ps(gradOg, state);
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
state_atv, active_state),
state_grad);
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad);
grad_in =
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node);
grad_ig =
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate);
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg,
active_gate);
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI),
_mm256_mul_ps(grad_fg, checkF));
prev_state_grad =
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad);
checkIGrad = _mm256_mul_ps(grad_ig, prev_state);
checkFGrad = _mm256_mul_ps(grad_fg, prev_state);
checkOGrad = _mm256_mul_ps(grad_og, state);
}
#endif
#endif
......
......@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
}
}
......@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
frame_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
grad.gateGrad += frame_size * 4;
grad.stateGrad += frame_size;
grad.stateActiveGrad += frame_size;
grad.outputGrad += frame_size;
if (grad.prevStateGrad) {
grad.prevStateGrad += frame_size;
grad.gate_grad += frame_size * 4;
grad.state_grad += frame_size;
grad.state_active_grad += frame_size;
grad.output_grad += frame_size;
if (grad.prev_state_grad) {
grad.prev_state_grad += frame_size;
}
}
}
......
......@@ -31,26 +31,26 @@ typedef enum {
template <class T>
struct LstmMetaValue {
T *gateValue;
T *prevStateValue;
T *stateValue;
T *stateActiveValue;
T *outputValue;
T *checkIg;
T *checkFg;
T *checkOg;
T *gate_value;
T *prev_state_value;
T *state_value;
T *state_active_value;
T *output_value;
T *check_ig;
T *check_fg;
T *check_og;
};
template <class T>
struct LstmMetaGrad {
T *gateGrad;
T *prevStateGrad;
T *stateGrad;
T *stateActiveGrad;
T *outputGrad;
T *checkIgGrad;
T *checkFgGrad;
T *checkOgGrad;
T *gate_grad;
T *prev_state_grad;
T *state_grad;
T *state_active_grad;
T *output_grad;
T *check_ig_grad;
T *check_fg_grad;
T *check_og_grad;
};
inline activation_mode_t ActiveType(const std::string &type) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册