From ad1ad738d89bb6b347ee0c53ef0245acb86f158d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 17 Jun 2018 10:48:34 +0800 Subject: [PATCH] add gpu support for concat --- paddle/fluid/operators/math/concat.cc | 2 +- paddle/fluid/operators/math/concat.cu | 41 ++++++++++++++++----------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index c10cff9c9b1..14964fc62af 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -98,7 +98,7 @@ class ConcatGradFunctor { int col_idx = 0; for (int j = 0; j < num; ++j) { int col_len = output_cols[j]; - auto* out_tensor = (*outputs)[j]; + auto* out_tensor = outputs->at(j); if (out_tensor != nullptr) { T* dst_ptr = out_tensor->data() + k * col_len; memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 4285d38dcd6..f66baa6573f 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -102,10 +102,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; T* output_ptr = outputs_data[curr_segment]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * segment_width + local_col] = - input_data[tid_y * in_col + tid_x]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input_data[tid_y * in_col + tid_x]; + } } } @@ -118,10 +120,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, int split = tid_x / fixed_out_col; int in_offset = tid_x - split * fixed_out_col; T* output_ptr = outputs_data[split]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * fixed_out_col + in_offset] = - input_data[tid_y * in_col + tid_x]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * fixed_out_col + in_offset] = + input_data[tid_y * in_col + tid_x]; + } } } @@ -203,17 +207,18 @@ template class ConcatGradFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, const int axis, - std::vector* outputs) { + const framework::Tensor& input, + const std::vector& ref_inputs, + const int axis, std::vector* outputs) { // TODO(zcd): Add input data validity checking int o_num = outputs->size(); int out_row = 1; - auto dim_0 = outputs->at(0).dims(); + auto dim_0 = ref_inputs[0]->dims(); for (int i = 0; i < axis; ++i) { out_row *= dim_0[i]; } - int out_col = outputs->at(0).numel() / out_row; + int out0_col = ref_inputs[0]->numel() / out_row; int in_col = 0, in_row = out_row; bool sameShape = true; @@ -223,13 +228,17 @@ class ConcatGradFunctor { outputs_cols[0] = 0; for (int i = 0; i < o_num; ++i) { - int t_col = outputs->at(i).numel() / out_row; + int t_col = outputs->at(i)->numel() / out_row; if (sameShape) { - if (t_col != out_col) sameShape = false; + if (t_col != out0_col) sameShape = false; } in_col += t_col; outputs_cols[i + 1] = in_col; - outputs_ptr[i] = outputs->at(i).data(); + if (outputs->at(i) != nullptr) { + outputs_ptr[i] = outputs->at(i)->data(); + } else { + outputs_ptr[i] = nullptr; + } } T** dev_out_gpu_data = @@ -255,7 +264,7 @@ class ConcatGradFunctor { if (sameShape) { KernelConcatGrad<<>>( - input.data(), in_row, in_col, out_col, dev_out_gpu_data); + input.data(), in_row, in_col, out0_col, dev_out_gpu_data); } else { const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace()); KernelConcatGrad<<>>( -- GitLab