提交 043f47b2 编写于 作者: C chengduoZH

fix concat op

上级 76ae540f
...@@ -70,9 +70,8 @@ __global__ void KernelConcat(T** inputs, const int input_col, ...@@ -70,9 +70,8 @@ __global__ void KernelConcat(T** inputs, const int input_col,
const int output_rows, const int output_cols, const int output_rows, const int output_cols,
T* output) { T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col; int split = tid_x * 1.0 / input_col;
int in_offset = tid_x - split * input_col; int in_offset = tid_x - split * input_col;
T* input_ptr = inputs[split]; T* input_ptr = inputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -110,17 +109,16 @@ __global__ void KernelConcatGrad(const T* input, const int input_row, ...@@ -110,17 +109,16 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
template <typename T> template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row, __global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols, const int input_col, const int output_col,
T** outputs) { T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col; int split = tid_x / output_col;
int in_offset = tid_x - split * input_col; int in_offset = tid_x - split * output_col;
T* output_ptr = outputs[split]; T* output_ptr = outputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * output_cols + in_offset] = output_ptr[tid_y * output_col + in_offset] =
input[tid_y * input_col + tid_x]; input[tid_y * input_col + tid_x];
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册