提交 3e552cdc 编写于 作者: G guosheng

Fix gru_op related code style

上级 dcf3ffd9
...@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
int frame_size = hidden_dims[1]; int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value; math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.stateWeight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0; Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data(); const size_t* order = batch_gate->lod()[2].data();
...@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
// to reorder. // to reorder.
ReorderInitState<Place, T>(context.device_context(), *h0, order, ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true); &ordered_h0, true);
gru_value.prevOutValue = ordered_h0.data<T>(); gru_value.prev_out_value = ordered_h0.data<T>();
} else { } else {
gru_value.prevOutValue = nullptr; gru_value.prev_out_value = nullptr;
} }
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
...@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend); Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.outputValue = hidden_t.data<T>(); gru_value.output_value = hidden_t.data<T>();
gru_value.gateValue = gate_t.data<T>(); gru_value.gate_value = gate_t.data<T>();
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<Place, T>::compute( math::GRUUnitFunctor<Place, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")), math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation"))); math::ActiveType(context.Attr<std::string>("gate_activation")));
gru_value.prevOutValue = gru_value.outputValue; gru_value.prev_out_value = gru_value.output_value;
} }
math::Batch2LoDTensorFunctor<Place, T> to_seq; math::Batch2LoDTensorFunctor<Place, T> to_seq;
...@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
math::hl_gru_value<T> gru_value; math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.stateWeight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad; math::hl_gru_grad<T> gru_grad;
if (weight_grad) { if (weight_grad) {
gru_grad.gateWeightGrad = gru_grad.gate_weight_grad =
weight_grad->mutable_data<T>(context.GetPlace()); weight_grad->mutable_data<T>(context.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0)); zero(dev_ctx, weight_grad, static_cast<T>(0.0));
gru_grad.stateWeightGrad = gru_grad.state_weight_grad =
weight_grad->data<T>() + 2 * frame_size * frame_size; weight_grad->data<T>() + 2 * frame_size * frame_size;
} else { } else {
gru_grad.gateWeightGrad = nullptr; gru_grad.gate_weight_grad = nullptr;
gru_grad.stateWeightGrad = nullptr; gru_grad.state_weight_grad = nullptr;
} }
auto batch_starts = batch_hidden_grad.lod()[0]; auto batch_starts = batch_hidden_grad.lod()[0];
...@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batch_gate->Slice(bstart, bend);
gru_value.gateValue = gate_t.data<T>(); gru_value.gate_value = gate_t.data<T>();
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
gru_grad.outputGrad = hidden_grad_t.data<T>(); gru_grad.output_grad = hidden_grad_t.data<T>();
Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
gru_grad.gateGrad = gate_grad_t.data<T>(); gru_grad.gate_grad = gate_grad_t.data<T>();
Tensor reset_hidden_prev_grad_t = Tensor reset_hidden_prev_grad_t =
batch_reset_hidden_prev_grad.Slice(bstart, bend); batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>(); gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) { if (n == 0) {
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr; gru_value.prev_out_value = h0 ? ordered_h0.data<T>() : nullptr;
gru_grad.prevOutGrad = gru_grad.prev_out_grad =
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr; h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
} else { } else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]); int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
gru_value.prevOutValue = hidden_prev_t.data<T>(); gru_value.prev_out_value = hidden_prev_t.data<T>();
Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart); Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
gru_grad.prevOutGrad = hidden_prev_grad_t.data<T>(); gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>();
} }
math::GRUUnitGradFunctor<Place, T>::compute( math::GRUUnitGradFunctor<Place, T>::compute(
......
...@@ -27,174 +27,174 @@ namespace math { ...@@ -27,174 +27,174 @@ namespace math {
namespace detail { namespace detail {
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpResetOutput, bool isBatch, typename T> template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput, __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gateValue, T *resetOutputValue, T *gate_value, T *reset_output_value,
T *prevOutputValue, int frameSize, T *prev_output_value, int frame_size,
int batchSize, int batch_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
resetOutputValue += batchIdx * frameSize; reset_output_value += batch_idx * frame_size;
} }
T rPrevOut = 0; T r_prev_out = 0;
T rValueResetOutput; T r_value_reset_output;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
T rValueResetGate = gateValue[frameIdx + frameSize * 1]; T r_value_reset_gate = gate_value[frame_idx + frame_size * 1];
if (prevOutputValue) { if (prev_output_value) {
if (isBatch) prevOutputValue += batchIdx * frameSize; if (is_batch) prev_output_value += batch_idx * frame_size;
rPrevOut = prevOutputValue[frameIdx]; r_prev_out = prev_output_value[frame_idx];
} }
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput, op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
active_gate); r_value_reset_output, active_gate);
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate; gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
gateValue[frameIdx + frameSize * 1] = rValueResetGate; gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
resetOutputValue[frameIdx] = rValueResetOutput; reset_output_value[frame_idx] = r_value_reset_output;
} }
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpFinalOutput, bool isBatch, typename T> template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput, __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gateValue, T *prevOutputValue, T *gate_value, T *prev_output_value,
T *outputValue, int frameSize, T *output_value, int frame_size,
int batchSize, int batch_size,
activation_mode_t active_node) { activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
outputValue += batchIdx * frameSize; output_value += batch_idx * frame_size;
} }
T rOutput; T r_output;
T rPrevOut = 0; T r_prev_out = 0;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
T rValueFrameState = gateValue[frameIdx + frameSize * 2]; T r_value_frame_state = gate_value[frame_idx + frame_size * 2];
if (prevOutputValue) { if (prev_output_value) {
if (isBatch) prevOutputValue += batchIdx * frameSize; if (is_batch) prev_output_value += batch_idx * frame_size;
rPrevOut = prevOutputValue[frameIdx]; r_prev_out = prev_output_value[frame_idx];
} }
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
active_node); r_output, active_node);
gateValue[frameIdx + frameSize * 2] = rValueFrameState; gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
outputValue[frameIdx] = rOutput; output_value[frame_idx] = r_output;
} }
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpStateGrad, bool isBatch, typename T> template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue, __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *outputGrad, T *prev_out_grad, T *output_grad,
int frameSize, int batchSize, int frame_size, int batch_size,
activation_mode_t active_node) { activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
gateGrad += batchIdx * 3 * frameSize; gate_grad += batch_idx * 3 * frame_size;
outputGrad += batchIdx * frameSize; output_grad += batch_idx * frame_size;
} }
T rUpdateGateGrad; T r_update_gate_grad;
T rFrameStateGrad; T r_frame_state_grad;
T rPrevOutValue = 0; T r_prev_out_value = 0;
T rPrevOutGrad = 0; T r_prev_out_grad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
T rFrameStateValue = gateValue[frameIdx + frameSize * 2]; T r_frame_state_value = gate_value[frame_idx + frame_size * 2];
T rOutGrad = outputGrad[frameIdx]; T r_out_grad = output_grad[frame_idx];
if (prevOutValue && prevOutGrad) { if (prev_out_value && prev_out_grad) {
if (isBatch) prevOutValue += batchIdx * frameSize; if (is_batch) prev_out_value += batch_idx * frame_size;
rPrevOutValue = prevOutValue[frameIdx]; r_prev_out_value = prev_out_value[frame_idx];
if (isBatch) prevOutGrad += batchIdx * frameSize; if (is_batch) prev_out_grad += batch_idx * frame_size;
rPrevOutGrad = prevOutGrad[frameIdx]; r_prev_out_grad = prev_out_grad[frame_idx];
} }
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
active_node); r_out_grad, active_node);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
if (prevOutGrad) { if (prev_out_grad) {
prevOutGrad[frameIdx] = rPrevOutGrad; prev_out_grad[frame_idx] = r_prev_out_grad;
} }
} }
/* /*
* threads(framePerBlock, batchPerBlock) * threads(frame_per_block, batch_per_block)
* grid(frameBlocks, batchBlocks) * grid(frame_blocks, batch_blocks)
*/ */
template <class OpResetGrad, bool isBatch, typename T> template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue, __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gateGrad, T *prevOutValue, T *gate_grad, T *prev_out_value,
T *prevOutGrad, T *resetOutputGrad, T *prev_out_grad, T *reset_output_grad,
int frameSize, int batchSize, int frame_size, int batch_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
if (frameIdx >= frameSize) return; if (frame_idx >= frame_size) return;
int batchIdx = 0; int batch_idx = 0;
if (isBatch) { if (is_batch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y; batch_idx = block_idx.y * block_dim.y + thread_idx.y;
if (batchIdx >= batchSize) return; if (batch_idx >= batch_size) return;
gateValue += batchIdx * 3 * frameSize; gate_value += batch_idx * 3 * frame_size;
gateGrad += batchIdx * 3 * frameSize; gate_grad += batch_idx * 3 * frame_size;
resetOutputGrad += batchIdx * frameSize; reset_output_grad += batch_idx * frame_size;
} }
T rResetGateGrad; T r_reset_gate_grad;
T rPrevOutValue = 0; T r_prev_out_value = 0;
T rPrevOutGrad = 0; T r_prev_out_grad = 0;
T rResetOutputGrad = 0; T r_reset_output_grad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0]; T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0];
T rResetGateValue = gateValue[frameIdx + frameSize * 1]; T r_reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prevOutValue && prevOutGrad) { if (prev_out_value && prev_out_grad) {
if (isBatch) prevOutValue += batchIdx * frameSize; if (is_batch) prev_out_value += batch_idx * frame_size;
if (isBatch) prevOutGrad += batchIdx * frameSize; if (is_batch) prev_out_grad += batch_idx * frame_size;
rPrevOutValue = prevOutValue[frameIdx]; r_prev_out_value = prev_out_value[frame_idx];
rPrevOutGrad = prevOutGrad[frameIdx]; r_prev_out_grad = prev_out_grad[frame_idx];
rResetOutputGrad = resetOutputGrad[frameIdx]; r_reset_output_grad = reset_output_grad[frame_idx];
} }
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
active_gate); r_reset_output_grad, active_gate);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad; gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
if (prevOutGrad) { if (prev_out_grad) {
prevOutGrad[frameIdx] = rPrevOutGrad; prev_out_grad[frame_idx] = r_prev_out_grad;
} }
} }
} // namespace detail } // namespace detail
......
...@@ -28,23 +28,25 @@ namespace forward { ...@@ -28,23 +28,25 @@ namespace forward {
template <typename T> template <typename T>
class gru_resetOutput { class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut, HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate,
T &valueResetOutput, activation_mode_t actGate) { T &prev_out, T &value_reset_output,
valueUpdateGate = activation(valueUpdateGate, actGate); activation_mode_t act_gate) {
valueResetGate = activation(valueResetGate, actGate); value_update_gate = activation(value_update_gate, act_gate);
valueResetOutput = prevOut * valueResetGate; value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate;
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &prevOut, __m256 &valueResetOutput, __m256 &value_reset_gate, __m256 &prev_out,
activation_mode_t actGate) { __m256 &value_reset_output,
valueUpdateGate = activation(valueUpdateGate, actGate); activation_mode_t act_gate) {
valueResetGate = activation(valueResetGate, actGate); value_update_gate = activation(value_update_gate, act_gate);
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate); value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate);
} }
#endif #endif
#endif #endif
...@@ -53,24 +55,26 @@ class gru_resetOutput { ...@@ -53,24 +55,26 @@ class gru_resetOutput {
template <typename T> template <typename T>
class gru_finalOutput { class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut, HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state,
T &valueOutput, activation_mode_t actInput) { T &prev_out, T &value_output,
valueFrameState = activation(valueFrameState, actInput); activation_mode_t act_input) {
valueOutput = prevOut - (valueUpdateGate * prevOut) + value_frame_state = activation(value_frame_state, act_input);
(valueUpdateGate * valueFrameState); value_output = prev_out - (value_update_gate * prev_out) +
(value_update_gate * value_frame_state);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &prevOut, __m256 &valueOutput, __m256 &value_frame_state, __m256 &prev_out,
activation_mode_t actInput) { __m256 &value_output,
valueFrameState = activation(valueFrameState, actInput); activation_mode_t act_input) {
valueOutput = _mm256_add_ps( value_frame_state = activation(value_frame_state, act_input);
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)), value_output = _mm256_add_ps(
_mm256_mul_ps(valueUpdateGate, valueFrameState)); _mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)),
_mm256_mul_ps(value_update_gate, value_frame_state));
} }
#endif #endif
#endif #endif
...@@ -82,34 +86,37 @@ namespace backward { ...@@ -82,34 +86,37 @@ namespace backward {
template <typename T> template <typename T>
class gru_stateGrad { class gru_stateGrad {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &valueFrameState, T &gradFrameState, T &value_frame_state, T &grad_frame_state,
T &valuePrevOut, T &gradPrevOut, T &gradOutput, T &value_prev_out, T &grad_prev_out,
activation_mode_t actInput) { T &grad_output, activation_mode_t act_input) {
gradUpdateGate = (gradOutput * valueFrameState); grad_update_gate = (grad_output * value_frame_state);
gradUpdateGate -= (gradOutput * valuePrevOut); grad_update_gate -= (grad_output * value_prev_out);
gradPrevOut -= (gradOutput * valueUpdateGate); grad_prev_out -= (grad_output * value_update_gate);
gradPrevOut += gradOutput; grad_prev_out += grad_output;
gradFrameState = grad_frame_state = activation(grad_output * value_update_gate,
activation(gradOutput * valueUpdateGate, valueFrameState, actInput); value_frame_state, act_input);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &valueFrameState, __m256 &gradFrameState, __m256 &grad_update_gate,
__m256 &valuePrevOut, __m256 &gradPrevOut, __m256 &value_frame_state,
__m256 &gradOutput, activation_mode_t actInput) { __m256 &grad_frame_state, __m256 &value_prev_out,
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState); __m256 &grad_prev_out, __m256 &grad_output,
gradUpdateGate = activation_mode_t act_input) {
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut)); grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state);
gradPrevOut = _mm256_add_ps( grad_update_gate = _mm256_sub_ps(
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)), grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out));
gradOutput); grad_prev_out = _mm256_add_ps(
gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate), _mm256_sub_ps(grad_prev_out,
valueFrameState, actInput); _mm256_mul_ps(grad_output, value_update_gate)),
grad_output);
grad_frame_state = activation(_mm256_mul_ps(grad_output, value_update_gate),
value_frame_state, act_input);
} }
#endif #endif
#endif #endif
...@@ -118,30 +125,32 @@ class gru_stateGrad { ...@@ -118,30 +125,32 @@ class gru_stateGrad {
template <typename T> template <typename T>
class gru_resetGrad { class gru_resetGrad {
public: public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &valueResetGate, T &gradResetGate, T &value_reset_gate, T &grad_reset_gate,
T &valuePrevOut, T &gradPrevOut, T &value_prev_out, T &grad_prev_out,
T &gradResetOutput, activation_mode_t actGate) { T &grad_reset_output, activation_mode_t act_gate) {
gradResetGate = (gradResetOutput * valuePrevOut); grad_reset_gate = (grad_reset_output * value_prev_out);
gradPrevOut += (gradResetOutput * valueResetGate); grad_prev_out += (grad_reset_output * value_reset_gate);
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate); grad_update_gate =
gradResetGate = activation(gradResetGate, valueResetGate, actGate); activation(grad_update_gate, value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &valueResetGate, __m256 &gradResetGate, __m256 &grad_update_gate, __m256 &value_reset_gate,
__m256 &valuePrevOut, __m256 &gradPrevOut, __m256 &grad_reset_gate, __m256 &value_prev_out,
__m256 &gradResetOutput, __m256 &grad_prev_out, __m256 &grad_reset_output,
activation_mode_t actGate) { activation_mode_t act_gate) {
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut); grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out);
gradPrevOut = _mm256_add_ps(gradPrevOut, grad_prev_out = _mm256_add_ps(
_mm256_mul_ps(gradResetOutput, valueResetGate)); grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate));
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate); grad_update_gate =
gradResetGate = activation(gradResetGate, valueResetGate, actGate); activation(grad_update_gate, value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
} }
#endif #endif
#endif #endif
......
...@@ -21,29 +21,29 @@ namespace math { ...@@ -21,29 +21,29 @@ namespace math {
template <typename T> template <typename T>
struct GRUUnitFunctor<platform::CPUPlace, T> { struct GRUUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize, hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
#ifndef __NVCC__ #ifndef __NVCC__
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1, context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
value.gateValue, frameSize * 3); 1, value.gate_value, frame_size * 3);
} }
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value, detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frameSize, batchSize, active_gate); frame_size, batch_size, active_gate);
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1, context, false, false, batch_size, frame_size, frame_size, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, value.reset_output_value, frame_size, value.state_weight, frame_size,
value.gateValue + frameSize * 2, frameSize * 3); 1, value.gate_value + frame_size * 2, frame_size * 3);
} }
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frameSize, batchSize, active_node); frame_size, batch_size, active_node);
#endif #endif
} }
}; };
...@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> { ...@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
template <typename T> template <typename T>
struct GRUUnitGradFunctor<platform::CPUPlace, T> { struct GRUUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_node, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
#ifndef __NVCC__ #ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frameSize, batchSize, active_node); grad, frame_size, batch_size, active_node);
if (value.prevOutValue && grad.prevOutGrad) { if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1, context, false, true, batch_size, frame_size, frame_size, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frameSize, 0, grad.resetOutputGrad, frameSize); frame_size, 0, grad.reset_output_grad, frame_size);
if (grad.stateWeightGrad) { if (grad.state_weight_grad) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1, context, true, false, frame_size, frame_size, batch_size, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, value.reset_output_value, frame_size,
frameSize * 3, 1, grad.stateWeightGrad, frameSize); grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
} }
} }
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value, detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frameSize, batchSize, active_gate); grad, frame_size, batch_size, active_gate);
if (grad.prevOutGrad && value.prevOutValue) { if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1, context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prevOutGrad, frameSize); grad.prev_out_grad, frame_size);
if (grad.gateWeightGrad) { if (grad.gate_weight_grad) {
math::gemm<platform::CPUPlace, T>( math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1, context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gateWeightGrad, frameSize * 2); grad.gate_weight_grad, frame_size * 2);
} }
} }
#endif #endif
......
...@@ -21,66 +21,66 @@ namespace math { ...@@ -21,66 +21,66 @@ namespace math {
template <typename T> template <typename T>
struct GRUUnitFunctor<platform::GPUPlace, T> { struct GRUUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize, hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream(); reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batchSize == 1) { if (batch_size == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024; int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1); threads = dim3(frame_per_block, 1);
grid = dim3(frameBlocks, 1); grid = dim3(frame_blocks, 1);
} else { } else {
threads = dim3(32, 32); threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
} }
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1, context, false, false, batch_size, frame_size * 2, frame_size, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
value.gateValue, frameSize * 3); 1, value.gate_value, frame_size * 3);
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>, detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ false, /* is_batch= */ false,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue, detail::forward::gru_resetOutput<T>(), value.gate_value,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize, value.reset_output_value, value.prev_out_value, frame_size,
active_gate); batch_size, active_gate);
} else { } else {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>, detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ true, /* is_batch= */ true,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue, detail::forward::gru_resetOutput<T>(), value.gate_value,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize, value.reset_output_value, value.prev_out_value, frame_size,
active_gate); batch_size, active_gate);
} }
if (value.prevOutValue) { if (value.prev_out_value) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1, context, false, false, batch_size, frame_size, frame_size, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, value.reset_output_value, frame_size, value.state_weight, frame_size,
value.gateValue + frameSize * 2, frameSize * 3); 1, value.gate_value + frame_size * 2, frame_size * 3);
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>, detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ false, /* is_batch= */ false,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prevOutValue, value.outputValue, frameSize, batchSize, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node);
} else { } else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>, detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ true, /* is_batch= */ true,
T><<<grid, threads, 0, stream>>>( T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue, detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prevOutValue, value.outputValue, frameSize, batchSize, value.prev_out_value, value.output_value, frame_size, batch_size,
active_node); active_node);
} }
} }
...@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> { ...@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
template <typename T> template <typename T>
struct GRUUnitGradFunctor<platform::GPUPlace, T> { struct GRUUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_node, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) { activation_mode_t active_gate) {
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream(); reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batchSize == 1) { if (batch_size == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024; int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1); threads = dim3(frame_per_block, 1);
grid = dim3(frameBlocks, 1); grid = dim3(frame_blocks, 1);
} else { } else {
threads = dim3(32, 32); threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruBackwardStateGrad< detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>, detail::backward::gru_stateGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_stateGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_node); grad.output_grad, frame_size, batch_size, active_node);
} else { } else {
detail::KeGruBackwardStateGrad< detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>, detail::backward::gru_stateGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_stateGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_node); grad.output_grad, frame_size, batch_size, active_node);
} }
if (value.prevOutValue && grad.prevOutGrad) { if (value.prev_out_value && grad.prev_out_grad) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1, context, false, true, batch_size, frame_size, frame_size, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
frameSize, 0, grad.resetOutputGrad, frameSize); frame_size, 0, grad.reset_output_grad, frame_size);
if (grad.stateWeightGrad) { if (grad.state_weight_grad) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1, context, true, false, frame_size, frame_size, batch_size, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, value.reset_output_value, frame_size,
frameSize * 3, 1, grad.stateWeightGrad, frameSize); grad.gate_grad + frame_size * 2, frame_size * 3, 1,
grad.state_weight_grad, frame_size);
} }
} }
if (batchSize == 1) { if (batch_size == 1) {
detail::KeGruBackwardResetGrad< detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>, detail::backward::gru_resetGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_resetGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_gate); grad.reset_output_grad, frame_size, batch_size, active_gate);
} else { } else {
detail::KeGruBackwardResetGrad< detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>, detail::backward::gru_resetGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad, detail::backward::gru_resetGrad<T>(), value.gate_value,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
batchSize, active_gate); grad.reset_output_grad, frame_size, batch_size, active_gate);
} }
if (grad.prevOutGrad && value.prevOutValue) { if (grad.prev_out_grad && value.prev_out_value) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1, context, false, true, batch_size, frame_size, frame_size * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
grad.prevOutGrad, frameSize); grad.prev_out_grad, frame_size);
if (grad.gateWeightGrad) { if (grad.gate_weight_grad) {
math::gemm<platform::GPUPlace, T>( math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1, context, true, false, frame_size, frame_size * 2, batch_size, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
grad.gateWeightGrad, frameSize * 2); grad.gate_weight_grad, frame_size * 2);
} }
} }
} }
......
...@@ -22,28 +22,28 @@ namespace math { ...@@ -22,28 +22,28 @@ namespace math {
// TODO(guosheng): refine code style in gru_compute // TODO(guosheng): refine code style in gru_compute
template <typename T> template <typename T>
struct hl_gru_value { struct hl_gru_value {
T *gateWeight; T *gate_weight;
T *stateWeight; T *state_weight;
T *gateValue; T *gate_value;
T *resetOutputValue; T *reset_output_value;
T *outputValue; T *output_value;
T *prevOutValue; T *prev_out_value;
}; };
template <typename T> template <typename T>
struct hl_gru_grad { struct hl_gru_grad {
T *gateWeightGrad; T *gate_weight_grad;
T *stateWeightGrad; T *state_weight_grad;
T *gateGrad; T *gate_grad;
T *resetOutputGrad; T *reset_output_grad;
T *outputGrad; T *output_grad;
T *prevOutGrad; T *prev_out_grad;
}; };
template <typename Place, typename T> template <typename Place, typename T>
struct GRUUnitFunctor { struct GRUUnitFunctor {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize, hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node, activation_mode_t active_node,
activation_mode_t active_gate); activation_mode_t active_gate);
}; };
...@@ -51,8 +51,9 @@ struct GRUUnitFunctor { ...@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
template <typename Place, typename T> template <typename Place, typename T>
struct GRUUnitGradFunctor { struct GRUUnitGradFunctor {
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize, hl_gru_value<T> value, hl_gru_grad<T> grad,
int batchSize, activation_mode_t active_node, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate); activation_mode_t active_gate);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册