diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index fff1980637d029b8a392c166734d3c3b84fed867..98466d44fc0dd91ef0cc8e8eac2660c42a19267c 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -96,7 +96,7 @@ if(CUDNN_FOUND) endif() message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. " - "Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ") + "Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}. ") endif() endif() diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index e925e7bb5917c9433c3c79b9a21a41b4d48a5ba0..fa04dc8d3ca7b301d892eb61dab813722ce2a106 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -24,9 +24,9 @@ namespace operators { namespace math { template -__global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size, - const int output_rows, const int output_cols, - T* output) { +__global__ void ConcatKernel(const T** inputs, const int* input_cols, + int col_size, const int output_rows, + const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int curr_segment = 0; int curr_offset = input_cols[0]; @@ -41,7 +41,7 @@ __global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size, int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; - T* input_ptr = inputs[curr_segment]; + const T* input_ptr = inputs[curr_segment]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) output[tid_y * output_cols + tid_x] = @@ -50,14 +50,14 @@ __global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size, } template -__global__ void ConcatKernel(T** inputs_data, const int fixed_in_col, - const int out_rows, const int out_cols, - T* output_data) { +__device__ void ConcatKernelDetail(const T** inputs_data, + const int fixed_in_col, const int out_rows, + const int out_cols, T* output_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { int split = tid_x * 1.0 / fixed_in_col; int in_offset = tid_x - split * fixed_in_col; - T* input_ptr = inputs_data[split]; + const T* input_ptr = inputs_data[split]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { output_data[tid_y * out_cols + tid_x] = @@ -66,6 +66,25 @@ __global__ void ConcatKernel(T** inputs_data, const int fixed_in_col, } } +template +__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, + const int fixed_in_col, const int out_rows, + const int out_cols, T* output_data) { + const T* inputs_data[2]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, + output_data); +} + +template +__global__ void ConcatKernel(const T** inputs_data, const int in_num, + const int fixed_in_col, const int out_rows, + const int out_cols, T* output_data) { + ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, + output_data); +} + template __global__ void SplitKernel(const T* input_data, const int in_row, const int in_col, const int* out_cols, @@ -94,9 +113,9 @@ __global__ void SplitKernel(const T* input_data, const int in_row, } template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, - T** outputs_data) { +__device__ void SplitKernelDetail(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { int split = tid_x / fixed_out_col; @@ -111,6 +130,45 @@ __global__ void SplitKernel(const T* input_data, const int in_row, } } +template +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T** outputs_data) { + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T* outputs_addr0, T* outputs_addr1) { + T* outputs_data[2]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +static inline void GetBlockDims(const platform::CUDADeviceContext& context, + int num_rows, int num_cols, dim3* block_dims, + dim3* grid_dims) { + // Set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((num_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + *block_dims = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((num_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, 1)); + *grid_dims = dim3(grid_cols, grid_rows, 1); +} + /* * All tensors' dimension should be the same and the values of * each dimension must be the same, except the axis dimension. @@ -131,53 +189,47 @@ class ConcatFunctor { int in_col = input[0].numel() / in_row; int out_row = in_row, out_col = 0; - std::vector inputs_data; + std::vector inputs_data(in_num); std::vector inputs_col(in_num + 1); - inputs_data.reserve(in_num); inputs_col[0] = 0; - bool sameShape = true; + bool has_same_shape = true; for (int i = 0; i < in_num; ++i) { int t_cols = input[i].numel() / in_row; - if (sameShape) { - if (t_cols != in_col) sameShape = false; + if (has_same_shape) { + if (t_cols != in_col) has_same_shape = false; } out_col += t_cols; inputs_col[i + 1] = out_col; - inputs_data.emplace_back(input[i].data()); + inputs_data[i] = input[i].data(); } - // computation - // set the thread block and grid according to CurrentDeviceId - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (out_col < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((out_col + 31) >> 5) << 5; + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims); + + memory::allocation::AllocationPtr tmp_dev_ins_data; + const T** dev_ins_data = nullptr; + if (!has_same_shape || (in_num != 2)) { + tmp_dev_ins_data = + platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( + inputs_data.size() * sizeof(T*)); + memory::Copy(boost::get(context.GetPlace()), + tmp_dev_ins_data->ptr(), platform::CPUPlace(), + static_cast(inputs_data.data()), + inputs_data.size() * sizeof(T*), context.stream()); + dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); } - int block_rows = kThreadsPerBlock / block_cols; - dim3 block_size = dim3(block_cols, block_rows, 1); - - int max_threads = context.GetMaxPhysicalThreadCount(); - int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - int grid_cols = - std::min((out_col + block_cols - 1) / block_cols, max_blocks); - int grid_rows = - std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1)); - dim3 grid_size = dim3(grid_cols, grid_rows, 1); - - auto tmp_dev_ins_data = - platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( - inputs_data.size() * sizeof(T*)); - memory::Copy(boost::get(context.GetPlace()), - tmp_dev_ins_data->ptr(), platform::CPUPlace(), - static_cast(inputs_data.data()), - inputs_data.size() * sizeof(T*), context.stream()); - T** dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); - - if (sameShape) { - ConcatKernel<<>>( - dev_ins_data, in_col, out_row, out_col, output->data()); + + if (has_same_shape) { + if (in_num == 2) { + ConcatKernel<<>>( + inputs_data[0], inputs_data[1], in_col, out_row, out_col, + output->data()); + } else { + ConcatKernel<<>>( + dev_ins_data, in_num, in_col, out_row, out_col, output->data()); + } } else { auto tmp_dev_ins_col_data = platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( @@ -188,7 +240,7 @@ class ConcatFunctor { inputs_col.size() * sizeof(int), context.stream()); int* dev_ins_col_data = static_cast(tmp_dev_ins_col_data->ptr()); - ConcatKernel<<>>( + ConcatKernel<<>>( dev_ins_data, dev_ins_col_data, static_cast(inputs_col.size()), out_row, out_col, output->data()); } @@ -216,7 +268,7 @@ class SplitFunctor { int out0_col = ref_inputs[0]->numel() / out_row; int in_col = 0, in_row = out_row; - bool sameShape = true; + bool has_same_shape = true; std::vector outputs_data(o_num); std::vector outputs_cols(o_num + 1); @@ -224,8 +276,8 @@ class SplitFunctor { outputs_cols[0] = 0; for (int i = 0; i < o_num; ++i) { int t_col = ref_inputs.at(i)->numel() / out_row; - if (sameShape) { - if (t_col != out0_col) sameShape = false; + if (has_same_shape) { + if (t_col != out0_col) has_same_shape = false; } in_col += t_col; outputs_cols[i + 1] = in_col; @@ -236,36 +288,32 @@ class SplitFunctor { } } - // computation - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (in_col < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((in_col + 31) >> 5) << 5; + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims); + + memory::allocation::AllocationPtr tmp_dev_outs_data; + T** dev_out_gpu_data = nullptr; + if (!has_same_shape || (o_num != 2)) { + tmp_dev_outs_data = + platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( + outputs_data.size() * sizeof(T*)); + memory::Copy(boost::get(context.GetPlace()), + tmp_dev_outs_data->ptr(), platform::CPUPlace(), + reinterpret_cast(outputs_data.data()), + outputs_data.size() * sizeof(T*), context.stream()); + dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); } - int block_rows = kThreadsPerBlock / block_cols; - dim3 block_size = dim3(block_cols, block_rows, 1); - - int max_threads = context.GetMaxPhysicalThreadCount(); - int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - int grid_cols = - std::min((in_col + block_cols - 1) / block_cols, max_blocks); - int grid_rows = - std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1)); - dim3 grid_size = dim3(grid_cols, grid_rows, 1); - - auto tmp_dev_outs_data = - platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( - outputs_data.size() * sizeof(T*)); - memory::Copy(boost::get(context.GetPlace()), - tmp_dev_outs_data->ptr(), platform::CPUPlace(), - reinterpret_cast(outputs_data.data()), - outputs_data.size() * sizeof(T*), context.stream()); - T** dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); - - if (sameShape) { - SplitKernel<<>>( - input.data(), in_row, in_col, out0_col, dev_out_gpu_data); + + if (has_same_shape) { + if (o_num == 2) { + SplitKernel<<>>( + input.data(), in_row, in_col, out0_col, outputs_data[0], + outputs_data[1]); + } else { + SplitKernel<<>>( + input.data(), in_row, in_col, out0_col, dev_out_gpu_data); + } } else { auto tmp_dev_ins_col_data = platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( @@ -277,7 +325,7 @@ class SplitFunctor { int* dev_outs_col_data = reinterpret_cast(tmp_dev_ins_col_data->ptr()); - SplitKernel<<>>( + SplitKernel<<>>( input.data(), in_row, in_col, dev_outs_col_data, static_cast(outputs_cols.size()), dev_out_gpu_data); } diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc index 8ba9e8e8ec1344edc3beaf7f4a58f99107cc0e9c..411dbca25bb48c99dfd16779f54e46a3e80d0d4e 100644 --- a/paddle/fluid/operators/math/concat_test.cc +++ b/paddle/fluid/operators/math/concat_test.cc @@ -17,26 +17,24 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/concat_and_split.h" +/** + * case 1: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [3, 3, 4] + * output: + * out.shape: [5, 3, 4] + */ template -void testConcat() { +void ConcatCase1(DeviceContext* context) { paddle::framework::Tensor input_a_cpu; paddle::framework::Tensor input_b_cpu; paddle::framework::Tensor out_cpu; + paddle::framework::Tensor input_a; paddle::framework::Tensor input_b; paddle::framework::Tensor out; - DeviceContext* context = new DeviceContext(Place()); - // DeviceContext context(Place()); - - /** - * cast1: - * inputs: - * t_a.shape: [2, 3, 4] - * t_b.shape: [3, 3, 4] - * output: - * out.shape: [5, 3, 4] - */ auto dim_a = paddle::framework::make_ddim({2, 3, 4}); auto dim_b = paddle::framework::make_ddim({3, 3, 4}); auto dim_out = paddle::framework::make_ddim({5, 3, 4}); @@ -51,8 +49,8 @@ void testConcat() { out_cpu.mutable_data(dim_out, paddle::platform::CPUPlace()); } - int* a_ptr; - int* b_ptr; + int* a_ptr = nullptr; + int* b_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { a_ptr = input_a_cpu.data(); b_ptr = input_b_cpu.data(); @@ -84,7 +82,7 @@ void testConcat() { PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); - int* out_ptr; + int* out_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(), &out_cpu); @@ -104,28 +102,42 @@ void testConcat() { ++idx_a; } } - // - /** - * cast2: - * inputs: - * t_a.shape: [2, 3, 4] - * t_b.shape: [2, 4, 4] - * output: - * out.shape: [2, 7, 4] - */ - dim_a = paddle::framework::make_ddim({2, 3, 4}); - dim_b = paddle::framework::make_ddim({2, 4, 4}); - dim_out = paddle::framework::make_ddim({2, 7, 4}); - - input_a.Resize(dim_a); - input_b.Resize(dim_b); - out.Resize(dim_out); +} + +/** + * case 2: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 4, 4] + * output: + * out.shape: [2, 7, 4] + */ +template +void ConcatCase2(DeviceContext* context) { + paddle::framework::Tensor input_a_cpu; + paddle::framework::Tensor input_b_cpu; + paddle::framework::Tensor out_cpu; + + paddle::framework::Tensor input_a; + paddle::framework::Tensor input_b; + paddle::framework::Tensor out; + + auto dim_a = paddle::framework::make_ddim({2, 3, 4}); + auto dim_b = paddle::framework::make_ddim({2, 4, 4}); + auto dim_out = paddle::framework::make_ddim({2, 7, 4}); + + input_a.mutable_data(dim_a, Place()); + input_b.mutable_data(dim_b, Place()); + out.mutable_data(dim_out, Place()); + if (paddle::platform::is_gpu_place(Place())) { - input_a_cpu.Resize(dim_a); - input_b_cpu.Resize(dim_b); - out_cpu.Resize(dim_out); + input_a_cpu.mutable_data(dim_a, paddle::platform::CPUPlace()); + input_b_cpu.mutable_data(dim_b, paddle::platform::CPUPlace()); + out_cpu.mutable_data(dim_out, paddle::platform::CPUPlace()); } + int* a_ptr = nullptr; + int* b_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { a_ptr = input_a_cpu.data(); b_ptr = input_b_cpu.data(); @@ -146,16 +158,18 @@ void testConcat() { paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b); } - input.clear(); + std::vector input; input.push_back(input_a); input.push_back(input_b); + paddle::operators::math::ConcatFunctor concat_functor; concat_functor(*context, input, 1, &out); // check the dim of input_a, input_b PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + int* out_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(), &out_cpu); @@ -164,8 +178,8 @@ void testConcat() { out_ptr = out.data(); } - cols = 3 * 4; - idx_a = 0, idx_b = 0; + int cols = 3 * 4; + int idx_a = 0, idx_b = 0; for (int i = 0; i < 2; ++i) { for (int j = 0; j < 28; ++j) { if (j >= cols) { @@ -177,28 +191,42 @@ void testConcat() { } } } +} + +/** + * case 3: + * inputs: + * t_a.shape: [2, 3, 5] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 3, 9] + */ +template +void ConcatCase3(DeviceContext* context) { + paddle::framework::Tensor input_a_cpu; + paddle::framework::Tensor input_b_cpu; + paddle::framework::Tensor out_cpu; + + paddle::framework::Tensor input_a; + paddle::framework::Tensor input_b; + paddle::framework::Tensor out; + + auto dim_a = paddle::framework::make_ddim({2, 3, 4}); + auto dim_b = paddle::framework::make_ddim({2, 3, 5}); + auto dim_out = paddle::framework::make_ddim({2, 3, 9}); + + input_a.mutable_data(dim_a, Place()); + input_b.mutable_data(dim_b, Place()); + out.mutable_data(dim_out, Place()); - /** - * cast3: - * inputs: - * t_a.shape: [2, 3, 5] - * t_b.shape: [2, 3, 4] - * output: - * out.shape: [2, 3, 9] - */ - dim_a = paddle::framework::make_ddim({2, 3, 4}); - dim_b = paddle::framework::make_ddim({2, 3, 5}); - dim_out = paddle::framework::make_ddim({2, 3, 9}); - - input_a.Resize(dim_a); - input_b.Resize(dim_b); - out.Resize(dim_out); if (paddle::platform::is_gpu_place(Place())) { - input_a_cpu.Resize(dim_a); - input_b_cpu.Resize(dim_b); - out_cpu.Resize(dim_out); + input_a_cpu.mutable_data(dim_a, paddle::platform::CPUPlace()); + input_b_cpu.mutable_data(dim_b, paddle::platform::CPUPlace()); + out_cpu.mutable_data(dim_out, paddle::platform::CPUPlace()); } + int* a_ptr = nullptr; + int* b_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { a_ptr = input_a_cpu.data(); b_ptr = input_b_cpu.data(); @@ -219,16 +247,18 @@ void testConcat() { paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b); } - input.clear(); + std::vector input; input.push_back(input_a); input.push_back(input_b); + paddle::operators::math::ConcatFunctor concat_functor; concat_functor(*context, input, 2, &out); // check the dim of input_a, input_b PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + int* out_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(), &out_cpu); @@ -238,8 +268,8 @@ void testConcat() { } // check the data - cols = 4; - idx_a = 0, idx_b = 0; + int cols = 4; + int idx_a = 0, idx_b = 0; for (int i = 0; i < 6; ++i) { for (int j = 0; j < 9; ++j) { if (j >= cols) { @@ -251,29 +281,43 @@ void testConcat() { } } } +} + +/** + * case 4: + * inputs: + * axis = 1 + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 6, 4] + */ +template +void ConcatCase4(DeviceContext* context) { + paddle::framework::Tensor input_a_cpu; + paddle::framework::Tensor input_b_cpu; + paddle::framework::Tensor out_cpu; + + paddle::framework::Tensor input_a; + paddle::framework::Tensor input_b; + paddle::framework::Tensor out; + + auto dim_a = paddle::framework::make_ddim({2, 3, 4}); + auto dim_b = paddle::framework::make_ddim({2, 3, 4}); + auto dim_out = paddle::framework::make_ddim({2, 6, 4}); + + input_a.mutable_data(dim_a, Place()); + input_b.mutable_data(dim_b, Place()); + out.mutable_data(dim_out, Place()); - /** - * cast4: - * inputs: - * axis = 1 - * t_a.shape: [2, 3, 4] - * t_b.shape: [2, 3, 4] - * output: - * out.shape: [2, 6, 4] - */ - dim_a = paddle::framework::make_ddim({2, 3, 4}); - dim_b = paddle::framework::make_ddim({2, 3, 4}); - dim_out = paddle::framework::make_ddim({2, 6, 4}); - - input_a.Resize(dim_a); - input_b.Resize(dim_b); - out.Resize(dim_out); if (paddle::platform::is_gpu_place(Place())) { - input_a_cpu.Resize(dim_a); - input_b_cpu.Resize(dim_b); - out_cpu.Resize(dim_out); + input_a_cpu.mutable_data(dim_a, paddle::platform::CPUPlace()); + input_b_cpu.mutable_data(dim_b, paddle::platform::CPUPlace()); + out_cpu.mutable_data(dim_out, paddle::platform::CPUPlace()); } + int* a_ptr = nullptr; + int* b_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { a_ptr = input_a_cpu.data(); b_ptr = input_b_cpu.data(); @@ -294,16 +338,19 @@ void testConcat() { paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b); } - input.clear(); + std::vector input; input.push_back(input_a); input.push_back(input_b); + paddle::operators::math::ConcatFunctor concat_functor; concat_functor(*context, input, 1, &out); + context->Wait(); // check the dim of input_a, input_b PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + int* out_ptr = nullptr; if (paddle::platform::is_gpu_place(Place())) { paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(), &out_cpu); @@ -313,8 +360,8 @@ void testConcat() { } // check the data - cols = 12; - idx_a = 0, idx_b = 0; + int cols = 12; + int idx_a = 0, idx_b = 0; for (int i = 0; i < 2; ++i) { for (int j = 0; j < 24; ++j) { if (j >= cols) { @@ -328,10 +375,21 @@ void testConcat() { } } +template +void TestConcatMain() { + DeviceContext* context = new DeviceContext(Place()); + + ConcatCase1(context); + ConcatCase2(context); + ConcatCase3(context); + ConcatCase4(context); +} + TEST(math, concat) { - testConcat(); + TestConcatMain(); #ifdef PADDLE_WITH_CUDA - testConcat(); + TestConcatMain(); #endif }