selected_rows_functor.cu 23.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15
#include <set>
16
#include <vector>
T
typhoonzero 已提交
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/selected_rows_functor.h"
19
#include "paddle/fluid/platform/bfloat16.h"
20
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
C
chengduo 已提交
21
#include "paddle/fluid/platform/float16.h"
22
#include "paddle/phi/kernels/funcs/math_function.h"
23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
28 29
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
30 31
                  const phi::SelectedRows& input1,
                  const phi::SelectedRows& input2, phi::SelectedRows* output) {
32
    auto in1_height = input1.height();
33 34 35 36 37 38
    PADDLE_ENFORCE_EQ(
        in1_height, input2.height(),
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height  = "
                                          "[%d], second input height = [%d]",
                                          in1_height, input2.height()));
39 40
    output->set_height(in1_height);

D
dzhwinter 已提交
41
    framework::Vector<int64_t> in1_rows(input1.rows());
42 43 44 45 46 47 48 49 50 51 52 53 54 55
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
56 57 58 59 60 61 62 63 64 65 66 67
    PADDLE_ENFORCE_EQ(
        in1_row_numel, in2_value.numel() / in2_rows.size(),
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, in2_value.numel() / in2_rows.size()));
    PADDLE_ENFORCE_EQ(
        in1_row_numel, out_value->numel() / out_rows.size(),
        platform::errors::InvalidArgument(
            "The input and oupput width must be equal."
            "But recieved input width = [%d], output width = [%d]",
            in1_row_numel, out_value->numel() / out_rows.size()));
68 69 70 71 72

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
73 74 75
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
76
    auto in2_place = input2.place();
77 78 79
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in2_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
80
    auto out_place = context.GetPlace();
81 82 83
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
84

85
    memory::Copy(out_place, out_data, in1_place, in1_data,
86
                 in1_value.numel() * sizeof(T), context.stream());
87 88

    auto* in2_data = in2_value.data<T>();
89
    memory::Copy(out_place, out_data + in1_value.numel(), in2_place, in2_data,
Q
QI JUN 已提交
90
                 in2_value.numel() * sizeof(T), context.stream());
91 92 93
  }
};

Q
QI JUN 已提交
94 95
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
96 97

namespace {
Q
QI JUN 已提交
98
template <typename T, int block_size>
99 100
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
                                            const int64_t* rows, T* tensor_out,
Q
QI JUN 已提交
101
                                            int64_t row_numel) {
C
chengduo 已提交
102
  const int ty = blockIdx.x;
103 104 105 106 107 108 109 110 111
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
112
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
113 114 115 116 117
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
118 119
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
120
                  const phi::SelectedRows& input1,
121 122 123 124
                  const framework::Tensor& input2, framework::Tensor* output) {
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
125 126 127 128 129 130 131 132 133 134 135 136
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument(
            "The two inputs height must be equal."
            "But recieved first input height = [%d], first input height = [%d]",
            in1_height, in2_dims[0]));
    PADDLE_ENFORCE_EQ(
        in1_height, out_dims[0],
        platform::errors::InvalidArgument(
            "The input and output height must be equal."
            "But recieved input height = [%d], output height = [%d]",
            in1_height, out_dims[0]));
137 138

    auto& in1_value = input1.value();
139
    auto& in1_rows = input1.rows();
140 141

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
142 143 144 145 146 147 148 149 150 151 152 153
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2.numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2.numel() / in1_height));
    PADDLE_ENFORCE_EQ(
        in1_row_numel, output->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The input and output width must be equal."
            "But recieved input width = [%d], output width = [%d]",
            in1_row_numel, output->numel() / in1_height));
154 155 156 157 158

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

159
    phi::funcs::SetConstant<platform::CUDADeviceContext, T> functor;
C
chengduo 已提交
160
    functor(context, output, static_cast<T>(0));
161

Q
QI JUN 已提交
162
    const int block_size = 256;
163
    dim3 threads(block_size, 1);
C
chengduo 已提交
164
    dim3 grid(in1_rows.size(), 1);
165
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
Q
QI JUN 已提交
166 167
    SelectedRowsAddTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
168
        in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), out_data,
Y
Yu Yang 已提交
169
        in1_row_numel);
170 171 172

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
173
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
174 175 176
  }
};

Q
QI JUN 已提交
177 178
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
C
chengduo 已提交
179 180 181
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
                                      platform::float16>;
Q
QI JUN 已提交
182 183

