diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index b672c79afd97e36894af647fd4bc6edfb885ff13..bfce56f9fdcafa0800c9742b9fae41fd6a572b40 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -20,7 +20,7 @@ namespace math { /* * All tensors' dimension should be the same and the values of - * each dimension are the same, except the axis dimension. + * each dimension must be the same, except the axis dimension. */ template class ConcatFunctor { @@ -63,7 +63,7 @@ class ConcatFunctor { /* * All tensors' dimension should be the same and the values of - * each dimension are the same, except the axis dimension. + * each dimension must be the same, except the axis dimension. */ template class ConcatGradFunctor { diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 60b266f08fb2d4217c5933902d69de96fc2abe22..c0786757b34195d47c3b1cadc938f0e9fcfd6038 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -66,68 +66,66 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, } template -__global__ void KernelConcat(T** inputs, const int input_col, - const int output_rows, const int output_cols, - T* output) { +__global__ void KernelConcat(T** inputs_data, const int fixed_in_col, + const int out_rows, const int out_cols, + T* output_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - double inv_input_col = 1.0 / input_col; - for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * inv_input_col; - int in_offset = tid_x - split * input_col; - T* input_ptr = inputs[split]; + for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * 1.0 / fixed_in_col; + int in_offset = tid_x - split * fixed_in_col; + T* input_ptr = inputs_data[split]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { - output[tid_y * output_cols + tid_x] = - input_ptr[tid_y * input_col + in_offset]; + for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { + output_data[tid_y * out_cols + tid_x] = + input_ptr[tid_y * fixed_in_col + in_offset]; } } } template -__global__ void KernelConcatGrad(const T* input, const int input_row, - const int input_col, const int* output_cols, - int col_size, T** outputs) { +__global__ void KernelConcatGrad(const T* input_data, const int in_row, + const int in_col, const int* out_cols, + int out_cols_size, T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int segment = upper_bound(output_cols, col_size, tid_x) - 1; - int curr_offset = output_cols[segment]; + int segment = upper_bound(out_cols, out_cols_size, tid_x) - 1; + int curr_offset = out_cols[segment]; int curr_segment = segment; - for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { T curr_col_offset; - while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) { + while ((curr_col_offset = out_cols[curr_segment + 1]) <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; } int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; - T* output_ptr = outputs[curr_segment]; + T* output_ptr = outputs_data[curr_segment]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) output_ptr[tid_y * segment_width + local_col] = - input[tid_y * input_col + tid_x]; + input_data[tid_y * in_col + tid_x]; } } template -__global__ void KernelConcatGrad(const T* input, const int input_row, - const int input_col, const int output_cols, - T** outputs) { +__global__ void KernelConcatGrad(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - double inv_input_col = 1.0 / input_col; - for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * inv_input_col; - int in_offset = tid_x - split * input_col; - T* output_ptr = outputs[split]; + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x / fixed_out_col; + int in_offset = tid_x - split * fixed_out_col; + T* output_ptr = outputs_data[split]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * output_cols + in_offset] = - input[tid_y * input_col + tid_x]; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * fixed_out_col + in_offset] = + input_data[tid_y * in_col + tid_x]; } } /* * All tensors' dimension should be the same and the values of - * each dimension are the same, except the axis dimension. + * each dimension must be the same, except the axis dimension. */ template class ConcatFunctor { @@ -136,41 +134,40 @@ class ConcatFunctor { const std::vector& input, const int axis, framework::Tensor* output) { // TODO(zcd): Add input data validity checking - int num = input.size(); - int rows = 1; + int in_num = input.size(); + int in_row = 1; auto dim_0 = input[0].dims(); for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; + in_row *= dim_0[i]; } - int cols = input[0].numel() / rows; - int out_rows = rows, out_cols = 0; + int in_col = input[0].numel() / in_row; + int out_row = in_row, out_col = 0; - framework::Vector inputs_data(num * sizeof(T*) / 2); - framework::Vector inputs_cols(num + 1); - inputs_cols[0] = 0; + framework::Vector inputs_data(in_num * sizeof(T*) / 2); + framework::Vector inputs_col(in_num + 1); T** inputs_ptr = reinterpret_cast(inputs_data.data()); + inputs_col[0] = 0; bool sameShape = true; - for (int i = 0; i < num; ++i) { - int t_cols = input[i].numel() / rows; + for (int i = 0; i < in_num; ++i) { + int t_cols = input[i].numel() / in_row; if (sameShape) { - if (t_cols != cols) sameShape = false; + if (t_cols != in_col) sameShape = false; } - out_cols += t_cols; - inputs_cols[i + 1] = out_cols; + out_col += t_cols; + inputs_col[i + 1] = out_col; inputs_ptr[i] = const_cast(input[i].data()); } - T** ins_gpu = + T** dev_ins_data = reinterpret_cast(inputs_data.CUDAMutableData(context.GetPlace())); - const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace()); // computation // set the thread block and grid according to CurrentDeviceId const int kThreadsPerBlock = 1024; int block_cols = kThreadsPerBlock; - if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((out_cols + 31) >> 5) << 5; + if (out_col < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((out_col + 31) >> 5) << 5; } int block_rows = kThreadsPerBlock / block_cols; dim3 block_size = dim3(block_cols, block_rows, 1); @@ -179,25 +176,26 @@ class ConcatFunctor { int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); int grid_cols = - std::min((out_cols + block_cols - 1) / block_cols, max_blocks); + std::min((out_col + block_cols - 1) / block_cols, max_blocks); int grid_rows = - std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); + std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1)); dim3 grid_size = dim3(grid_cols, grid_rows, 1); if (sameShape) { KernelConcat<<>>( - ins_gpu, cols, out_rows, out_cols, output->data()); + dev_ins_data, in_col, out_row, out_col, output->data()); } else { + const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace()); KernelConcat<<>>( - ins_gpu, ins_col_gpu, static_cast(inputs_cols.size()), out_rows, - out_cols, output->data()); + dev_ins_data, dev_ins_col_data, static_cast(inputs_col.size()), + out_row, out_col, output->data()); } } }; /* * All tensors' dimension should be the same and the values of - * each dimension are the same, except the axis dimension. + * each dimension must be the same, except the axis dimension. */ template class ConcatGradFunctor { @@ -206,41 +204,40 @@ class ConcatGradFunctor { const framework::Tensor& input, const int axis, std::vector& outputs) { // TODO(zcd): Add input data validity checking - int num = outputs.size(); - int input_row = 1; + int o_num = outputs.size(); + int out_row = 1; auto dim_0 = outputs[0].dims(); for (int i = 0; i < axis; ++i) { - input_row *= dim_0[i]; + out_row *= dim_0[i]; } - int output_col_0 = outputs[0].numel() / input_row; - int input_col = 0; + int out_col = outputs[0].numel() / out_row; + int in_col = 0, in_row = out_row; bool sameShape = true; - framework::Vector outputs_data(num * sizeof(T*) / 2); - framework::Vector outputs_cols(num + 1); - outputs_cols[0] = 0; + framework::Vector outputs_data(o_num * sizeof(T*) / 2); + framework::Vector outputs_cols(o_num + 1); T** outputs_ptr = reinterpret_cast(outputs_data.data()); - for (int i = 0; i < num; ++i) { - int t_col = outputs[i].numel() / input_row; + outputs_cols[0] = 0; + for (int i = 0; i < o_num; ++i) { + int t_col = outputs[i].numel() / out_row; if (sameShape) { - if (t_col != output_col_0) sameShape = false; + if (t_col != out_col) sameShape = false; } - input_col += t_col; - outputs_cols[i + 1] = input_col; + in_col += t_col; + outputs_cols[i + 1] = in_col; outputs_ptr[i] = outputs[i].data(); } - T** outs_gpu = + T** dev_out_gpu_data = reinterpret_cast(outputs_data.CUDAMutableData(context.GetPlace())); - const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace()); // computation const int kThreadsPerBlock = 1024; int block_cols = kThreadsPerBlock; - if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((input_col + 31) >> 5) << 5; + if (in_col < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((in_col + 31) >> 5) << 5; } int block_rows = kThreadsPerBlock / block_cols; dim3 block_size = dim3(block_cols, block_rows, 1); @@ -249,18 +246,19 @@ class ConcatGradFunctor { int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); int grid_cols = - std::min((input_col + block_cols - 1) / block_cols, max_blocks); + std::min((in_col + block_cols - 1) / block_cols, max_blocks); int grid_rows = - std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1)); + std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1)); dim3 grid_size = dim3(grid_cols, grid_rows, 1); if (sameShape) { KernelConcatGrad<<>>( - input.data(), input_row, input_col, output_col_0, outs_gpu); + input.data(), in_row, in_col, out_col, dev_out_gpu_data); } else { + const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace()); KernelConcatGrad<<>>( - input.data(), input_row, input_col, outs_col_gpu, - static_cast(outputs_cols.size()), outs_gpu); + input.data(), in_row, in_col, dev_outs_col_data, + static_cast(outputs_cols.size()), dev_out_gpu_data); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 558f3a4dcbb8fe39c427d8b100f4488440e8b8cb..1e00d67d5480bfa77a60e1aed52cafac6e8242ca 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -20,19 +20,35 @@ from op_test import OpTest class TestConcatOp(OpTest): def setUp(self): self.op_type = "concat" - x0 = np.random.random((2, 1, 4, 5)).astype('float32') - x1 = np.random.random((2, 2, 4, 5)).astype('float32') - x2 = np.random.random((2, 3, 4, 5)).astype('float32') - axis = 1 - self.inputs = {'X': [('x0', x0), ('x1', x1), ('x2', x2)]} - self.attrs = {'axis': axis} - self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)} + self.init_test_data() + self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} + self.attrs = {'axis': self.axis} + self.outputs = { + 'Out': np.concatenate( + (self.x0, self.x1, self.x2), axis=self.axis) + } def test_check_output(self): self.check_output() def test_check_grad(self): self.check_grad(['x0'], 'Out') + self.check_grad(['x1'], 'Out') + self.check_grad(['x2'], 'Out') + + def init_test_data(self): + self.x0 = np.random.random((2, 1, 4, 5)).astype('float32') + self.x1 = np.random.random((2, 2, 4, 5)).astype('float32') + self.x2 = np.random.random((2, 3, 4, 5)).astype('float32') + self.axis = 1 + + +class TestConcatOp2(OpTest): + def init_test_data(self): + self.x0 = np.random.random((2, 3, 4, 5)).astype('float32') + self.x1 = np.random.random((2, 3, 4, 5)).astype('float32') + self.x2 = np.random.random((2, 3, 4, 5)).astype('float32') + self.axis = 1 if __name__ == '__main__':