未验证 提交 0066bbba 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #6021 from guoshengCS/fix-GRUOp-codestyle

Fix gru_op related code style
......@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data();
......@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
// to reorder.
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
gru_value.prevOutValue = ordered_h0.data<T>();
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prevOutValue = nullptr;
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
......@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.outputValue = hidden_t.data<T>();
gru_value.gateValue = gate_t.data<T>();
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<Place, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
gru_value.prevOutValue = gru_value.outputValue;
gru_value.prev_out_value = gru_value.output_value;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
......@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad;
if (weight_grad) {
gru_grad.gateWeightGrad =
gru_grad.gate_weight_grad =
weight_grad->mutable_data<T>(context.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
gru_grad.stateWeightGrad =
gru_grad.state_weight_grad =
weight_grad->data<T>() + 2 * frame_size * frame_size;
} else {
gru_grad.gateWeightGrad = nullptr;
gru_grad.stateWeightGrad = nullptr;
gru_grad.gate_weight_grad = nullptr;
gru_grad.state_weight_grad = nullptr;
}
auto batch_starts = batch_hidden_grad.lod()[0];
......@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
gru_value.gateValue = gate_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
gru_grad.outputGrad = hidden_grad_t.data<T>();
gru_grad.output_grad = hidden_grad_t.data<T>();
Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
gru_grad.gateGrad = gate_grad_t.data<T>();
gru_grad.gate_grad = gate_grad_t.data<T>();
Tensor reset_hidden_prev_grad_t =
batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) {
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
gru_grad.prevOutGrad =
gru_value.prev_out_value = h0 ? ordered_h0.data<T>() : nullptr;
gru_grad.prev_out_grad =
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
} else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
gru_value.prevOutValue = hidden_prev_t.data<T>();
gru_value.prev_out_value = hidden_prev_t.data<T>();
Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
gru_grad.prevOutGrad = hidden_prev_grad_t.data<T>();
gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>();
}
math::GRUUnitGradFunctor<Place, T>::compute(
......
......@@ -25,393 +25,397 @@ namespace detail {
#ifndef __NVCC__
template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
T *gateValue, T *resetOutputValue,
T *prevOutputValue, int frameSize,
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
activation_mode_t active_gate) {
T rValueUpdateGate;
T rValueResetGate;
T rValueResetOutput;
T rPrevOut = 0;
T *updateGate = gateValue;
T *resetGate = gateValue + frameSize;
for (int i = 0; i < frameSize; i++) {
rValueUpdateGate = updateGate[i];
rValueResetGate = resetGate[i];
if (prevOutputValue) {
rPrevOut = prevOutputValue[i];
T r_value_update_gate;
T r_value_reset_gate;
T r_value_reset_output;
T r_prev_out = 0;
T *update_gate = gate_value;
T *reset_gate = gate_value + frame_size;
for (int i = 0; i < frame_size; i++) {
r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i];
if (prev_output_value) {
r_prev_out = prev_output_value[i];
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, active_gate);
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
r_value_reset_output, active_gate);
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
resetOutputValue[i] = rValueResetOutput;
update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate;
reset_output_value[i] = r_value_reset_output;
}
}
template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput,
T *gateValue, T *prevOutputValue,
T *outputValue, int frameSize,
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
activation_mode_t active_node) {
T rValueUpdateGate;
T rValueFrameState;
T rPrevOut = 0;
T rOutput;
T *updateGate = gateValue;
T *frameState = gateValue + frameSize * 2;
for (int i = 0; i < frameSize; i++) {
rValueUpdateGate = updateGate[i];
rValueFrameState = frameState[i];
if (prevOutputValue) {
rPrevOut = prevOutputValue[i];
T r_value_update_gate;
T r_value_frame_state;
T r_prev_out = 0;
T r_output;
T *update_gate = gate_value;
T *frame_state = gate_value + frame_size * 2;
for (int i = 0; i < frame_size; i++) {
r_value_update_gate = update_gate[i];
r_value_frame_state = frame_state[i];
if (prev_output_value) {
r_prev_out = prev_output_value[i];
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
active_node);
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
r_output, active_node);
frameState[i] = rValueFrameState;
outputValue[i] = rOutput;
frame_state[i] = r_value_frame_state;
output_value[i] = r_output;
}
}
template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue,
T *resetOutputValue, T *prevOutputValue,
int frameSize,
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
activation_mode_t active_gate) {
#ifdef __AVX__
__m256 rValueUpdateGate;
__m256 rValueResetGate;
__m256 rValueResetOutput;
__m256 rPrevOut = _mm256_set1_ps(0.0f);
__m256 *updateGate = (__m256 *)gateValue;
__m256 *resetGate = (__m256 *)(gateValue + frameSize);
for (int i = 0; i < frameSize / 8; i++) {
rValueUpdateGate = updateGate[i];
rValueResetGate = resetGate[i];
if (prevOutputValue) {
rPrevOut = ((__m256 *)prevOutputValue)[i];
__m256 r_value_update_gate;
__m256 r_value_reset_gate;
__m256 r_value_reset_output;
__m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 *update_gate = (__m256 *)gate_value;
__m256 *reset_gate = (__m256 *)(gate_value + frame_size);
for (int i = 0; i < frame_size / 8; i++) {
r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i];
if (prev_output_value) {
r_prev_out = ((__m256 *)prev_output_value)[i];
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, active_gate);
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
r_value_reset_output, active_gate);
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
((__m256 *)resetOutputValue)[i] = rValueResetOutput;
update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate;
((__m256 *)reset_output_value)[i] = r_value_reset_output;
}
#endif
}
template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue,
T *prevOutputValue, T *outputValue,
int frameSize,
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
activation_mode_t active_node) {
#ifdef __AVX__
__m256 rValueUpdateGate;
__m256 rValueFrameState;
__m256 rPrevOut = _mm256_set1_ps(0.0f);
__m256 rOutput;
__m256 *updateGate = (__m256 *)gateValue;
__m256 *frameState = (__m256 *)(gateValue + frameSize * 2);
for (int i = 0; i < frameSize / 8; i++) {
rValueUpdateGate = updateGate[i];
rValueFrameState = frameState[i];
if (prevOutputValue) {
rPrevOut = ((__m256 *)prevOutputValue)[i];
__m256 r_value_update_gate;
__m256 r_value_frame_state;
__m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 r_output;
__m256 *update_gate = (__m256 *)gate_value;
__m256 *frame_state = (__m256 *)(gate_value + frame_size * 2);
for (int i = 0; i < frame_size / 8; i++) {
r_value_update_gate = update_gate[i];
r_value_frame_state = frame_state[i];
if (prev_output_value) {
r_prev_out = ((__m256 *)prev_output_value)[i];
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
active_node);
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
r_output, active_node);
frameState[i] = rValueFrameState;
((__m256 *)outputValue)[i] = rOutput;
frame_state[i] = r_value_frame_state;
((__m256 *)output_value)[i] = r_output;
}
#endif
}
template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput opResetOutput,
hl_gru_value<T> value, int frameSize,
int batchSize, activation_mode_t active_gate) {
for (int b = 0; b < batchSize; b++) {
if (OpResetOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
inline void forward_reset_output(OpResetOutput op_reset_output,
hl_gru_value<T> value, int frame_size,
int batch_size,
activation_mode_t active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue,
value.prevOutValue, frameSize, active_gate);
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate);
} else {
hl_naive_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue,
value.prevOutValue, frameSize, active_gate);
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate);
}
value.gateValue += frameSize * 3;
value.resetOutputValue += frameSize;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
value.gate_value += frame_size * 3;
value.reset_output_value += frame_size;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
}
}
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput opFinalOutput,
hl_gru_value<T> value, int frameSize,
int batchSize, activation_mode_t active_node) {
for (int b = 0; b < batchSize; b++) {
if (OpFinalOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(opFinalOutput, value.gateValue,
value.prevOutValue, value.outputValue,
frameSize, active_node);
inline void forward_final_output(OpFinalOutput op_final_output,
hl_gru_value<T> value, int frame_size,
int batch_size,
activation_mode_t active_node) {
for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value,
frame_size, active_node);
} else {
hl_naive_gru_forward_final_output(opFinalOutput, value.gateValue,
value.prevOutValue, value.outputValue,
frameSize, active_node);
hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node);
}
value.gateValue += frameSize * 3;
value.outputValue += frameSize;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
value.gate_value += frame_size * 3;
value.output_value += frame_size;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
}
}
template <class OpStateGrad, typename T>
void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize,
void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
activation_mode_t active_node) {
T rUpdateGateValue;
T rUpdateGateGrad;
T rFrameStateValue;
T rFrameStateGrad;
T rOutGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T *updateGateValue = gateValue;
T *updateGateGrad = gateGrad;
T *frameStateValue = gateValue + frameSize * 2;
T *frameStateGrad = gateGrad + frameSize * 2;
for (int i = 0; i < frameSize; i++) {
rUpdateGateValue = updateGateValue[i];
rFrameStateValue = frameStateValue[i];
rOutGrad = outputGrad[i];
if (prevOutValue) {
rPrevOutValue = prevOutValue[i];
T r_update_gate_value;
T r_update_gate_grad;
T r_frame_state_value;
T r_frame_state_grad;
T r_out_grad;
T r_prev_out_value = 0;
T r_prev_out_grad = 0;
T *update_gate_value = gate_value;
T *update_gate_grad = gate_grad;
T *frame_state_value = gate_value + frame_size * 2;
T *frame_state_grad = gate_grad + frame_size * 2;
for (int i = 0; i < frame_size; i++) {
r_update_gate_value = update_gate_value[i];
r_frame_state_value = frame_state_value[i];
r_out_grad = output_grad[i];
if (prev_out_value) {
r_prev_out_value = prev_out_value[i];
}
if (prevOutGrad) {
rPrevOutGrad = prevOutGrad[i];
if (prev_out_grad) {
r_prev_out_grad = prev_out_grad[i];
}
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
active_node);
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
r_out_grad, active_node);
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
if (prevOutGrad) {
prevOutGrad[i] = rPrevOutGrad;
update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad;
if (prev_out_grad) {
prev_out_grad[i] = r_prev_out_grad;
}
}
}
template <class OpResetGrad, typename T>
void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize,
void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size,
activation_mode_t active_gate) {
T rUpdateGateValue;
T rUpdateGateGrad;
T rResetGateValue;
T rResetGateGrad;
T rResetOutputGrad = 0;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T *updateGateValue = gateValue;
T *updateGateGrad = gateGrad;
T *resetGateValue = gateValue + frameSize;
T *resetGateGrad = gateGrad + frameSize;
for (int i = 0; i < frameSize; i++) {
rUpdateGateValue = updateGateValue[i];
rUpdateGateGrad = updateGateGrad[i];
rResetGateValue = resetGateValue[i];
if (prevOutValue && prevOutGrad) {
rResetOutputGrad = resetOutputGrad[i];
T r_update_gate_value;
T r_update_gate_grad;
T r_reset_gate_value;
T r_reset_gate_grad;
T r_reset_output_grad = 0;
T r_prev_out_value = 0;
T r_prev_out_grad = 0;
T *update_gate_value = gate_value;
T *update_gate_grad = gate_grad;
T *reset_gate_value = gate_value + frame_size;
T *reset_gate_grad = gate_grad + frame_size;
for (int i = 0; i < frame_size; i++) {
r_update_gate_value = update_gate_value[i];
r_update_gate_grad = update_gate_grad[i];
r_reset_gate_value = reset_gate_value[i];
if (prev_out_value && prev_out_grad) {
r_reset_output_grad = reset_output_grad[i];
}
if (prevOutValue) {
rPrevOutValue = prevOutValue[i];
if (prev_out_value) {
r_prev_out_value = prev_out_value[i];
}
if (prevOutGrad) {
rPrevOutGrad = prevOutGrad[i];
if (prev_out_grad) {
r_prev_out_grad = prev_out_grad[i];
}
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
active_gate);
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
r_reset_output_grad, active_gate);
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
if (prevOutGrad) {
prevOutGrad[i] = rPrevOutGrad;
update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad;
if (prev_out_grad) {
prev_out_grad[i] = r_prev_out_grad;
}
}
}
template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize,
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
activation_mode_t active_node) {
#ifdef __AVX__
__m256 rUpdateGateValue;
__m256 rUpdateGateGrad;
__m256 rFrameStateValue;
__m256 rFrameStateGrad;
__m256 rOutGrad;
__m256 rPrevOutValue = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue;
__m256 *updateGateGrad = (__m256 *)gateGrad;
__m256 *frameStateValue = (__m256 *)(gateValue + frameSize * 2);
__m256 *frameStateGrad = (__m256 *)(gateGrad + frameSize * 2);
for (int i = 0; i < frameSize / 8; i++) {
rUpdateGateValue = updateGateValue[i];
rFrameStateValue = frameStateValue[i];
rOutGrad = ((__m256 *)outputGrad)[i];
if (prevOutValue) {
rPrevOutValue = ((__m256 *)prevOutValue)[i];
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
__m256 r_frame_state_value;
__m256 r_frame_state_grad;
__m256 r_out_grad;
__m256 r_prev_out_value = _mm256_set1_ps(0.0f);
__m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
__m256 *update_gate_value = (__m256 *)gate_value;
__m256 *update_gate_grad = (__m256 *)gate_grad;
__m256 *frame_state_value = (__m256 *)(gate_value + frame_size * 2);
__m256 *frame_state_grad = (__m256 *)(gate_grad + frame_size * 2);
for (int i = 0; i < frame_size / 8; i++) {
r_update_gate_value = update_gate_value[i];
r_frame_state_value = frame_state_value[i];
r_out_grad = ((__m256 *)output_grad)[i];
if (prev_out_value) {
r_prev_out_value = ((__m256 *)prev_out_value)[i];
}
if (prevOutGrad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i];
if (prev_out_grad) {
r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
}
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
active_node);
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
r_out_grad, active_node);
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
if (prevOutGrad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad;
update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad;
if (prev_out_grad) {
((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
}
}
#endif
}
template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize,
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size,
activation_mode_t active_gate) {
#ifdef __AVX__
__m256 rUpdateGateValue;
__m256 rUpdateGateGrad;
__m256 rResetGateValue;
__m256 rResetGateGrad;
__m256 rResetOutputGrad = _mm256_set1_ps(0.0f);
__m256 rPrevOutValue = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue;
__m256 *updateGateGrad = (__m256 *)gateGrad;
__m256 *resetGateValue = (__m256 *)(gateValue + frameSize);
__m256 *resetGateGrad = (__m256 *)(gateGrad + frameSize);
for (int i = 0; i < frameSize / 8; i++) {
rUpdateGateValue = updateGateValue[i];
rUpdateGateGrad = updateGateGrad[i];
rResetGateValue = resetGateValue[i];
if (prevOutValue && prevOutGrad) {
rResetOutputGrad = ((__m256 *)resetOutputGrad)[i];
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
__m256 r_reset_gate_value;
__m256 r_reset_gate_grad;
__m256 r_reset_output_grad = _mm256_set1_ps(0.0f);
__m256 r_prev_out_value = _mm256_set1_ps(0.0f);
__m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
__m256 *update_gate_value = (__m256 *)gate_value;
__m256 *update_gate_grad = (__m256 *)gate_grad;
__m256 *reset_gate_value = (__m256 *)(gate_value + frame_size);
__m256 *reset_gate_grad = (__m256 *)(gate_grad + frame_size);
for (int i = 0; i < frame_size / 8; i++) {
r_update_gate_value = update_gate_value[i];
r_update_gate_grad = update_gate_grad[i];
r_reset_gate_value = reset_gate_value[i];
if (prev_out_value && prev_out_grad) {
r_reset_output_grad = ((__m256 *)reset_output_grad)[i];
}
if (prevOutValue) {
rPrevOutValue = ((__m256 *)prevOutValue)[i];
if (prev_out_value) {
r_prev_out_value = ((__m256 *)prev_out_value)[i];
}
if (prevOutGrad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i];
if (prev_out_grad) {
r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
}
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
active_gate);
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
r_reset_output_grad, active_gate);
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
if (prevOutGrad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad;
update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad;
if (prev_out_grad) {
((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
}
}
#endif
}
template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad opStateGrad, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node) {
for (int b = 0; b < batchSize; b++) {
if (OpStateGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
inline void backward_state_grad(OpStateGrad op_state_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node) {
for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node);
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
} else {
hl_naive_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node);
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
}
value.gateValue += frameSize * 3;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
value.gate_value += frame_size * 3;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
grad.gateGrad += frameSize * 3;
grad.outputGrad += frameSize;
if (grad.prevOutGrad) {
grad.prevOutGrad += frameSize;
grad.gate_grad += frame_size * 3;
grad.output_grad += frame_size;
if (grad.prev_out_grad) {
grad.prev_out_grad += frame_size;
}
}
}
template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad opResetGrad, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_gate) {
for (int b = 0; b < batchSize; b++) {
if (OpResetGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
inline void backward_reset_grad(OpResetGrad op_reset_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate);
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
} else {
hl_naive_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate);
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
}
value.gateValue += frameSize * 3;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
value.gate_value += frame_size * 3;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
grad.gateGrad += frameSize * 3;
grad.resetOutputGrad += frameSize;
if (grad.prevOutGrad) {
grad.prevOutGrad += frameSize;
grad.gate_grad += frame_size * 3;
grad.reset_output_grad += frame_size;
if (grad.prev_out_grad) {
grad.prev_out_grad += frame_size;
}
}
}
......
......@@ -27,174 +27,174 @@ namespace math {
namespace detail {
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <class OpResetOutput, bool isBatch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
T *gateValue, T *resetOutputValue,
T *prevOutputValue, int frameSize,
int batchSize,
template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
int batch_size,
activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
resetOutputValue += batchIdx * frameSize;
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size;
}
T rPrevOut = 0;
T rValueResetOutput;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
T rValueResetGate = gateValue[frameIdx + frameSize * 1];
T r_prev_out = 0;
T r_value_reset_output;
T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
T r_value_reset_gate = gate_value[frame_idx + frame_size * 1];
if (prevOutputValue) {
if (isBatch) prevOutputValue += batchIdx * frameSize;
rPrevOut = prevOutputValue[frameIdx];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
r_prev_out = prev_output_value[frame_idx];
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput,
active_gate);
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
r_value_reset_output, active_gate);
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate;
gateValue[frameIdx + frameSize * 1] = rValueResetGate;
resetOutputValue[frameIdx] = rValueResetOutput;
gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
reset_output_value[frame_idx] = r_value_reset_output;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <class OpFinalOutput, bool isBatch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
T *gateValue, T *prevOutputValue,
T *outputValue, int frameSize,
int batchSize,
template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
int batch_size,
activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
outputValue += batchIdx * frameSize;
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
T rOutput;
T rPrevOut = 0;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
T rValueFrameState = gateValue[frameIdx + frameSize * 2];
T r_output;
T r_prev_out = 0;
T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
T r_value_frame_state = gate_value[frame_idx + frame_size * 2];
if (prevOutputValue) {
if (isBatch) prevOutputValue += batchIdx * frameSize;
rPrevOut = prevOutputValue[frameIdx];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
r_prev_out = prev_output_value[frame_idx];
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
active_node);
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
r_output, active_node);
gateValue[frameIdx + frameSize * 2] = rValueFrameState;
outputValue[frameIdx] = rOutput;
gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <class OpStateGrad, bool isBatch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize, int batchSize,
template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size, int batch_size,
activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
gateGrad += batchIdx * 3 * frameSize;
outputGrad += batchIdx * frameSize;
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
gate_grad += batch_idx * 3 * frame_size;
output_grad += batch_idx * frame_size;
}
T rUpdateGateGrad;
T rFrameStateGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
T rFrameStateValue = gateValue[frameIdx + frameSize * 2];
T rOutGrad = outputGrad[frameIdx];
T r_update_gate_grad;
T r_frame_state_grad;
T r_prev_out_value = 0;
T r_prev_out_grad = 0;
T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
T r_frame_state_value = gate_value[frame_idx + frame_size * 2];
T r_out_grad = output_grad[frame_idx];
if (prevOutValue && prevOutGrad) {
if (isBatch) prevOutValue += batchIdx * frameSize;
rPrevOutValue = prevOutValue[frameIdx];
if (prev_out_value && prev_out_grad) {
if (is_batch) prev_out_value += batch_idx * frame_size;
r_prev_out_value = prev_out_value[frame_idx];
if (isBatch) prevOutGrad += batchIdx * frameSize;
rPrevOutGrad = prevOutGrad[frameIdx];
if (is_batch) prev_out_grad += batch_idx * frame_size;
r_prev_out_grad = prev_out_grad[frame_idx];
}
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
active_node);
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
r_out_grad, active_node);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad;
if (prevOutGrad) {
prevOutGrad[frameIdx] = rPrevOutGrad;
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
if (prev_out_grad) {
prev_out_grad[frame_idx] = r_prev_out_grad;
}
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <class OpResetGrad, bool isBatch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize, int batchSize,
template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size,
activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
gateGrad += batchIdx * 3 * frameSize;
resetOutputGrad += batchIdx * frameSize;
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
gate_grad += batch_idx * 3 * frame_size;
reset_output_grad += batch_idx * frame_size;
}
T rResetGateGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T rResetOutputGrad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0];
T rResetGateValue = gateValue[frameIdx + frameSize * 1];
if (prevOutValue && prevOutGrad) {
if (isBatch) prevOutValue += batchIdx * frameSize;
if (isBatch) prevOutGrad += batchIdx * frameSize;
rPrevOutValue = prevOutValue[frameIdx];
rPrevOutGrad = prevOutGrad[frameIdx];
rResetOutputGrad = resetOutputGrad[frameIdx];
T r_reset_gate_grad;
T r_prev_out_value = 0;
T r_prev_out_grad = 0;
T r_reset_output_grad = 0;
T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0];
T r_reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_out_value && prev_out_grad) {
if (is_batch) prev_out_value += batch_idx * frame_size;
if (is_batch) prev_out_grad += batch_idx * frame_size;
r_prev_out_value = prev_out_value[frame_idx];
r_prev_out_grad = prev_out_grad[frame_idx];
r_reset_output_grad = reset_output_grad[frame_idx];
}
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
active_gate);
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
r_reset_output_grad, active_gate);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad;
if (prevOutGrad) {
prevOutGrad[frameIdx] = rPrevOutGrad;
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
if (prev_out_grad) {
prev_out_grad[frame_idx] = r_prev_out_grad;
}
}
} // namespace detail
......
......@@ -28,23 +28,25 @@ namespace forward {
template <typename T>
class gru_resetOutput {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut,
T &valueResetOutput, activation_mode_t actGate) {
valueUpdateGate = activation(valueUpdateGate, actGate);
valueResetGate = activation(valueResetGate, actGate);
valueResetOutput = prevOut * valueResetGate;
HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate,
T &prev_out, T &value_reset_output,
activation_mode_t act_gate) {
value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate;
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate,
__m256 &prevOut, __m256 &valueResetOutput,
activation_mode_t actGate) {
valueUpdateGate = activation(valueUpdateGate, actGate);
valueResetGate = activation(valueResetGate, actGate);
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate);
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_reset_gate, __m256 &prev_out,
__m256 &value_reset_output,
activation_mode_t act_gate) {
value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate);
}
#endif
#endif
......@@ -53,24 +55,26 @@ class gru_resetOutput {
template <typename T>
class gru_finalOutput {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut,
T &valueOutput, activation_mode_t actInput) {
valueFrameState = activation(valueFrameState, actInput);
valueOutput = prevOut - (valueUpdateGate * prevOut) +
(valueUpdateGate * valueFrameState);
HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state,
T &prev_out, T &value_output,
activation_mode_t act_input) {
value_frame_state = activation(value_frame_state, act_input);
value_output = prev_out - (value_update_gate * prev_out) +
(value_update_gate * value_frame_state);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState,
__m256 &prevOut, __m256 &valueOutput,
activation_mode_t actInput) {
valueFrameState = activation(valueFrameState, actInput);
valueOutput = _mm256_add_ps(
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)),
_mm256_mul_ps(valueUpdateGate, valueFrameState));
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_frame_state, __m256 &prev_out,
__m256 &value_output,
activation_mode_t act_input) {
value_frame_state = activation(value_frame_state, act_input);
value_output = _mm256_add_ps(
_mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)),
_mm256_mul_ps(value_update_gate, value_frame_state));
}
#endif
#endif
......@@ -82,34 +86,37 @@ namespace backward {
template <typename T>
class gru_stateGrad {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueFrameState, T &gradFrameState,
T &valuePrevOut, T &gradPrevOut, T &gradOutput,
activation_mode_t actInput) {
gradUpdateGate = (gradOutput * valueFrameState);
gradUpdateGate -= (gradOutput * valuePrevOut);
gradPrevOut -= (gradOutput * valueUpdateGate);
gradPrevOut += gradOutput;
gradFrameState =
activation(gradOutput * valueUpdateGate, valueFrameState, actInput);
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_frame_state, T &grad_frame_state,
T &value_prev_out, T &grad_prev_out,
T &grad_output, activation_mode_t act_input) {
grad_update_gate = (grad_output * value_frame_state);
grad_update_gate -= (grad_output * value_prev_out);
grad_prev_out -= (grad_output * value_update_gate);
grad_prev_out += grad_output;
grad_frame_state = activation(grad_output * value_update_gate,
value_frame_state, act_input);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueFrameState, __m256 &gradFrameState,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradOutput, activation_mode_t actInput) {
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState);
gradUpdateGate =
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut));
gradPrevOut = _mm256_add_ps(
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)),
gradOutput);
gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate),
valueFrameState, actInput);
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &grad_update_gate,
__m256 &value_frame_state,
__m256 &grad_frame_state, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_output,
activation_mode_t act_input) {
grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state);
grad_update_gate = _mm256_sub_ps(
grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out));
grad_prev_out = _mm256_add_ps(
_mm256_sub_ps(grad_prev_out,
_mm256_mul_ps(grad_output, value_update_gate)),
grad_output);
grad_frame_state = activation(_mm256_mul_ps(grad_output, value_update_gate),
value_frame_state, act_input);
}
#endif
#endif
......@@ -118,30 +125,32 @@ class gru_stateGrad {
template <typename T>
class gru_resetGrad {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueResetGate, T &gradResetGate,
T &valuePrevOut, T &gradPrevOut,
T &gradResetOutput, activation_mode_t actGate) {
gradResetGate = (gradResetOutput * valuePrevOut);
gradPrevOut += (gradResetOutput * valueResetGate);
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_reset_gate, T &grad_reset_gate,
T &value_prev_out, T &grad_prev_out,
T &grad_reset_output, activation_mode_t act_gate) {
grad_reset_gate = (grad_reset_output * value_prev_out);
grad_prev_out += (grad_reset_output * value_reset_gate);
grad_update_gate =
activation(grad_update_gate, value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueResetGate, __m256 &gradResetGate,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradResetOutput,
activation_mode_t actGate) {
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut);
gradPrevOut = _mm256_add_ps(gradPrevOut,
_mm256_mul_ps(gradResetOutput, valueResetGate));
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &grad_update_gate, __m256 &value_reset_gate,
__m256 &grad_reset_gate, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_reset_output,
activation_mode_t act_gate) {
grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out);
grad_prev_out = _mm256_add_ps(
grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate));
grad_update_gate =
activation(grad_update_gate, value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
}
#endif
#endif
......
......@@ -21,29 +21,29 @@ namespace math {
template <typename T>
struct GRUUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
#ifndef __NVCC__
if (value.prevOutValue) {
if (value.prev_out_value) {
math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
value.gateValue, frameSize * 3);
context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
1, value.gate_value, frame_size * 3);
}
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frameSize, batchSize, active_gate);
frame_size, batch_size, active_gate);
if (value.prevOutValue) {
if (value.prev_out_value) {
math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
value.gateValue + frameSize * 2, frameSize * 3);
context, false, false, batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight, frame_size,
1, value.gate_value + frame_size * 2, frame_size * 3);
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frameSize, batchSize, active_node);
frame_size, batch_size, active_node);
#endif
}
};
......@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
template <typename T>
struct GRUUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
hl_gru_value<T> value, hl_gru_grad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frameSize, batchSize, active_node);
grad, frame_size, batch_size, active_node);
if (value.prevOutValue && grad.prevOutGrad) {
if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
frameSize, 0, grad.resetOutputGrad, frameSize);
context, false, true, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frame_size, 0, grad.reset_output_grad, frame_size);
if (grad.stateWeightGrad) {
if (grad.state_weight_grad) {
math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
context, true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
}
}
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frameSize, batchSize, active_gate);
grad, frame_size, batch_size, active_gate);
if (grad.prevOutGrad && value.prevOutValue) {
if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
grad.prevOutGrad, frameSize);
context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prev_out_grad, frame_size);
if (grad.gateWeightGrad) {
if (grad.gate_weight_grad) {
math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
grad.gateWeightGrad, frameSize * 2);
context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gate_weight_grad, frame_size * 2);
}
}
#endif
......
......@@ -21,66 +21,66 @@ namespace math {
template <typename T>
struct GRUUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (value.prevOutValue) {
if (value.prev_out_value) {
math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
value.gateValue, frameSize * 3);
context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
1, value.gate_value, frame_size * 3);
}
if (batchSize == 1) {
if (batch_size == 1) {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ false,
/* is_batch= */ false,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
active_gate);
detail::forward::gru_resetOutput<T>(), value.gate_value,
value.reset_output_value, value.prev_out_value, frame_size,
batch_size, active_gate);
} else {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ true,
/* is_batch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
active_gate);
detail::forward::gru_resetOutput<T>(), value.gate_value,
value.reset_output_value, value.prev_out_value, frame_size,
batch_size, active_gate);
}
if (value.prevOutValue) {
if (value.prev_out_value) {
math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
value.gateValue + frameSize * 2, frameSize * 3);
context, false, false, batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight, frame_size,
1, value.gate_value + frame_size * 2, frame_size * 3);
}
if (batchSize == 1) {
if (batch_size == 1) {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ false,
/* is_batch= */ false,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue,
value.prevOutValue, value.outputValue, frameSize, batchSize,
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
} else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ true,
/* is_batch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue,
value.prevOutValue, value.outputValue, frameSize, batchSize,
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
}
}
......@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
template <typename T>
struct GRUUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
hl_gru_value<T> value, hl_gru_grad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (batchSize == 1) {
if (batch_size == 1) {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
batchSize, active_node);
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node);
} else {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
batchSize, active_node);
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node);
}
if (value.prevOutValue && grad.prevOutGrad) {
if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
frameSize, 0, grad.resetOutputGrad, frameSize);
context, false, true, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frame_size, 0, grad.reset_output_grad, frame_size);
if (grad.stateWeightGrad) {
if (grad.state_weight_grad) {
math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
context, true, false, frame_size, frame_size, batch_size, 1,
value.reset_output_value, frame_size,
grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
}
}
if (batchSize == 1) {
if (batch_size == 1) {
detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
batchSize, active_gate);
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.reset_output_grad, frame_size, batch_size, active_gate);
} else {
detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
batchSize, active_gate);
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.reset_output_grad, frame_size, batch_size, active_gate);
}
if (grad.prevOutGrad && value.prevOutValue) {
if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
grad.prevOutGrad, frameSize);
context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prev_out_grad, frame_size);
if (grad.gateWeightGrad) {
if (grad.gate_weight_grad) {
math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
grad.gateWeightGrad, frameSize * 2);
context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gate_weight_grad, frame_size * 2);
}
}
}
......
......@@ -22,28 +22,28 @@ namespace math {
// TODO(guosheng): refine code style in gru_compute
template <typename T>
struct hl_gru_value {
T *gateWeight;
T *stateWeight;
T *gateValue;
T *resetOutputValue;
T *outputValue;
T *prevOutValue;
T *gate_weight;
T *state_weight;
T *gate_value;
T *reset_output_value;
T *output_value;
T *prev_out_value;
};
template <typename T>
struct hl_gru_grad {
T *gateWeightGrad;
T *stateWeightGrad;
T *gateGrad;
T *resetOutputGrad;
T *outputGrad;
T *prevOutGrad;
T *gate_weight_grad;
T *state_weight_grad;
T *gate_grad;
T *reset_output_grad;
T *output_grad;
T *prev_out_grad;
};
template <typename Place, typename T>
struct GRUUnitFunctor {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate);
};
......@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
template <typename Place, typename T>
struct GRUUnitGradFunctor {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
hl_gru_value<T> value, hl_gru_grad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册