template <typename T>
Q
QI JUN 已提交
184 185
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
186 187
                  const phi::SelectedRows& input1, const int64_t input2_offset,
                  phi::SelectedRows* input2) {
Q
QI JUN 已提交
188
    auto in1_height = input1.height();
189 190 191 192 193 194
    PADDLE_ENFORCE_EQ(
        in1_height, input2->height(),
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, input2->height()));
Q
QI JUN 已提交
195

196
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
197 198 199 200 201 202
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
203
    paddle::framework::MixVector<int64_t> mixv_in2_rows(&in2_rows);
Y
Fix CI  
Yu Yang 已提交
204
    if (in1_rows.size()) {
205
      mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Y
Fix CI  
Yu Yang 已提交
206
    }
Q
QI JUN 已提交
207 208

    auto in1_place = input1.place();
209 210 211
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
Q
QI JUN 已提交
212
    auto in2_place = input2->place();
213 214 215
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
Q
QI JUN 已提交
216 217 218

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
219
    memory::Copy(in2_place, in2_data + input2_offset, in1_place, in1_data,
Q
QI JUN 已提交
220
                 in1_value.numel() * sizeof(T), context.stream());
Q
QI JUN 已提交
221 222 223
  }
};

Q
QI JUN 已提交
224 225 226 227
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
228 229
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
                                  platform::float16>;
Q
QI JUN 已提交
230 231 232 233 234 235 236

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
C
chengduo 已提交
237
  const int ty = blockIdx.x;
Q
QI JUN 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
252 253
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
254
                  const phi::SelectedRows& input1, framework::Tensor* input2) {
Q
QI JUN 已提交
255 256
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
257 258 259 260 261 262
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, in2_dims[0]));
Q
QI JUN 已提交
263 264

    auto& in1_value = input1.value();
265
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
266 267

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
268 269 270 271 272 273
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2->numel() / in1_height));
Q
QI JUN 已提交
274 275 276 277 278

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
C
chengduo 已提交
279
    dim3 grid(in1_rows.size(), 1);
280
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
Q
QI JUN 已提交
281 282
    SelectedRowsAddToTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
283
        in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), in2_data,
Y
Yu Yang 已提交
284
        in1_row_numel);
Q
QI JUN 已提交
285 286 287
  }
};

Q
QI JUN 已提交
288 289 290 291
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
292 293
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
                                        platform::float16>;
T
typhoonzero 已提交
294 295 296 297 298 299 300

namespace scatter {

template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
                               T* out, const int64_t* out_rows,
                               size_t out_rows_size, int64_t row_numel) {
S
sneaxiy 已提交
301
  const int ty = blockIdx.x;
T
typhoonzero 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

322 323 324
template <typename DeviceContext, typename T>
struct MergeAddImpl {
  phi::SelectedRows operator()(const DeviceContext& context,
325 326 327
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
S
sneaxiy 已提交
328 329 330 331
    (*this)(context, input, &out);
    return out;
  }

332 333
  void operator()(const DeviceContext& context, const phi::SelectedRows& input,
                  phi::SelectedRows* output, const bool sorted_result = false) {
D
dzhwinter 已提交
334
    framework::Vector<int64_t> input_rows(input.rows());
Q
Qiao Longfei 已提交
335 336 337 338
    if (input_rows.size() == 0) {
      return;
    }

339
    phi::SelectedRows& out = *output;
T
typhoonzero 已提交
340
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
Q
Qiao Longfei 已提交
341 342
    std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
T
typhoonzero 已提交
343 344

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
345 346 347 348

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
349
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
T
typhoonzero 已提交
350 351
        context.GetPlace());

352
    phi::funcs::SetConstant<DeviceContext, T> constant_functor;
C
chengduo 已提交
353
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
T
typhoonzero 已提交
354

T
wip  
typhoonzero 已提交
355
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
356 357 358 359
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
S
sneaxiy 已提交
360
    dim3 grid1(input_rows.size(), 1);
T
typhoonzero 已提交
361

362 363
    paddle::framework::MixVector<int64_t> mix_vector_input(&input_rows);
    paddle::framework::MixVector<int64_t> mix_vector_out(out.mutable_rows());
S
sneaxiy 已提交
364
    MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
365 366 367 368
        input_data, mix_vector_input.CUDAData(context.GetPlace()), out_data,
        mix_vector_out.CUDAMutableData(context.GetPlace()), out.rows().size(),
        input_width);
    mix_vector_out.CopyToCPU();
T
typhoonzero 已提交
369
  }
370

