concat.cu 10.1 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/operators/math/concat.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20 21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

template <typename T>
C
chengduoZH 已提交
26
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
27 28 29
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
30 31
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
32
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
33 34
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
35 36
      curr_offset = curr_col_offset;
      ++curr_segment;
37
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
38 39 40 41
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
42

C
chengduoZH 已提交
43
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
44
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
45 46 47 48 49 50
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
51
template <typename T>
C
chengduoZH 已提交
52 53 54
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
                             const int out_rows, const int out_cols,
                             T* output_data) {
C
chengduoZH 已提交
55
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
56 57 58 59
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
    T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
60
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
61 62 63
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
64
    }
C
chengduoZH 已提交
65 66 67 68
  }
}

template <typename T>
C
chengduoZH 已提交
69 70 71
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int* out_cols,
                                 int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
72
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
73 74
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
75
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
76 77
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
78 79
      curr_offset = curr_col_offset;
      ++curr_segment;
80
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
81 82 83 84
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
85
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
86 87 88 89 90 91
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
92 93 94 95
  }
}

template <typename T>
C
chengduoZH 已提交
96 97 98
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int fixed_out_col,
                                 T** outputs_data) {
C
chengduoZH 已提交
99
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
100 101 102 103
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
104 105 106 107 108 109
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
110 111 112
  }
}

C
chengduoZH 已提交
113
/*
C
chengduoZH 已提交
114
 * All tensors' dimension should be the same and the values of
115
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
116 117 118 119 120
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
121
                  const std::vector<framework::Tensor>& input, const int axis,
C
chengduoZH 已提交
122
                  framework::Tensor* output) {
C
chengduoZH 已提交
123
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
124 125
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
126 127
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
128
      in_row *= dim_0[i];
C
chengduoZH 已提交
129
    }
C
chengduoZH 已提交
130 131
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
132

C
chengduoZH 已提交
133 134
    framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
    framework::Vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
135 136
    T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());

C
chengduoZH 已提交
137
    inputs_col[0] = 0;
C
chengduoZH 已提交
138
    bool sameShape = true;
C
chengduoZH 已提交
139 140
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
C
chengduoZH 已提交
141
      if (sameShape) {
C
chengduoZH 已提交
142
        if (t_cols != in_col) sameShape = false;
C
chengduoZH 已提交
143
      }
C
chengduoZH 已提交
144 145
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
C
chengduoZH 已提交
146
      inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
C
chengduoZH 已提交
147 148
    }

C
chengduoZH 已提交
149
    T** dev_ins_data =
C
chengduoZH 已提交
150 151
        reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
152
    // computation
C
chengduoZH 已提交
153 154
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
155
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
156 157
    if (out_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((out_col + 31) >> 5) << 5;
158 159
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
160 161
    dim3 block_size = dim3(block_cols, block_rows, 1);

162
    int max_threads = context.GetMaxPhysicalThreadCount();
C
chengduoZH 已提交
163 164 165
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
166
        std::min((out_col + block_cols - 1) / block_cols, max_blocks);
C
chengduoZH 已提交
167
    int grid_rows =
C
chengduoZH 已提交
168
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
169 170
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

C
chengduoZH 已提交
171 172
    if (sameShape) {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
173
          dev_ins_data, in_col, out_row, out_col, output->data<T>());
C
chengduoZH 已提交
174
    } else {
C
chengduoZH 已提交
175
      const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
C
chengduoZH 已提交
176
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
177 178
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
179
    }
D
dzhwinter 已提交
180 181 182
    // Wait() must be called because `inputs_data` may be destructed before
    // kernel ends
    context.Wait();
C
chengduoZH 已提交
183 184 185
  }
};

C
chengduoZH 已提交
186 187
/*
 * All tensors' dimension should be the same and the values of
188
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
189
 */
C
chengduoZH 已提交
190 191 192 193
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
194
                  const framework::Tensor& input,
195
                  const std::vector<const framework::LoDTensor*>& ref_inputs,
Q
qiaolongfei 已提交
196
                  const int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
197
    // TODO(zcd): Add input data validity checking
198
    int o_num = outputs->size();
C
chengduoZH 已提交
199
    int out_row = 1;
Q
qiaolongfei 已提交
200
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
201
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
202
      out_row *= dim_0[i];
C
chengduoZH 已提交
203 204
    }

Q
qiaolongfei 已提交
205
    int out0_col = ref_inputs[0]->numel() / out_row;
C
chengduoZH 已提交
206
    int in_col = 0, in_row = out_row;
C
chengduoZH 已提交
207 208
    bool sameShape = true;

C
chengduoZH 已提交
209 210
    framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
    framework::Vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
211
    T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
C
chengduoZH 已提交
212

C
chengduoZH 已提交
213 214
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
Q
qiaolongfei 已提交
215
      int t_col = ref_inputs.at(i)->numel() / out_row;
C
chengduoZH 已提交
216
      if (sameShape) {
Q
qiaolongfei 已提交
217
        if (t_col != out0_col) sameShape = false;
C
chengduoZH 已提交
218
      }
C
chengduoZH 已提交
219 220
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
221 222 223 224 225
      if (outputs->at(i) != nullptr) {
        outputs_ptr[i] = outputs->at(i)->data<T>();
      } else {
        outputs_ptr[i] = nullptr;
      }
C
chengduoZH 已提交
226 227
    }

C
chengduoZH 已提交
228
    T** dev_out_gpu_data =
C
chengduoZH 已提交
229 230
        reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
231
    // computation
C
chengduoZH 已提交
232
    const int kThreadsPerBlock = 1024;
233
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
234 235
    if (in_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((in_col + 31) >> 5) << 5;
236 237
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
238 239
    dim3 block_size = dim3(block_cols, block_rows, 1);

240 241 242 243
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
244
        std::min((in_col + block_cols - 1) / block_cols, max_blocks);
245
    int grid_rows =
C
chengduoZH 已提交
246
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
247 248 249 250
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

    if (sameShape) {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
Q
qiaolongfei 已提交
251
          input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
C
chengduoZH 已提交
252
    } else {
C
chengduoZH 已提交
253
      const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
C
chengduoZH 已提交
254
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
255 256
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
257
    }
D
dzhwinter 已提交
258 259 260
    // Wait() must be called because `outputs_data` may be destructed before
    // kernel ends
    context.Wait();
C
chengduoZH 已提交
261 262 263 264 265 266 267 268
  }
};

template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
269 270 271 272 273
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
274 275 276
}  // namespace math
}  // namespace operators
}  // namespace paddle