selected_rows_functor.cu 17.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15
#include <set>
16
#include <vector>
T
typhoonzero 已提交
17

Y
Yi Wang 已提交
18 19
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduo 已提交
21
#include "paddle/fluid/platform/float16.h"
22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
27 28
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
29 30 31 32 33 34 35
                  const framework::SelectedRows& input1,
                  const framework::SelectedRows& input2,
                  framework::SelectedRows* output) {
    auto in1_height = input1.height();
    PADDLE_ENFORCE_EQ(in1_height, input2.height());
    output->set_height(in1_height);

D
dzhwinter 已提交
36
    framework::Vector<int64_t> in1_rows(input1.rows());
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
    PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
    auto in2_place = input2.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
    auto out_place = context.GetPlace();
    PADDLE_ENFORCE(platform::is_gpu_place(out_place));

64 65 66
    memory::Copy(boost::get<platform::CUDAPlace>(out_place), out_data,
                 boost::get<platform::CUDAPlace>(in1_place), in1_data,
                 in1_value.numel() * sizeof(T), context.stream());
67 68

    auto* in2_data = in2_value.data<T>();
D
dzhwinter 已提交
69
    memory::Copy(boost::get<platform::CUDAPlace>(out_place),
Q
QI JUN 已提交
70
                 out_data + in1_value.numel(),
D
dzhwinter 已提交
71
                 boost::get<platform::CUDAPlace>(in2_place), in2_data,
Q
QI JUN 已提交
72
                 in2_value.numel() * sizeof(T), context.stream());
73 74 75
  }
};

Q
QI JUN 已提交
76 77
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
78 79

namespace {
Q
QI JUN 已提交
80
template <typename T, int block_size>
81 82
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
                                            const int64_t* rows, T* tensor_out,
Q
QI JUN 已提交
83
                                            int64_t row_numel) {
C
chengduo 已提交
84
  const int ty = blockIdx.x;
85 86 87 88 89 90 91 92 93
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
94
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
95 96 97 98 99
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
100 101
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
102 103 104 105 106 107 108 109 110
                  const framework::SelectedRows& input1,
                  const framework::Tensor& input2, framework::Tensor* output) {
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
    PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);

    auto& in1_value = input1.value();
111
    auto& in1_rows = input1.rows();
112 113 114 115 116 117 118 119 120

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
    PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

Q
QI JUN 已提交
121
    SetConstant<platform::CUDADeviceContext, T> functor;
C
chengduo 已提交
122
    functor(context, output, static_cast<T>(0));
123

Q
QI JUN 已提交
124
    const int block_size = 256;
125
    dim3 threads(block_size, 1);
C
chengduo 已提交
126
    dim3 grid(in1_rows.size(), 1);
Q
QI JUN 已提交
127 128
    SelectedRowsAddTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
129 130
        in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
        in1_row_numel);
131 132 133

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
134
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
135 136 137
  }
};

Q
QI JUN 已提交
138 139
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
C
chengduo 已提交
140 141 142
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
                                      platform::float16>;
Q
QI JUN 已提交
143 144

template <typename T>
Q
QI JUN 已提交
145 146
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
147 148 149 150 151 152
                  const framework::SelectedRows& input1,
                  const int64_t input2_offset,
                  framework::SelectedRows* input2) {
    auto in1_height = input1.height();
    PADDLE_ENFORCE_EQ(in1_height, input2->height());

153
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
154 155 156 157 158 159
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
Y
Fix CI  
Yu Yang 已提交
160 161 162
    if (in1_rows.size()) {
      in2_rows.Extend(in1_rows.begin(), in1_rows.end());
    }
Q
QI JUN 已提交
163 164 165 166 167 168 169 170

    auto in1_place = input1.place();
    PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
    auto in2_place = input2->place();
    PADDLE_ENFORCE(platform::is_gpu_place(in2_place));

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
D
dzhwinter 已提交
171
    memory::Copy(boost::get<platform::CUDAPlace>(in2_place),
Q
QI JUN 已提交
172
                 in2_data + input2_offset,
D
dzhwinter 已提交
173
                 boost::get<platform::CUDAPlace>(in1_place), in1_data,
Q
QI JUN 已提交
174
                 in1_value.numel() * sizeof(T), context.stream());
Q
QI JUN 已提交
175 176 177
  }
};

Q
QI JUN 已提交
178 179 180 181
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
182 183
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
                                  platform::float16>;
Q
QI JUN 已提交
184 185 186 187 188 189 190

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
C
chengduo 已提交
191
  const int ty = blockIdx.x;
Q
QI JUN 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
206 207
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
208 209 210 211 212 213 214
                  const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);

    auto& in1_value = input1.value();
215
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
216 217 218 219 220 221 222 223

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
C
chengduo 已提交
224
    dim3 grid(in1_rows.size(), 1);
Q
QI JUN 已提交
225 226
    SelectedRowsAddToTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
227 228
        in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
        in1_row_numel);
Q
QI JUN 已提交
229 230 231
  }
};

Q
QI JUN 已提交
232 233 234 235
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
236 237
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
                                        platform::float16>;