371
  void operator()(const DeviceContext& context,
372 373
                  const std::vector<const phi::SelectedRows*>& inputs,
                  phi::SelectedRows* output, const bool sorted_result = false) {
374
    if (inputs.size() == 0) {
M
minqiyang 已提交
375
      VLOG(3) << "no input! return";
376 377
      return;
    }
378
    const phi::SelectedRows* has_value_input = nullptr;
379
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
380
      if (in->rows().size() > 0) {
381 382 383 384 385
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
386
      VLOG(3) << "no input has value! just return" << std::endl;
387 388 389 390
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
391
    phi::SelectedRows& out = *output;
392 393
    std::set<int64_t> merged_row_set;
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
394
      if (input->rows().size() == 0) {
395 396
        continue;
      }
397
      PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
398 399 400
                        platform::errors::InvalidArgument(
                            "All input should have same "
                            "dimension except for the first one."));
401
      PADDLE_ENFORCE_EQ(input_height, input->height(),
402 403
                        platform::errors::InvalidArgument(
                            "All input should have same height."));
404 405
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
Q
Qiao Longfei 已提交
406
    std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
Q
format  
Qiao Longfei 已提交
407
                                        merged_row_set.end());
Q
Qiao Longfei 已提交
408
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
409 410 411 412

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
413
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
414 415
        context.GetPlace());

416
    phi::funcs::SetConstant<DeviceContext, T> constant_functor;
C
chengduo 已提交
417
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
418 419 420 421 422 423 424

    auto* out_data = out.mutable_value()->data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);

    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
425
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
426 427
        continue;
      }
428 429
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();
430 431
      dim3 grid1(input_rows.size(), 1);

432 433
      paddle::framework::MixVector<int64_t> mix_vector_input(&input_rows);
      paddle::framework::MixVector<int64_t> mix_vector_out(out.mutable_rows());
434
      MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
435 436 437 438
          input_data, mix_vector_input.CUDAData(context.GetPlace()), out_data,
          mix_vector_out.CUDAMutableData(context.GetPlace()), out.rows().size(),
          input_width);
      mix_vector_out.CopyToCPU();
439 440
    }
  }
T
typhoonzero 已提交
441 442
};

443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const platform::CUDADeviceContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<platform::CUDADeviceContext, T>()(context, input,
                                                          sorted_result);
  }

  void operator()(const platform::CUDADeviceContext& context,
                  const phi::SelectedRows& input, phi::SelectedRows* output,
                  const bool sorted_result) {
    MergeAddImpl<platform::CUDADeviceContext, T>()(context, input, output,
                                                   sorted_result);
  }

  void operator()(const platform::CUDADeviceContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
                  phi::SelectedRows* output, const bool sorted_result) {
    MergeAddImpl<platform::CUDADeviceContext, T>()(context, inputs, output,
                                                   sorted_result);
  }
};

template <typename T>
struct MergeAdd<phi::GPUContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const phi::GPUContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<phi::GPUContext, T>()(context, input, sorted_result);
  }

  void operator()(const phi::GPUContext& context,
                  const phi::SelectedRows& input, phi::SelectedRows* output,
                  const bool sorted_result) {
    MergeAddImpl<phi::GPUContext, T>()(context, input, output, sorted_result);
  }

  void operator()(const phi::GPUContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
                  phi::SelectedRows* output, const bool sorted_result) {
    MergeAddImpl<phi::GPUContext, T>()(context, inputs, output, sorted_result);
  }
};

#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype)                    \
  template struct MergeAddImpl<platform::CUDADeviceContext, dtype>; \
  template struct MergeAddImpl<phi::GPUContext, dtype>;             \
  template struct MergeAdd<platform::CUDADeviceContext, dtype>;     \
  template struct MergeAdd<phi::GPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int64_t)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::float16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<double>)
T
wip  
typhoonzero 已提交
506 507 508 509 510

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
                                     const int64_t* rows, const ScatterOps& op,
                                     T* tensor_out, int64_t row_numel) {
C
chengduo 已提交
511
  const int ty = blockIdx.x;
T
wip  
typhoonzero 已提交
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
557 558
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
559
                  const ScatterOps& op, const phi::SelectedRows& input1,
T
typhoonzero 已提交
560
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
561 562
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
563 564
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
565 566 567

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
568 569 570 571 572 573
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, in2_dims[0]));
T
wip  
typhoonzero 已提交
574 575 576 577 578

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
579 580 581 582 583 584
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
585

T
typhoonzero 已提交
586 587
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
588

T
typhoonzero 已提交
589
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
C
chengduo 已提交
590
    dim3 grid(in1_rows.size(), 1);
T
typhoonzero 已提交
591
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
D
dzhwinter 已提交
592 593
        grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
                                              op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
594 595
  }
};
T
typhoonzero 已提交
596
}  // namespace scatter
597 598 599
}  // namespace math
}  // namespace operators
}  // namespace paddle