diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 19d877dfb6e6ad1d1e27969fa60a59adaadd0232..a65b1987cb547d7d02c454df1c6758e74037a1b6 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -32,6 +32,7 @@ class ConcatKernel : public framework::OpKernel { int64_t axis = static_cast(ctx.Attr("axis")); auto place = ctx.GetPlace(); out->mutable_data(place); + std::vector inputs(ins.size()); for (size_t j = 0; j < ins.size(); ++j) { inputs[j] = *ins[j]; @@ -49,17 +50,17 @@ class ConcatGradKernel : public framework::OpKernel { auto* in = ctx.Input(framework::GradVarName("Out")); auto outs = ctx.MultiOutput(framework::GradVarName("X")); int64_t axis = static_cast(ctx.Attr("axis")); - size_t input_offset = 0; - auto in_stride = framework::stride_numel(in->dims()); - for (auto& out : outs) { - out->mutable_data(ctx.GetPlace()); - auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), - out_stride, in->data() + input_offset, - in_stride, out_stride[axis]); - input_offset += out_stride[axis]; + std::vector outputs(outs.size()); + for (size_t j = 0; j < outs.size(); ++j) { + outs[j]->mutable_data(ctx.GetPlace()); + outputs[j] = *outs[j]; } + + auto& dev_ctx = ctx.template device_context(); + paddle::operators::math::ConcatGradFunctor + concat_grad_functor; + concat_grad_functor(dev_ctx, *in, static_cast(axis), outputs); } }; diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index 32059aa2f0ce7f76de14de62ad73aeef90e3504c..5c5c6489d601d63aa975c8811e3320c4c03922c2 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -25,16 +25,12 @@ template class ConcatFunctor { public: void operator()(const platform::CPUDeviceContext& context, - std::vector& input, const int axis, + const std::vector& input, const int axis, framework::Tensor* output) { // assume the the max size of input is less than 8 and see the performance // save origin dim int num = input.size(); std::vector origin_dim(num); - // for (int j = 0; j < num; ++j) { - // origin_dim[j] = input[j].dims(); - // } - auto out_dim = output->dims(); // get the matrix size int rows = 1; @@ -42,40 +38,72 @@ class ConcatFunctor { for (int i = 0; i < axis; ++i) { rows *= dim_0[i]; } - int cols = input[0].numel() / rows; int out_rows = rows, out_cols = 0; - bool sameShape = true; - // reshape to matrix + // get input's cols + std::vector input_cols(input.size()); for (int i = 0; i < num; ++i) { int t_cols = input[i].numel() / rows; - if (sameShape) { - if (t_cols != cols) sameShape = false; - } out_cols += t_cols; - input[i].Resize({rows, t_cols}); + input_cols[i] = t_cols; } - output->Resize({out_rows, out_cols}); auto& cpu_place = boost::get(context.GetPlace()); + // computation - for (int k = 0; k < rows; ++k) { - // offset k * out_cols + for (int k = 0; k < out_rows; ++k) { T* dst_ptr = output->data() + k * out_cols; int col_idx = 0; for (int j = 0; j < num; ++j) { - int col_len = input[j].dims()[1]; + int col_len = input_cols[j]; const T* src_prt = input[j].data() + k * col_len; memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, sizeof(T) * col_len); col_idx += col_len; } } + } +}; + +template +class ConcatGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // assume the the max size of input is less than 8 and see the performance + // save origin dim + int num = outputs.size(); + std::vector origin_dim(num); - // recover origin dim - // for (int j = 0; j < num; ++j) { - // input[j]->Resize(origin_dim[j]); - // } - output->Resize(out_dim); + // get the matrix size + int input_rows = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_rows *= dim_0[i]; + } + int input_cols = 0; + + // get outputs' cols + std::vector output_cols(outputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = outputs[i].numel() / input_rows; + input_cols += t_cols; + output_cols[i] = t_cols; + } + auto& cpu_place = boost::get(context.GetPlace()); + + // computation + for (int k = 0; k < input_rows; ++k) { + const T* src_ptr = input.data() + k * input_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = output_cols[j]; + T* dst_ptr = outputs[j].data() + k * col_len; + memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, + sizeof(T) * col_len); + col_idx += col_len; + } + } } }; @@ -84,6 +112,11 @@ template class ConcatFunctor; template class ConcatFunctor; template class ConcatFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 6932e22f84ddfa873485914ef01aad6022bea2f1..8af7233426c89df3e2147b56592b856de803a8a6 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -22,7 +22,7 @@ namespace math { // TODO(zcd): This can be replaced by tensor, // if that, maybe we should add int8 to VarType::Type. // Or replaced by tensorArray. -static constexpr int MaxSize = 32; +static constexpr int MaxSize = 8; template struct CUDADeviceArray { T data[MaxSize]; @@ -54,7 +54,6 @@ __global__ void KernelConcat(const CUDADeviceArray inputs, const int output_rows, const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int segment = upper_bound(input_cols.data, input_cols.size, tid_x) - 1; int curr_offset = input_cols.data[segment]; @@ -69,13 +68,73 @@ __global__ void KernelConcat(const CUDADeviceArray inputs, int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; const T* input_ptr = inputs.data[curr_segment]; - + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) output[tid_y * output_cols + tid_x] = input_ptr[tid_y * segment_width + local_col]; } } +template +__global__ void KernelConcat(const CUDADeviceArray inputs, + const int input_col, const int output_rows, + const int output_cols, T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + float inv_input_col = 1.0 / input_col; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + const T* input_ptr = inputs.data[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * input_col + in_offset]; + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, + CUDADeviceArray output_cols, + CUDADeviceArray outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(output_cols.data, output_cols.size, tid_x) - 1; + int curr_offset = output_cols.data[segment]; + int curr_segment = segment; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs.data[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input[tid_y * input_col + tid_x]; + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int output_cols, + CUDADeviceArray outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + float inv_input_col = 1.0 / input_col; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* output_ptr = outputs.data[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * output_cols + in_offset] = + input[tid_y * input_col + tid_x]; + } +} + /* * All tensors' dimension should be the same. */ @@ -83,17 +142,13 @@ template class ConcatFunctor { public: void operator()(const platform::CUDADeviceContext& context, - std::vector& input, const int axis, + const std::vector& input, const int axis, framework::Tensor* output) { // assume the the max size of input is less than 8 and see the performance // save origin dim int num = input.size(); - // std::vector origin_dim(num); - // for (int j = 0; j < num; ++j) { - // origin_dim[j] = input[j].dims(); - // } - auto out_dim = output->dims(); - + PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d", + MaxSize); // get the matrix size int rows = 1; auto dim_0 = input[0].dims(); @@ -117,30 +172,96 @@ class ConcatFunctor { if (t_cols != cols) sameShape = false; } out_cols += t_cols; - input[i].Resize({rows, t_cols}); inputs_cols.data[i + 1] = out_cols; inputs_data.data[i] = input[i].data(); } - output->Resize({out_rows, out_cols}); // computation - const int kThreadsPerBlock = 256; + // set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; int block_cols = std::min(out_cols, kThreadsPerBlock); int block_rows = std::max(kThreadsPerBlock / block_cols, 1); dim3 block_size = dim3(block_cols, block_rows, 1); - int grid_cols = (out_cols + block_cols - 1) / block_cols; - int grid_rows = (out_rows + block_rows - 1) / block_rows; + int dev_id = paddle::platform::GetCurrentDeviceId(); + int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id); + int max_threads_per_mp = + paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id); + int max_threads = multi_process * max_threads_per_mp; + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((out_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); dim3 grid_size = dim3(grid_cols, grid_rows, 1); - KernelConcat<<>>( - inputs_data, inputs_cols, out_rows, out_cols, output->data()); + if (sameShape) { + KernelConcat<<>>( + inputs_data, cols, out_rows, out_cols, output->data()); + } else { + KernelConcat<<>>( + inputs_data, inputs_cols, out_rows, out_cols, output->data()); + } + } +}; + +template +class ConcatGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // assume the the max size of input is less than 8 and see the performance + // save origin dim + int num = outputs.size(); + PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d", + MaxSize); + + // get the matrix size + int input_row = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_row *= dim_0[i]; + } + + int output_col_0 = outputs[0].numel() / input_row; + int input_col = 0; + bool sameShape = true; + + CUDADeviceArray outputs_data; + CUDADeviceArray outputs_cols; + outputs_data.size = num; + outputs_cols.size = num + 1; + outputs_cols.data[0] = 0; - // recover origin dim - // for (int j = 0; j < num; ++j) { - // input[j].Resize(origin_dim[j]); - // } - output->Resize(out_dim); + for (int i = 0; i < num; ++i) { + int t_col = outputs[i].numel() / input_row; + if (sameShape) { + if (t_col != output_col_0) sameShape = false; + } + input_col += t_col; + outputs_cols.data[i + 1] = input_col; + outputs_data.data[i] = outputs[i].data(); + } + + // computation + const int kThreadsPerBlock = 256; + int block_cols = std::min(input_col, kThreadsPerBlock); + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + dim3 block_size = dim3(block_cols, block_rows, 1); + + int grid_cols = (input_col + block_cols - 1) / block_cols; + int grid_rows = (input_row + block_rows - 1) / block_rows; + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + KernelConcatGrad<<>>( + input.data(), input_row, input_col, output_col_0, outputs_data); + } else { + KernelConcatGrad<<>>( + input.data(), input_row, input_col, outputs_cols, outputs_data); + } } }; @@ -149,6 +270,11 @@ template class ConcatFunctor; template class ConcatFunctor; template class ConcatFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h index 50c75dd208d74f60692e10230a2ca49fa492f591..bc878318883d197d17823d2e6862251f1b02e6b3 100644 --- a/paddle/fluid/operators/math/concat.h +++ b/paddle/fluid/operators/math/concat.h @@ -20,18 +20,23 @@ namespace operators { namespace math { /* - * the tensor's shape of input will be changed, - * so the second parameter is not const. * */ template class ConcatFunctor { public: void operator()(const DeviceContext& context, - std::vector& input, const int axis, + const std::vector& input, const int axis, framework::Tensor* output); }; +template +class ConcatGradFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const int axis, std::vector& outputs); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc index 815070b1133e2ec07ed7ab93588331a7ec557409..1741af8148bb90863f294ba4930006a58b5ddbf9 100644 --- a/paddle/fluid/operators/math/concat_test.cc +++ b/paddle/fluid/operators/math/concat_test.cc @@ -251,6 +251,80 @@ void testConcat() { } } } + + /** + * cast4: + * inputs: + * axis = 1 + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 6, 4] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 3, 4}); + dim_out = make_ddim({2, 6, 4}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 3 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 1, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + // check the data + cols = 12; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 24; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } } TEST(math, concat) { diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 05e1eae853e20b3fd86438c03f52628179a311ca..da4041bad0d82fe1c8c7a12fd0c7177e6dbddef3 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -33,6 +33,26 @@ int GetCUDADeviceCount() { return count; } +int GetCUDAMultiProcessors(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE( + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMultiProcessors"); + return count; +} + +int GetCUDAMaxThreadsPerMultiProcessor(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE(cudaDeviceGetAttribute( + &count, cudaDevAttrMaxThreadsPerMultiProcessor, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMaxThreadsPerMultiProcessor"); + return count; +} + int GetCurrentDeviceId() { int device_id; PADDLE_ENFORCE( diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 3d4883d8078daa2b55d8ea792b47e93e4f4feec8..c38ccf0f2ade1d2405177b541b33fd84283726ff 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse = //! Get the total number of GPU devices in system. int GetCUDADeviceCount(); +//! Get the MultiProcessors of the ith GPU. +int GetCUDAMultiProcessors(int i); + +//! Get the MaxThreads of each MultiProcessor of the ith GPU. +int GetCUDAMaxThreadsPerMultiProcessor(int i); + //! Get the current GPU device id in system. int GetCurrentDeviceId();