concat.cu 9.9 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/operators/math/concat.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20 21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

template <typename T>
C
chengduoZH 已提交
26
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
27 28 29
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
30 31
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
32
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
33 34
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
35 36
      curr_offset = curr_col_offset;
      ++curr_segment;
37
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
38 39 40 41
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
42

C
chengduoZH 已提交
43
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
44
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
45 46 47 48 49 50
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
51
template <typename T>
C
chengduoZH 已提交
52 53 54
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
                             const int out_rows, const int out_cols,
                             T* output_data) {
C
chengduoZH 已提交
55
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
56 57 58 59
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
    T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
60
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
61 62 63
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
64
    }
C
chengduoZH 已提交
65 66 67 68
  }
}

template <typename T>
C
chengduoZH 已提交
69 70 71
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int* out_cols,
                                 int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
72
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
73 74
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
75
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
76 77
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
78 79
      curr_offset = curr_col_offset;
      ++curr_segment;
80
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
81 82 83 84
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
85
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
86 87 88 89 90 91
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
92 93 94 95
  }
}

template <typename T>
C
chengduoZH 已提交
96 97 98
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int fixed_out_col,
                                 T** outputs_data) {
C
chengduoZH 已提交
99
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
100 101 102 103
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
104 105 106 107 108 109
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
110 111 112
  }
}

C
chengduoZH 已提交
113
/*
C
chengduoZH 已提交
114
 * All tensors' dimension should be the same and the values of
115
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
116 117 118 119 120
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
121
                  const std::vector<framework::Tensor>& input, const int axis,
C
chengduoZH 已提交
122
                  framework::Tensor* output) {
C
chengduoZH 已提交
123
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
124 125
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
126 127
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
128
      in_row *= dim_0[i];
C
chengduoZH 已提交
129
    }
C
chengduoZH 已提交
130 131
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
132

C
chengduoZH 已提交
133 134
    framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
    framework::Vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
135 136
    T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());

C
chengduoZH 已提交
137
    inputs_col[0] = 0;
C
chengduoZH 已提交
138
    bool sameShape = true;
C
chengduoZH 已提交
139 140
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
C
chengduoZH 已提交
141
      if (sameShape) {
C
chengduoZH 已提交
142
        if (t_cols != in_col) sameShape = false;
C
chengduoZH 已提交
143
      }
C
chengduoZH 已提交
144 145
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
C
chengduoZH 已提交
146
      inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
C
chengduoZH 已提交
147 148
    }

C
chengduoZH 已提交
149
    T** dev_ins_data =
C
chengduoZH 已提交
150 151
        reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
152
    // computation
C
chengduoZH 已提交
153 154
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
155
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
156 157
    if (out_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((out_col + 31) >> 5) << 5;
158 159
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
160 161
    dim3 block_size = dim3(block_cols, block_rows, 1);

162
    int max_threads = context.GetMaxPhysicalThreadCount();
C
chengduoZH 已提交
163 164 165
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
166
        std::min((out_col + block_cols - 1) / block_cols, max_blocks);
C
chengduoZH 已提交
167
    int grid_rows =
C
chengduoZH 已提交
168
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
169 170
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

C
chengduoZH 已提交
171 172
    if (sameShape) {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
173
          dev_ins_data, in_col, out_row, out_col, output->data<T>());
C
chengduoZH 已提交
174
    } else {
C
chengduoZH 已提交
175
      const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
C
chengduoZH 已提交
176
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
177 178
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
179 180 181 182
    }
  }
};

C
chengduoZH 已提交
183 184
/*
 * All tensors' dimension should be the same and the values of
185
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
186
 */
C
chengduoZH 已提交
187 188 189 190
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
191 192 193
                  const framework::Tensor& input,
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  const int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
194
    // TODO(zcd): Add input data validity checking
195
    int o_num = outputs->size();
C
chengduoZH 已提交
196
    int out_row = 1;
Q
qiaolongfei 已提交
197
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
198
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
199
      out_row *= dim_0[i];
C
chengduoZH 已提交
200 201
    }

Q
qiaolongfei 已提交
202
    int out0_col = ref_inputs[0]->numel() / out_row;
C
chengduoZH 已提交
203
    int in_col = 0, in_row = out_row;
C
chengduoZH 已提交
204 205
    bool sameShape = true;

C
chengduoZH 已提交
206 207
    framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
    framework::Vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
208
    T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
C
chengduoZH 已提交
209

C
chengduoZH 已提交
210 211
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
Q
qiaolongfei 已提交
212
      int t_col = ref_inputs.at(i)->numel() / out_row;
C
chengduoZH 已提交
213
      if (sameShape) {
Q
qiaolongfei 已提交
214
        if (t_col != out0_col) sameShape = false;
C
chengduoZH 已提交
215
      }
C
chengduoZH 已提交
216 217
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
218 219 220 221 222
      if (outputs->at(i) != nullptr) {
        outputs_ptr[i] = outputs->at(i)->data<T>();
      } else {
        outputs_ptr[i] = nullptr;
      }
C
chengduoZH 已提交
223 224
    }

C
chengduoZH 已提交
225
    T** dev_out_gpu_data =
C
chengduoZH 已提交
226 227
        reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
228
    // computation
C
chengduoZH 已提交
229
    const int kThreadsPerBlock = 1024;
230
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
231 232
    if (in_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((in_col + 31) >> 5) << 5;
233 234
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
235 236
    dim3 block_size = dim3(block_cols, block_rows, 1);

237 238 239 240
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
241
        std::min((in_col + block_cols - 1) / block_cols, max_blocks);
242
    int grid_rows =
C
chengduoZH 已提交
243
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
244 245 246 247
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

    if (sameShape) {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
Q
qiaolongfei 已提交
248
          input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
C
chengduoZH 已提交
249
    } else {
C
chengduoZH 已提交
250
      const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
C
chengduoZH 已提交
251
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
252 253
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
254
    }
C
chengduoZH 已提交
255 256 257 258 259 260 261 262
  }
};

template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
263 264 265 266 267
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
268 269 270
}  // namespace math
}  // namespace operators
}  // namespace paddle