提交 3e552cdc 编写于 作者: G guosheng

Fix gru_op related code style

上级 dcf3ffd9
...@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
int frame_size = hidden_dims[1]; int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value; math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.stateWeight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0; Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data(); const size_t* order = batch_gate->lod()[2].data();
...@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
// to reorder. // to reorder.
ReorderInitState<Place, T>(context.device_context(), *h0, order, ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true); &ordered_h0, true);
gru_value.prevOutValue = ordered_h0.data<T>(); gru_value.prev_out_value = ordered_h0.data<T>();
} else { } else {
gru_value.prevOutValue = nullptr; gru_value.prev_out_value = nullptr;
} }
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
...@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend); Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.outputValue = hidden_t.data<T>(); gru_value.output_value = hidden_t.data<T>();
gru_value.gateValue = gate_t.data<T>(); gru_value.gate_value = gate_t.data<T>();
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<Place, T>::compute( math::GRUUnitFunctor<Place, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")), math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation"))); math::ActiveType(context.Attr<std::string>("gate_activation")));
gru_value.prevOutValue = gru_value.outputValue; gru_value.prev_out_value = gru_value.output_value;
} }
math::Batch2LoDTensorFunctor<Place, T> to_seq; math::Batch2LoDTensorFunctor<Place, T> to_seq;
...@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
math::hl_gru_value<T> gru_value; math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.stateWeight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad; math::hl_gru_grad<T> gru_grad;
if (weight_grad) { if (weight_grad) {
gru_grad.gateWeightGrad = gru_grad.gate_weight_grad =
weight_grad->mutable_data<T>(context.GetPlace()); weight_grad->mutable_data<T>(context.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0)); zero(dev_ctx, weight_grad, static_cast<T>(0.0));
gru_grad.stateWeightGrad = gru_grad.state_weight_grad =
weight_grad->data<T>() + 2 * frame_size * frame_size; weight_grad->data<T>() + 2 * frame_size * frame_size;
} else { } else {
gru_grad.gateWeightGrad = nullptr; gru_grad.gate_weight_grad = nullptr;
gru_grad.stateWeightGrad = nullptr; gru_grad.state_weight_grad = nullptr;
} }
auto batch_starts = batch_hidden_grad.lod()[0]; auto batch_starts = batch_hidden_grad.lod()[0];
...@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batch_gate->Slice(bstart, bend);
gru_value.gateValue = gate_t.data<T>(); gru_value.gate_value = gate_t.data<T>();
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
gru_grad.outputGrad = hidden_grad_t.data<T>(); gru_grad.output_grad = hidden_grad_t.data<T>();
Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
gru_grad.gateGrad = gate_grad_t.data<T>(); gru_grad.gate_grad = gate_grad_t.data<T>();
Tensor reset_hidden_prev_grad_t = Tensor reset_hidden_prev_grad_t =
batch_reset_hidden_prev_grad.Slice(bstart, bend); batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>(); gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) { if (n == 0) {
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr; gru_value.prev_out_value = h0 ? ordered_h0.data<T>() : nullptr;
gru_grad.prevOutGrad = gru_grad.prev_out_grad =
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr; h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
} else { } else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]); int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
gru_value.prevOutValue = hidden_prev_t.data<T>(); gru_value.prev_out_value = hidden_prev_t.data<T>();
Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart); Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
gru_grad.prevOutGrad = hidden_prev_grad_t.data<T>(); gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>();
} }
math::GRUUnitGradFunctor<Place, T>::compute( math::GRUUnitGradFunctor<Place, T>::compute(
......
...@@ -25,393 +25,397 @@ namespace detail { ...@@ -25,393 +25,397 @@ namespace detail {
#ifndef __NVCC__ #ifndef __NVCC__
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput, void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gateValue, T *resetOutputValue, T *gate_value, T *reset_output_value,
T *prevOutputValue, int frameSize, T *prev_output_value, int frame_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
T rValueUpdateGate; T r_value_update_gate;
T rValueResetGate; T r_value_reset_gate;
T rValueResetOutput; T r_value_reset_output;
T rPrevOut = 0; T r_prev_out = 0;
T *updateGate = gateValue; T *update_gate = gate_value;
T *resetGate = gateValue + frameSize; T *reset_gate = gate_value + frame_size;
for (int i = 0; i < frameSize; i++) { for (int i = 0; i < frame_size; i++) {
rValueUpdateGate = updateGate[i]; r_value_update_gate = update_gate[i];
rValueResetGate = resetGate[i]; r_value_reset_gate = reset_gate[i];
if (prevOutputValue) { if (prev_output_value) {
rPrevOut = prevOutputValue[i]; r_prev_out = prev_output_value[i];
} }
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
rValueResetOutput, active_gate); r_value_reset_output, active_gate);
updateGate[i] = rValueUpdateGate; update_gate[i] = r_value_update_gate;
resetGate[i] = rValueResetGate; reset_gate[i] = r_value_reset_gate;
resetOutputValue[i] = rValueResetOutput; reset_output_value[i] = r_value_reset_output;
} }
} }
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput, void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gateValue, T *prevOutputValue, T *gate_value, T *prev_output_value,
T *outputValue, int frameSize, T *output_value, int frame_size,
activation_mode_t active_node) { activation_mode_t active_node) {
T rValueUpdateGate; T r_value_update_gate;
T rValueFrameState; T r_value_frame_state;
T rPrevOut = 0; T r_prev_out = 0;
T rOutput; T r_output;
T *updateGate = gateValue; T *update_gate = gate_value;
T *frameState = gateValue + frameSize * 2; T *frame_state = gate_value + frame_size * 2;
for (int i = 0; i < frameSize; i++) { for (int i = 0; i < frame_size; i++) {
rValueUpdateGate = updateGate[i]; r_value_update_gate = update_gate[i];
rValueFrameState = frameState[i]; r_value_frame_state = frame_state[i];
if (prevOutputValue) { if (prev_output_value) {
rPrevOut = prevOutputValue[i]; r_prev_out = prev_output_value[i];
} }
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
active_node); r_output, active_node);
frameState[i] = rValueFrameState; frame_state[i] = r_value_frame_state;
outputValue[i] = rOutput; output_value[i] = r_output;
} }
} }
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue, void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *resetOutputValue, T *prevOutputValue, T *gate_value, T *reset_output_value,
int frameSize, T *prev_output_value, int frame_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
#ifdef __AVX__ #ifdef __AVX__
__m256 rValueUpdateGate; __m256 r_value_update_gate;
__m256 rValueResetGate; __m256 r_value_reset_gate;
__m256 rValueResetOutput; __m256 r_value_reset_output;
__m256 rPrevOut = _mm256_set1_ps(0.0f); __m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 *updateGate = (__m256 *)gateValue; __m256 *update_gate = (__m256 *)gate_value;
__m256 *resetGate = (__m256 *)(gateValue + frameSize); __m256 *reset_gate = (__m256 *)(gate_value + frame_size);
for (int i = 0; i < frameSize / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
rValueUpdateGate = updateGate[i]; r_value_update_gate = update_gate[i];
rValueResetGate = resetGate[i]; r_value_reset_gate = reset_gate[i];
if (prevOutputValue) { if (prev_output_value) {
rPrevOut = ((__m256 *)prevOutputValue)[i]; r_prev_out = ((__m256 *)prev_output_value)[i];
} }
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
rValueResetOutput, active_gate); r_value_reset_output, active_gate);
updateGate[i] = rValueUpdateGate; update_gate[i] = r_value_update_gate;
resetGate[i] = rValueResetGate; reset_gate[i] = r_value_reset_gate;
((__m256 *)resetOutputValue)[i] = rValueResetOutput; ((__m256 *)reset_output_value)[i] = r_value_reset_output;
} }
#endif #endif
} }
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue, void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *prevOutputValue, T *outputValue, T *gate_value, T *prev_output_value,
int frameSize, T *output_value, int frame_size,
activation_mode_t active_node) { activation_mode_t active_node) {
#ifdef __AVX__ #ifdef __AVX__
__m256 rValueUpdateGate; __m256 r_value_update_gate;
__m256 rValueFrameState; __m256 r_value_frame_state;
__m256 rPrevOut = _mm256_set1_ps(0.0f); __m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 rOutput; __m256 r_output;
__m256 *updateGate = (__m256 *)gateValue; __m256 *update_gate = (__m256 *)gate_value;
__m256 *frameState = (__m256 *)(gateValue + frameSize * 2); __m256 *frame_state = (__m256 *)(gate_value + frame_size * 2);
for (int i = 0; i < frameSize / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
rValueUpdateGate = updateGate[i]; r_value_update_gate = update_gate[i];
rValueFrameState = frameState[i]; r_value_frame_state = frame_state[i];
if (prevOutputValue) { if (prev_output_value) {
rPrevOut = ((__m256 *)prevOutputValue)[i]; r_prev_out = ((__m256 *)prev_output_value)[i];
} }
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
active_node); r_output, active_node);
frameState[i] = rValueFrameState; frame_state[i] = r_value_frame_state;
((__m256 *)outputValue)[i] = rOutput; ((__m256 *)output_value)[i] = r_output;
} }
#endif #endif
} }
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput opResetOutput, inline void forward_reset_output(OpResetOutput op_reset_output,
hl_gru_value<T> value, int frameSize, hl_gru_value<T> value, int frame_size,
int batchSize, activation_mode_t active_gate) { int batch_size,
for (int b = 0; b < batchSize; b++) { activation_mode_t active_gate) {
if (OpResetOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output( hl_avx_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue, op_reset_output, value.gate_value, value.reset_output_value,
value.prevOutValue, frameSize, active_gate); value.prev_out_value, frame_size, active_gate);
} else { } else {
hl_naive_gru_forward_reset_output( hl_naive_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue, op_reset_output, value.gate_value, value.reset_output_value,
value.prevOutValue, frameSize, active_gate); value.prev_out_value, frame_size, active_gate);
} }
value.gateValue += frameSize * 3; value.gate_value += frame_size * 3;
value.resetOutputValue += frameSize; value.reset_output_value += frame_size;
if (value.prevOutValue) { if (value.prev_out_value) {
value.prevOutValue += frameSize; value.prev_out_value += frame_size;
} }
} }
} }
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput opFinalOutput, inline void forward_final_output(OpFinalOutput op_final_output,
hl_gru_value<T> value, int frameSize, hl_gru_value<T> value, int frame_size,
int batchSize, activation_mode_t active_node) { int batch_size,
for (int b = 0; b < batchSize; b++) { activation_mode_t active_node) {
if (OpFinalOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { for (int b = 0; b < batch_size; b++) {
hl_avx_gru_forward_final_output(opFinalOutput, value.gateValue, if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
value.prevOutValue, value.outputValue, hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
frameSize, active_node); value.prev_out_value, value.output_value,
frame_size, active_node);
} else { } else {
hl_naive_gru_forward_final_output(opFinalOutput, value.gateValue, hl_naive_gru_forward_final_output(
value.prevOutValue, value.outputValue, op_final_output, value.gate_value, value.prev_out_value,
frameSize, active_node); value.output_value, frame_size, active_node);
} }
value.gateValue += frameSize * 3; value.gate_value += frame_size * 3;
value.outputValue += frameSize; value.output_value += frame_size;
if (value.prevOutValue) { if (value.prev_out_value) {
value.prevOutValue += frameSize; value.prev_out_value += frame_size;
} }
} }
} }
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *outputGrad, T *prev_out_grad, T *output_grad,
int frameSize, int frame_size,
activation_mode_t active_node) { activation_mode_t active_node) {
T rUpdateGateValue; T r_update_gate_value;
T rUpdateGateGrad; T r_update_gate_grad;
T rFrameStateValue; T r_frame_state_value;
T rFrameStateGrad; T r_frame_state_grad;
T rOutGrad; T r_out_grad;
T rPrevOutValue = 0; T r_prev_out_value = 0;
T rPrevOutGrad = 0; T r_prev_out_grad = 0;
T *updateGateValue = gateValue; T *update_gate_value = gate_value;
T *updateGateGrad = gateGrad; T *update_gate_grad = gate_grad;
T *frameStateValue = gateValue + frameSize * 2; T *frame_state_value = gate_value + frame_size * 2;
T *frameStateGrad = gateGrad + frameSize * 2; T *frame_state_grad = gate_grad + frame_size * 2;
for (int i = 0; i < frameSize; i++) { for (int i = 0; i < frame_size; i++) {
rUpdateGateValue = updateGateValue[i]; r_update_gate_value = update_gate_value[i];
rFrameStateValue = frameStateValue[i]; r_frame_state_value = frame_state_value[i];
rOutGrad = outputGrad[i]; r_out_grad = output_grad[i];
if (prevOutValue) { if (prev_out_value) {
rPrevOutValue = prevOutValue[i]; r_prev_out_value = prev_out_value[i];
} }
if (prevOutGrad) { if (prev_out_grad) {
rPrevOutGrad = prevOutGrad[i]; r_prev_out_grad = prev_out_grad[i];
} }
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
active_node); r_out_grad, active_node);
updateGateGrad[i] = rUpdateGateGrad; update_gate_grad[i] = r_update_gate_grad;
frameStateGrad[i] = rFrameStateGrad; frame_state_grad[i] = r_frame_state_grad;
if (prevOutGrad) { if (prev_out_grad) {
prevOutGrad[i] = rPrevOutGrad; prev_out_grad[i] = r_prev_out_grad;
} }
} }
} }
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *resetOutputGrad, T *prev_out_grad, T *reset_output_grad,
int frameSize, int frame_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
T rUpdateGateValue; T r_update_gate_value;
T rUpdateGateGrad; T r_update_gate_grad;
T rResetGateValue; T r_reset_gate_value;
T rResetGateGrad; T r_reset_gate_grad;
T rResetOutputGrad = 0; T r_reset_output_grad = 0;
T rPrevOutValue = 0; T r_prev_out_value = 0;
T rPrevOutGrad = 0; T r_prev_out_grad = 0;
T *updateGateValue = gateValue; T *update_gate_value = gate_value;
T *updateGateGrad = gateGrad; T *update_gate_grad = gate_grad;
T *resetGateValue = gateValue + frameSize; T *reset_gate_value = gate_value + frame_size;
T *resetGateGrad = gateGrad + frameSize; T *reset_gate_grad = gate_grad + frame_size;
for (int i = 0; i < frameSize; i++) { for (int i = 0; i < frame_size; i++) {
rUpdateGateValue = updateGateValue[i]; r_update_gate_value = update_gate_value[i];
rUpdateGateGrad = updateGateGrad[i]; r_update_gate_grad = update_gate_grad[i];
rResetGateValue = resetGateValue[i]; r_reset_gate_value = reset_gate_value[i];
if (prevOutValue && prevOutGrad) { if (prev_out_value && prev_out_grad) {
rResetOutputGrad = resetOutputGrad[i]; r_reset_output_grad = reset_output_grad[i];
} }
if (prevOutValue) { if (prev_out_value) {
rPrevOutValue = prevOutValue[i]; r_prev_out_value = prev_out_value[i];
} }
if (prevOutGrad) { if (prev_out_grad) {
rPrevOutGrad = prevOutGrad[i]; r_prev_out_grad = prev_out_grad[i];
} }
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
active_gate); r_reset_output_grad, active_gate);
updateGateGrad[i] = rUpdateGateGrad; update_gate_grad[i] = r_update_gate_grad;
resetGateGrad[i] = rResetGateGrad; reset_gate_grad[i] = r_reset_gate_grad;
if (prevOutGrad) { if (prev_out_grad) {
prevOutGrad[i] = rPrevOutGrad; prev_out_grad[i] = r_prev_out_grad;
} }
} }
} }
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *outputGrad, T *prev_out_grad, T *output_grad,
int frameSize, int frame_size,
activation_mode_t active_node) { activation_mode_t active_node) {
#ifdef __AVX__ #ifdef __AVX__
__m256 rUpdateGateValue; __m256 r_update_gate_value;
__m256 rUpdateGateGrad; __m256 r_update_gate_grad;
__m256 rFrameStateValue; __m256 r_frame_state_value;
__m256 rFrameStateGrad; __m256 r_frame_state_grad;
__m256 rOutGrad; __m256 r_out_grad;
__m256 rPrevOutValue = _mm256_set1_ps(0.0f); __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f); __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue; __m256 *update_gate_value = (__m256 *)gate_value;
__m256 *updateGateGrad = (__m256 *)gateGrad; __m256 *update_gate_grad = (__m256 *)gate_grad;
__m256 *frameStateValue = (__m256 *)(gateValue + frameSize * 2); __m256 *frame_state_value = (__m256 *)(gate_value + frame_size * 2);
__m256 *frameStateGrad = (__m256 *)(gateGrad + frameSize * 2); __m256 *frame_state_grad = (__m256 *)(gate_grad + frame_size * 2);
for (int i = 0; i < frameSize / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
rUpdateGateValue = updateGateValue[i]; r_update_gate_value = update_gate_value[i];
rFrameStateValue = frameStateValue[i]; r_frame_state_value = frame_state_value[i];
rOutGrad = ((__m256 *)outputGrad)[i]; r_out_grad = ((__m256 *)output_grad)[i];
if (prevOutValue) { if (prev_out_value) {
rPrevOutValue = ((__m256 *)prevOutValue)[i]; r_prev_out_value = ((__m256 *)prev_out_value)[i];
} }
if (prevOutGrad) { if (prev_out_grad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
} }
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
active_node); r_out_grad, active_node);
updateGateGrad[i] = rUpdateGateGrad; update_gate_grad[i] = r_update_gate_grad;
frameStateGrad[i] = rFrameStateGrad; frame_state_grad[i] = r_frame_state_grad;
if (prevOutGrad) { if (prev_out_grad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad; ((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
} }
} }
#endif #endif
} }
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *resetOutputGrad, T *prev_out_grad, T *reset_output_grad,
int frameSize, int frame_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
#ifdef __AVX__ #ifdef __AVX__
__m256 rUpdateGateValue; __m256 r_update_gate_value;
__m256 rUpdateGateGrad; __m256 r_update_gate_grad;
__m256 rResetGateValue; __m256 r_reset_gate_value;
__m256 rResetGateGrad; __m256 r_reset_gate_grad;
__m256 rResetOutputGrad = _mm256_set1_ps(0.0f); __m256 r_reset_output_grad = _mm256_set1_ps(0.0f);
__m256 rPrevOutValue = _mm256_set1_ps(0.0f); __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f); __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue; __m256 *update_gate_value = (__m256 *)gate_value;
__m256 *updateGateGrad = (__m256 *)gateGrad; __m256 *update_gate_grad = (__m256 *)gate_grad;
__m256 *resetGateValue = (__m256 *)(gateValue + frameSize); __m256 *reset_gate_value = (__m256 *)(gate_value + frame_size);
__m256 *resetGateGrad = (__m256 *)(gateGrad + frameSize); __m256 *reset_gate_grad = (__m256 *)(gate_grad + frame_size);
for (int i = 0; i < frameSize / 8; i++) { for (int i = 0; i < frame_size / 8; i++) {
rUpdateGateValue = updateGateValue[i]; r_update_gate_value = update_gate_value[i];
rUpdateGateGrad = updateGateGrad[i]; r_update_gate_grad = update_gate_grad[i];
rResetGateValue = resetGateValue[i]; r_reset_gate_value = reset_gate_value[i];
if (prevOutValue && prevOutGrad) { if (prev_out_value && prev_out_grad) {
rResetOutputGrad = ((__m256 *)resetOutputGrad)[i]; r_reset_output_grad = ((__m256 *)reset_output_grad)[i];
} }
if (prevOutValue) { if (prev_out_value) {
rPrevOutValue = ((__m256 *)prevOutValue)[i]; r_prev_out_value = ((__m256 *)prev_out_value)[i];
} }
if (prevOutGrad) { if (prev_out_grad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
} }
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
active_gate); r_reset_output_grad, active_gate);
updateGateGrad[i] = rUpdateGateGrad; update_gate_grad[i] = r_update_gate_grad;
resetGateGrad[i] = rResetGateGrad; reset_gate_grad[i] = r_reset_gate_grad;
if (prevOutGrad) { if (prev_out_grad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad; ((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
} }
} }
#endif #endif
} }
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad opStateGrad, hl_gru_value<T> value, inline void backward_state_grad(OpStateGrad op_state_grad,
hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_node) { int frame_size, int batch_size,
for (int b = 0; b < batchSize; b++) { activation_mode_t active_node) {
if (OpStateGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad( hl_avx_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node); grad.prev_out_grad, grad.output_grad, frame_size, active_node);
} else { } else {
hl_naive_gru_backward_state_grad( hl_naive_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node); grad.prev_out_grad, grad.output_grad, frame_size, active_node);
} }
value.gateValue += frameSize * 3; value.gate_value += frame_size * 3;
if (value.prevOutValue) { if (value.prev_out_value) {
value.prevOutValue += frameSize; value.prev_out_value += frame_size;
} }
grad.gateGrad += frameSize * 3; grad.gate_grad += frame_size * 3;
grad.outputGrad += frameSize; grad.output_grad += frame_size;
if (grad.prevOutGrad) { if (grad.prev_out_grad) {
grad.prevOutGrad += frameSize; grad.prev_out_grad += frame_size;
} }
} }
} }
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad opResetGrad, hl_gru_value<T> value, inline void backward_reset_grad(OpResetGrad op_reset_grad,
hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_gate) { int frame_size, int batch_size,
for (int b = 0; b < batchSize; b++) { activation_mode_t active_gate) {
if (OpResetGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad( hl_avx_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
} else { } else {
hl_naive_gru_backward_reset_grad( hl_naive_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
} }
value.gateValue += frameSize * 3; value.gate_value += frame_size * 3;
if (value.prevOutValue) { if (value.prev_out_value) {
value.prevOutValue += frameSize; value.prev_out_value += frame_size;
} }
grad.gateGrad += frameSize * 3; grad.gate_grad += frame_size * 3;
grad.resetOutputGrad += frameSize; grad.reset_output_grad += frame_size;
if (grad.prevOutGrad) { if (grad.prev_out_grad) {
grad.prevOutGrad += frameSize; grad.prev_out_grad += frame_size;
} }
} }
} }
......
...@@ -27,174 +27,174 @@ namespace math { ...@@ -27,174 +27,174 @@ namespace math {
namespace detail { namespace detail {
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpResetOutput, bool isBatch, typename T> template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput, __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gateValue, T *resetOutputValue, T *gate_value, T *reset_output_value,
T *prevOutputValue, int frameSize, T *prev_output_value, int frame_size,
int batchSize, int batch_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
resetOutputValue += batchIdx * frameSize; reset_output_value += batch_idx * frame_size;
} }
T rPrevOut = 0; T r_prev_out = 0;
T rValueResetOutput; T r_value_reset_output;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
T rValueResetGate = gateValue[frameIdx + frameSize * 1]; T r_value_reset_gate = gate_value[frame_idx + frame_size * 1];
if (prevOutputValue) { if (prev_output_value) {
if (isBatch) prevOutputValue += batchIdx * frameSize; if (is_batch) prev_output_value += batch_idx * frame_size;
rPrevOut = prevOutputValue[frameIdx]; r_prev_out = prev_output_value[frame_idx];
} }
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput, op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
active_gate); r_value_reset_output, active_gate);
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate; gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
gateValue[frameIdx + frameSize * 1] = rValueResetGate; gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
resetOutputValue[frameIdx] = rValueResetOutput; reset_output_value[frame_idx] = r_value_reset_output;
} }
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpFinalOutput, bool isBatch, typename T> template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput, __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gateValue, T *prevOutputValue, T *gate_value, T *prev_output_value,
T *outputValue, int frameSize, T *output_value, int frame_size,
int batchSize, int batch_size,
activation_mode_t active_node) { activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
outputValue += batchIdx * frameSize; output_value += batch_idx * frame_size;
} }
T rOutput; T r_output;
T rPrevOut = 0; T r_prev_out = 0;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
T rValueFrameState = gateValue[frameIdx + frameSize * 2]; T r_value_frame_state = gate_value[frame_idx + frame_size * 2];
if (prevOutputValue) { if (prev_output_value) {
if (isBatch) prevOutputValue += batchIdx * frameSize; if (is_batch) prev_output_value += batch_idx * frame_size;
rPrevOut = prevOutputValue[frameIdx]; r_prev_out = prev_output_value[frame_idx];
} }
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
active_node); r_output, active_node);
gateValue[frameIdx + frameSize * 2] = rValueFrameState; gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
outputValue[frameIdx] = rOutput; output_value[frame_idx] = r_output;
} }
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpStateGrad, bool isBatch, typename T> template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue, __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *outputGrad, T *prev_out_grad, T *output_grad,
int frameSize, int batchSize, int frame_size, int batch_size,
activation_mode_t active_node) { activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
gateGrad += batchIdx * 3 * frameSize; gate_grad += batch_idx * 3 * frame_size;
outputGrad += batchIdx * frameSize; output_grad += batch_idx * frame_size;
} }
T rUpdateGateGrad; T r_update_gate_grad;
T rFrameStateGrad; T r_frame_state_grad;
T rPrevOutValue = 0; T r_prev_out_value = 0;
T rPrevOutGrad = 0; T r_prev_out_grad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
T rFrameStateValue = gateValue[frameIdx + frameSize * 2]; T r_frame_state_value = gate_value[frame_idx + frame_size * 2];
T rOutGrad = outputGrad[frameIdx]; T r_out_grad = output_grad[frame_idx];
if (prevOutValue && prevOutGrad) { if (prev_out_value && prev_out_grad) {
if (isBatch) prevOutValue += batchIdx * frameSize; if (is_batch) prev_out_value += batch_idx * frame_size;
rPrevOutValue = prevOutValue[frameIdx]; r_prev_out_value = prev_out_value[frame_idx];
if (isBatch) prevOutGrad += batchIdx * frameSize; if (is_batch) prev_out_grad += batch_idx * frame_size;
rPrevOutGrad = prevOutGrad[frameIdx]; r_prev_out_grad = prev_out_grad[frame_idx];
} }
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
active_node); r_out_grad, active_node);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
if (prevOutGrad) { if (prev_out_grad) {
prevOutGrad[frameIdx] = rPrevOutGrad; prev_out_grad[frame_idx] = r_prev_out_grad;
} }
} }
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpResetGrad, bool isBatch, typename T> template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue, __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *resetOutputGrad, T *prev_out_grad, T *reset_output_grad,
int frameSize, int batchSize, int frame_size, int batch_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
gateGrad += batchIdx * 3 * frameSize; gate_grad += batch_idx * 3 * frame_size;
resetOutputGrad += batchIdx * frameSize; reset_output_grad += batch_idx * frame_size;
} }
T rResetGateGrad; T r_reset_gate_grad;
T rPrevOutValue = 0; T r_prev_out_value = 0;
T rPrevOutGrad = 0; T r_prev_out_grad = 0;
T rResetOutputGrad = 0; T r_reset_output_grad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0]; T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0];
T rResetGateValue = gateValue[frameIdx + frameSize * 1]; T r_reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prevOutValue && prevOutGrad) { if (prev_out_value && prev_out_grad) {
if (isBatch) prevOutValue += batchIdx * frameSize; if (is_batch) prev_out_value += batch_idx * frame_size;
if (isBatch) prevOutGrad += batchIdx * frameSize; if (is_batch) prev_out_grad += batch_idx * frame_size;
rPrevOutValue = prevOutValue[frameIdx]; r_prev_out_value = prev_out_value[frame_idx];
rPrevOutGrad = prevOutGrad[frameIdx]; r_prev_out_grad = prev_out_grad[frame_idx];
rResetOutputGrad = resetOutputGrad[frameIdx]; r_reset_output_grad = reset_output_grad[frame_idx];
} }
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
active_gate); r_reset_output_grad, active_gate);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad; gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
if (prevOutGrad) { if (prev_out_grad) {
prevOutGrad[frameIdx] = rPrevOutGrad; prev_out_grad[frame_idx] = r_prev_out_grad;
} }
} }
} // namespace detail } // namespace detail
......
...@@ -28,23 +28,25 @@ namespace forward { ...@@ -28,23 +28,25 @@ namespace forward {
template <typename T> template <typename T>
class gru_resetOutput { class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut, HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate,
T &valueResetOutput, activation_mode_t actGate) { T &prev_out, T &value_reset_output,
valueUpdateGate = activation(valueUpdateGate, actGate); activation_mode_t act_gate) {
valueResetGate = activation(valueResetGate, actGate); value_update_gate = activation(value_update_gate, act_gate);
valueResetOutput = prevOut * valueResetGate; value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate;
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &prevOut, __m256 &valueResetOutput, __m256 &value_reset_gate, __m256 &prev_out,
activation_mode_t actGate) { __m256 &value_reset_output,
valueUpdateGate = activation(valueUpdateGate, actGate); activation_mode_t act_gate) {
valueResetGate = activation(valueResetGate, actGate); value_update_gate = activation(value_update_gate, act_gate);
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate); value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate);
} }
#endif #endif
#endif #endif
...@@ -53,24 +55,26 @@ class gru_resetOutput { ...@@ -53,24 +55,26 @@ class gru_resetOutput {
template <typename T> template <typename T>
class gru_finalOutput { class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut, HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state,
T &valueOutput, activation_mode_t actInput) { T &prev_out, T &value_output,
valueFrameState = activation(valueFrameState, actInput); activation_mode_t act_input) {
valueOutput = prevOut - (valueUpdateGate * prevOut) + value_frame_state = activation(value_frame_state, act_input);
(valueUpdateGate * valueFrameState); value_output = prev_out - (value_update_gate * prev_out) +
(value_update_gate * value_frame_state);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &prevOut, __m256 &valueOutput, __m256 &value_frame_state, __m256 &prev_out,
activation_mode_t actInput) { __m256 &value_output,
valueFrameState = activation(valueFrameState, actInput); activation_mode_t act_input) {
valueOutput = _mm256_add_ps( value_frame_state = activation(value_frame_state, act_input);
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)), value_output = _mm256_add_ps(
_mm256_mul_ps(valueUpdateGate, valueFrameState)); _mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)),
_mm256_mul_ps(value_update_gate, value_frame_state));
} }
#endif #endif
#endif #endif
...@@ -82,34 +86,37 @@ namespace backward { ...@@ -82,34 +86,37 @@ namespace backward {
template <typename T> template <typename T>
class gru_stateGrad { class gru_stateGrad {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &valueFrameState, T &gradFrameState, T &value_frame_state, T &grad_frame_state,
T &valuePrevOut, T &gradPrevOut, T &gradOutput, T &value_prev_out, T &grad_prev_out,
activation_mode_t actInput) { T &grad_output, activation_mode_t act_input) {
gradUpdateGate = (gradOutput * valueFrameState); grad_update_gate = (grad_output * value_frame_state);
gradUpdateGate -= (gradOutput * valuePrevOut); grad_update_gate -= (grad_output * value_prev_out);
gradPrevOut -= (gradOutput * valueUpdateGate); grad_prev_out -= (grad_output * value_update_gate);
gradPrevOut += gradOutput; grad_prev_out += grad_output;
gradFrameState = grad_frame_state = activation(grad_output * value_update_gate,
activation(gradOutput * valueUpdateGate, valueFrameState, actInput); value_frame_state, act_input);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &valueFrameState, __m256 &gradFrameState, __m256 &grad_update_gate,
__m256 &valuePrevOut, __m256 &gradPrevOut, __m256 &value_frame_state,
__m256 &gradOutput, activation_mode_t actInput) { __m256 &grad_frame_state, __m256 &value_prev_out,
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState); __m256 &grad_prev_out, __m256 &grad_output,
gradUpdateGate = activation_mode_t act_input) {
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut)); grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state);
gradPrevOut = _mm256_add_ps( grad_update_gate = _mm256_sub_ps(
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)), grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out));
gradOutput); grad_prev_out = _mm256_add_ps(
gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate), _mm256_sub_ps(grad_prev_out,
valueFrameState, actInput); _mm256_mul_ps(grad_output, value_update_gate)),
grad_output);
grad_frame_state = activation(_mm256_mul_ps(grad_output, value_update_gate),
value_frame_state, act_input);
} }
#endif #endif
#endif #endif
...@@ -118,30 +125,32 @@ class gru_stateGrad { ...@@ -118,30 +125,32 @@ class gru_stateGrad {
template <typename T> template <typename T>
class gru_resetGrad { class gru_resetGrad {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &valueResetGate, T &gradResetGate, T &value_reset_gate, T &grad_reset_gate,
T &valuePrevOut, T &gradPrevOut, T &value_prev_out, T &grad_prev_out,
T &gradResetOutput, activation_mode_t actGate) { T &grad_reset_output, activation_mode_t act_gate) {
gradResetGate = (gradResetOutput * valuePrevOut); grad_reset_gate = (grad_reset_output * value_prev_out);
gradPrevOut += (gradResetOutput * valueResetGate); grad_prev_out += (grad_reset_output * value_reset_gate);
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate); grad_update_gate =
gradResetGate = activation(gradResetGate, valueResetGate, actGate); activation(grad_update_gate, value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &valueResetGate, __m256 &gradResetGate, __m256 &grad_update_gate, __m256 &value_reset_gate,
__m256 &valuePrevOut, __m256 &gradPrevOut, __m256 &grad_reset_gate, __m256 &value_prev_out,
__m256 &gradResetOutput, __m256 &grad_prev_out, __m256 &grad_reset_output,
activation_mode_t actGate) { activation_mode_t act_gate) {
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut); grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out);
gradPrevOut = _mm256_add_ps(gradPrevOut, grad_prev_out = _mm256_add_ps(
_mm256_mul_ps(gradResetOutput, valueResetGate)); grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate));
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate); grad_update_gate =
gradResetGate = activation(gradResetGate, valueResetGate, actGate); activation(grad_update_gate, value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
} }
#endif #endif
#endif #endif
......
...@@ -21,29 +21,29 @@ namespace math { ...@@ -21,29 +21,29 @@ namespace math {
template <typename T> template <typename T>
struct GRUUnitFunctor<platform::CPUPlace, T> { struct GRUUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize, hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
#ifndef __NVCC__ #ifndef __NVCC__
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1, context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
value.gateValue, frameSize * 3); 1, value.gate_value, frame_size * 3);
} }
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value, detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frameSize, batchSize, active_gate); frame_size, batch_size, active_gate);
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1, context, false, false, batch_size, frame_size, frame_size, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, value.reset_output_value, frame_size, value.state_weight, frame_size,
value.gateValue + frameSize * 2, frameSize * 3); 1, value.gate_value + frame_size * 2, frame_size * 3);
} }
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frameSize, batchSize, active_node); frame_size, batch_size, active_node);
#endif #endif
} }
}; };
...@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> { ...@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
template <typename T> template <typename T>
struct GRUUnitGradFunctor<platform::CPUPlace, T> { struct GRUUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_node, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
#ifndef __NVCC__ #ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frameSize, batchSize, active_node); grad, frame_size, batch_size, active_node);
if (value.prevOutValue && grad.prevOutGrad) { if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1, context, false, true, batch_size, frame_size, frame_size, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frameSize, 0, grad.resetOutputGrad, frameSize); frame_size, 0, grad.reset_output_grad, frame_size);
if (grad.stateWeightGrad) { if (grad.state_weight_grad) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1, context, true, false, frame_size, frame_size, batch_size, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, value.reset_output_value, frame_size,
frameSize * 3, 1, grad.stateWeightGrad, frameSize); grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
} }
} }
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value, detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frameSize, batchSize, active_gate); grad, frame_size, batch_size, active_gate);
if (grad.prevOutGrad && value.prevOutValue) { if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1, context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prevOutGrad, frameSize); grad.prev_out_grad, frame_size);
if (grad.gateWeightGrad) { if (grad.gate_weight_grad) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1, context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gateWeightGrad, frameSize * 2); grad.gate_weight_grad, frame_size * 2);
} }
} }
#endif #endif
......
...@@ -21,66 +21,66 @@ namespace math { ...@@ -21,66 +21,66 @@ namespace math {
template <typename T> template <typename T>
struct GRUUnitFunctor<platform::GPUPlace, T> { struct GRUUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize, hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream(); reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batchSize == 1) { if (batch_size == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024; int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1); threads = dim3(frame_per_block, 1);
grid = dim3(frameBlocks, 1); grid = dim3(frame_blocks, 1);
} else { } else {
threads = dim3(32, 32); threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
} }
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1, context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
value.gateValue, frameSize * 3); 1, value.gate_value, frame_size * 3);
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>, detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ false, /* is_batch= */ false,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue, detail::forward::gru_resetOutput<T>(), value.gate_value,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize, value.reset_output_value, value.prev_out_value, frame_size,
active_gate); batch_size, active_gate);
} else { } else {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>, detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ true, /* is_batch= */ true,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue, detail::forward::gru_resetOutput<T>(), value.gate_value,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize, value.reset_output_value, value.prev_out_value, frame_size,
active_gate); batch_size, active_gate);
} }
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1, context, false, false, batch_size, frame_size, frame_size, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, value.reset_output_value, frame_size, value.state_weight, frame_size,
value.gateValue + frameSize * 2, frameSize * 3); 1, value.gate_value + frame_size * 2, frame_size * 3);
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>, detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ false, /* is_batch= */ false,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prevOutValue, value.outputValue, frameSize, batchSize, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node);
} else { } else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>, detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ true, /* is_batch= */ true,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prevOutValue, value.outputValue, frameSize, batchSize, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node);
} }
} }
...@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> { ...@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
template <typename T> template <typename T>
struct GRUUnitGradFunctor<platform::GPUPlace, T> { struct GRUUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_node, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream(); reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batchSize == 1) { if (batch_size == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024; int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1); threads = dim3(frame_per_block, 1);
grid = dim3(frameBlocks, 1); grid = dim3(frame_blocks, 1);
} else { } else {
threads = dim3(32, 32); threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruBackwardStateGrad< detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>, detail::backward::gru_stateGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_stateGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_node); grad.output_grad, frame_size, batch_size, active_node);
} else { } else {
detail::KeGruBackwardStateGrad< detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>, detail::backward::gru_stateGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_stateGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_node); grad.output_grad, frame_size, batch_size, active_node);
} }
if (value.prevOutValue && grad.prevOutGrad) { if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1, context, false, true, batch_size, frame_size, frame_size, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frameSize, 0, grad.resetOutputGrad, frameSize); frame_size, 0, grad.reset_output_grad, frame_size);
if (grad.stateWeightGrad) { if (grad.state_weight_grad) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1, context, true, false, frame_size, frame_size, batch_size, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, value.reset_output_value, frame_size,
frameSize * 3, 1, grad.stateWeightGrad, frameSize); grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
} }
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruBackwardResetGrad< detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>, detail::backward::gru_resetGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_resetGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_gate); grad.reset_output_grad, frame_size, batch_size, active_gate);
} else { } else {
detail::KeGruBackwardResetGrad< detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>, detail::backward::gru_resetGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_resetGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_gate); grad.reset_output_grad, frame_size, batch_size, active_gate);
} }
if (grad.prevOutGrad && value.prevOutValue) { if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1, context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prevOutGrad, frameSize); grad.prev_out_grad, frame_size);
if (grad.gateWeightGrad) { if (grad.gate_weight_grad) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1, context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gateWeightGrad, frameSize * 2); grad.gate_weight_grad, frame_size * 2);
} }
} }
} }
......
...@@ -22,28 +22,28 @@ namespace math { ...@@ -22,28 +22,28 @@ namespace math {
// TODO(guosheng): refine code style in gru_compute // TODO(guosheng): refine code style in gru_compute
template <typename T> template <typename T>
struct hl_gru_value { struct hl_gru_value {
T *gateWeight; T *gate_weight;
T *stateWeight; T *state_weight;
T *gateValue; T *gate_value;
T *resetOutputValue; T *reset_output_value;
T *outputValue; T *output_value;
T *prevOutValue; T *prev_out_value;
}; };
template <typename T> template <typename T>
struct hl_gru_grad { struct hl_gru_grad {
T *gateWeightGrad; T *gate_weight_grad;
T *stateWeightGrad; T *state_weight_grad;
T *gateGrad; T *gate_grad;
T *resetOutputGrad; T *reset_output_grad;
T *outputGrad; T *output_grad;
T *prevOutGrad; T *prev_out_grad;
}; };
template <typename Place, typename T> template <typename Place, typename T>
struct GRUUnitFunctor { struct GRUUnitFunctor {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize, hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate); activation_mode_t active_gate);
}; };
...@@ -51,8 +51,9 @@ struct GRUUnitFunctor { ...@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
template <typename Place, typename T> template <typename Place, typename T>
struct GRUUnitGradFunctor { struct GRUUnitGradFunctor {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_node, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate); activation_mode_t active_gate);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册