From 9c90dc9728813cbac15b9cf90d5bafb236056b3e Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 19 Jun 2018 21:42:40 +0800 Subject: [PATCH] Make the CUDA kernel of concat correct and fix unit tests. (#11541) * Make the CUDA kernel of concat correct and fix unit tests. --- paddle/fluid/operators/math/concat.cu | 41 +++++-------------- .../fluid/tests/unittests/test_concat_op.py | 13 +++++- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index f66baa6573..6205f3cd85 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -22,43 +22,24 @@ namespace paddle { namespace operators { namespace math { -template -__device__ T upper_bound(const T* first, T count, T val) { - const T* orig = first; - const T* it = nullptr; - T step = 0; - while (count > 0) { - it = first; - step = count / 2; - it += step; - if (!(val < *it)) { - first = ++it; - count -= step + 1; - } else { - count = step; - } - } - return first - orig; -} - template __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, const int output_rows, const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int segment = upper_bound(input_cols, col_size, tid_x) - 1; - - int curr_offset = input_cols[segment]; - int curr_segment = segment; + int curr_segment = 0; + int curr_offset = input_cols[0]; for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - T curr_col_offset; - while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { + int curr_col_offset = input_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; + curr_col_offset = input_cols[curr_segment + 1]; } int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; + T* input_ptr = inputs[curr_segment]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) @@ -89,14 +70,14 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, const int in_col, const int* out_cols, int out_cols_size, T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int segment = upper_bound(out_cols, out_cols_size, tid_x) - 1; - int curr_offset = out_cols[segment]; - int curr_segment = segment; + int curr_segment = 0; + int curr_offset = out_cols[0]; for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - T curr_col_offset; - while ((curr_col_offset = out_cols[curr_segment + 1]) <= tid_x) { + int curr_col_offset = out_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; + curr_col_offset = out_cols[curr_segment + 1]; } int local_col = tid_x - curr_offset; diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 1e00d67d54..e9f3c45dc4 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -43,7 +43,7 @@ class TestConcatOp(OpTest): self.axis = 1 -class TestConcatOp2(OpTest): +class TestConcatOp2(TestConcatOp): def init_test_data(self): self.x0 = np.random.random((2, 3, 4, 5)).astype('float32') self.x1 = np.random.random((2, 3, 4, 5)).astype('float32') @@ -51,5 +51,16 @@ class TestConcatOp2(OpTest): self.axis = 1 +class TestConcatOp3(TestConcatOp): + def init_test_data(self): + self.x0 = np.random.random((1, 256, 170, 256)).astype('float32') + self.x1 = np.random.random((1, 128, 170, 256)).astype('float32') + self.x2 = np.random.random((1, 128, 170, 256)).astype('float32') + self.axis = 1 + + def test_check_grad(self): + pass + + if __name__ == '__main__': unittest.main() -- GitLab