提交 1ab03d49 编写于 作者: G guosheng

Fix gru_op related code style in gpu_kernel

上级 3e552cdc
...@@ -36,12 +36,12 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, ...@@ -36,12 +36,12 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *prev_output_value, int frame_size, T *prev_output_value, int frame_size,
int batch_size, int batch_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
if (is_batch) { if (is_batch) {
batch_idx = block_idx.y * block_dim.y + thread_idx.y; batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return; if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size; gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size; reset_output_value += batch_idx * frame_size;
...@@ -75,11 +75,11 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -75,11 +75,11 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *output_value, int frame_size, T *output_value, int frame_size,
int batch_size, int batch_size,
activation_mode_t active_node) { activation_mode_t active_node) {
const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
if (is_batch) { if (is_batch) {
batch_idx = block_idx.y * block_dim.y + thread_idx.y; batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return; if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size; gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size; output_value += batch_idx * frame_size;
...@@ -112,11 +112,11 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -112,11 +112,11 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node) { activation_mode_t active_node) {
const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
if (is_batch) { if (is_batch) {
batch_idx = block_idx.y * block_dim.y + thread_idx.y; batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return; if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size; gate_value += batch_idx * 3 * frame_size;
gate_grad += batch_idx * 3 * frame_size; gate_grad += batch_idx * 3 * frame_size;
...@@ -160,11 +160,11 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -160,11 +160,11 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_gate) { activation_mode_t active_gate) {
const int frame_idx = block_idx.x * block_dim.x + thread_idx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
if (is_batch) { if (is_batch) {
batch_idx = block_idx.y * block_dim.y + thread_idx.y; batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return; if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size; gate_value += batch_idx * 3 * frame_size;
gate_grad += batch_idx * 3 * frame_size; gate_grad += batch_idx * 3 * frame_size;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册