From 043f47b27fa827cd87df93027124dce6d1d22d7e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 23 Mar 2018 18:29:15 +0800 Subject: [PATCH] fix concat op --- paddle/fluid/operators/math/concat.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 60b266f08..aede38000 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -70,9 +70,8 @@ __global__ void KernelConcat(T** inputs, const int input_col, const int output_rows, const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - double inv_input_col = 1.0 / input_col; for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * inv_input_col; + int split = tid_x * 1.0 / input_col; int in_offset = tid_x - split * input_col; T* input_ptr = inputs[split]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; @@ -110,17 +109,16 @@ __global__ void KernelConcatGrad(const T* input, const int input_row, template __global__ void KernelConcatGrad(const T* input, const int input_row, - const int input_col, const int output_cols, + const int input_col, const int output_col, T** outputs) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - double inv_input_col = 1.0 / input_col; for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * inv_input_col; - int in_offset = tid_x - split * input_col; + int split = tid_x / output_col; + int in_offset = tid_x - split * output_col; T* output_ptr = outputs[split]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * output_cols + in_offset] = + output_ptr[tid_y * output_col + in_offset] = input[tid_y * input_col + tid_x]; } } -- GitLab