T
typhoonzero 已提交
238 239 240 241 242 243 244

namespace scatter {

template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
                               T* out, const int64_t* out_rows,
                               size_t out_rows_size, int64_t row_numel) {
S
sneaxiy 已提交
245
  const int ty = blockIdx.x;
T
typhoonzero 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

template <typename T>
T
typhoonzero 已提交
267 268
struct MergeAdd<platform::CUDADeviceContext, T> {
  framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
269 270
                                     const framework::SelectedRows& input,
                                     const bool sorted_result = false) {
T
wip  
typhoonzero 已提交
271
    framework::SelectedRows out;
S
sneaxiy 已提交
272 273 274 275 276 277
    (*this)(context, input, &out);
    return out;
  }

  void operator()(const platform::CUDADeviceContext& context,
                  const framework::SelectedRows& input,
M
minqiyang 已提交
278 279
                  framework::SelectedRows* output,
                  const bool sorted_result = false) {
D
dzhwinter 已提交
280
    framework::Vector<int64_t> input_rows(input.rows());
Q
Qiao Longfei 已提交
281 282 283 284 285
    if (input_rows.size() == 0) {
      return;
    }

    framework::SelectedRows& out = *output;
T
typhoonzero 已提交
286
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
Q
Qiao Longfei 已提交
287 288
    std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
T
typhoonzero 已提交
289 290

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
291 292 293 294

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
T
typhoonzero 已提交
295 296 297 298 299
        framework::make_ddim(
            {static_cast<int64_t>(merge_rows.size()), input_width}),
        context.GetPlace());

    math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
C
chengduo 已提交
300
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
T
typhoonzero 已提交
301

T
wip  
typhoonzero 已提交
302
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
303 304 305 306
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
S
sneaxiy 已提交
307
    dim3 grid1(input_rows.size(), 1);
T
typhoonzero 已提交
308

S
sneaxiy 已提交
309
    MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
310 311 312
        input_data, input_rows.CUDAData(context.GetPlace()), out_data,
        out.mutable_rows()->CUDAMutableData(context.GetPlace()),
        out.rows().size(), input_width);
T
typhoonzero 已提交
313
  }
314 315 316

  void operator()(const platform::CUDADeviceContext& context,
                  const std::vector<const framework::SelectedRows*>& inputs,
M
minqiyang 已提交
317 318
                  framework::SelectedRows* output,
                  const bool sorted_result = false) {
319
    if (inputs.size() == 0) {
M
minqiyang 已提交
320
      VLOG(3) << "no input! return";
321 322 323 324
      return;
    }
    const framework::SelectedRows* has_value_input = nullptr;
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
325
      if (in->rows().size() > 0) {
326 327 328 329 330
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
331
      VLOG(3) << "no input has value! just return" << std::endl;
332 333 334 335
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
336 337 338
    framework::SelectedRows& out = *output;
    std::set<int64_t> merged_row_set;
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
339
      if (input->rows().size() == 0) {
340 341
        continue;
      }
342 343 344 345 346 347 348
      PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
                        "all input should have same "
                        "dimension except for the first one");
      PADDLE_ENFORCE_EQ(input_height, input->height(),
                        "all input should have same height");
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
Q
Qiao Longfei 已提交
349
    std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
Q
format  
Qiao Longfei 已提交
350
                                        merged_row_set.end());
Q
Qiao Longfei 已提交
351
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
352 353 354 355 356 357 358 359

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
        framework::make_ddim(
            {static_cast<int64_t>(merge_rows.size()), input_width}),
        context.GetPlace());

Q
Qiao Longfei 已提交
360
    math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
C
chengduo 已提交
361
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
362 363 364 365 366 367 368

    auto* out_data = out.mutable_value()->data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);

    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
369
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
370 371
        continue;
      }
372 373
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();
374 375 376 377 378 379 380 381
      dim3 grid1(input_rows.size(), 1);

      MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
          input_data, input_rows.CUDAData(context.GetPlace()), out_data,
          out.mutable_rows()->CUDAMutableData(context.GetPlace()),
          out.rows().size(), input_width);
    }
  }
T
typhoonzero 已提交
382 383
};

T
typhoonzero 已提交
384 385 386 387
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
388
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
T
wip  
typhoonzero 已提交
389 390 391 392 393

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
                                     const int64_t* rows, const ScatterOps& op,
                                     T* tensor_out, int64_t row_numel) {
C
chengduo 已提交
394
  const int ty = blockIdx.x;
T
wip  
typhoonzero 已提交
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
440 441 442 443
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const ScatterOps& op, const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
444 445
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
446 447
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
448 449 450 451 452 453 454 455 456 457 458

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

T
typhoonzero 已提交
459 460
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
461

T
typhoonzero 已提交
462
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
C
chengduo 已提交
463
    dim3 grid(in1_rows.size(), 1);
T
typhoonzero 已提交
464
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
D
dzhwinter 已提交
465 466
        grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
                                              op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
467 468
  }
};
T
typhoonzero 已提交
469
}  // namespace scatter
470 471 472
}  // namespace math
}  // namespace operators
}  // namespace paddle