selected_rows_functor.cu 23.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15
#include <set>
16
#include <vector>
T
typhoonzero 已提交
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/selected_rows_functor.h"
19
#include "paddle/fluid/platform/bfloat16.h"
20
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
C
chengduo 已提交
21
#include "paddle/fluid/platform/float16.h"
22
#include "paddle/phi/kernels/funcs/math_function.h"
23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
28 29
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
30
                  const phi::SelectedRows& input1,
31 32
                  const phi::SelectedRows& input2,
                  phi::SelectedRows* output) {
33
    auto in1_height = input1.height();
34
    PADDLE_ENFORCE_EQ(
35 36
        in1_height,
        input2.height(),
37
        platform::errors::InvalidArgument("The two inputs height must be equal."
38
                                          "But received first input height  = "
39
                                          "[%d], second input height = [%d]",
40 41
                                          in1_height,
                                          input2.height()));
42 43
    output->set_height(in1_height);

D
dzhwinter 已提交
44
    framework::Vector<int64_t> in1_rows(input1.rows());
45 46 47 48 49 50 51 52 53 54 55 56 57 58
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
59
    PADDLE_ENFORCE_EQ(
60 61
        in1_row_numel,
        in2_value.numel() / in2_rows.size(),
62 63
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
64
            "But received first input width = [%d], second input width = [%d]",
65 66
            in1_row_numel,
            in2_value.numel() / in2_rows.size()));
67
    PADDLE_ENFORCE_EQ(
68 69
        in1_row_numel,
        out_value->numel() / out_rows.size(),
70 71
        platform::errors::InvalidArgument(
            "The input and oupput width must be equal."
72
            "But received input width = [%d], output width = [%d]",
73 74
            in1_row_numel,
            out_value->numel() / out_rows.size()));
75 76 77 78 79

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
80 81
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place),
                      true,
82
                      platform::errors::InvalidArgument(
83
                          "The running environment is not on the GPU place."));
84
    auto in2_place = input2.place();
85 86
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in2_place),
                      true,
87
                      platform::errors::InvalidArgument(
88
                          "The running environment is not on the GPU place."));
89
    auto out_place = context.GetPlace();
90 91
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_place),
                      true,
92
                      platform::errors::InvalidArgument(
93
                          "The running environment is not on the GPU place."));
94

95 96 97 98 99 100
    memory::Copy(out_place,
                 out_data,
                 in1_place,
                 in1_data,
                 in1_value.numel() * sizeof(T),
                 context.stream());
101 102

    auto* in2_data = in2_value.data<T>();
103 104 105 106 107 108
    memory::Copy(out_place,
                 out_data + in1_value.numel(),
                 in2_place,
                 in2_data,
                 in2_value.numel() * sizeof(T),
                 context.stream());
109 110 111
  }
};

Q
QI JUN 已提交
112 113
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
114 115

namespace {
Q
QI JUN 已提交
116
template <typename T, int block_size>
117
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
118 119
                                            const int64_t* rows,
                                            T* tensor_out,
Q
QI JUN 已提交
120
                                            int64_t row_numel) {
C
chengduo 已提交
121
  const int ty = blockIdx.x;
122 123 124 125 126 127 128 129 130
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
131
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
132 133 134 135
  }
}
}  // namespace

H
hong 已提交
136 137 138 139
template <typename T>
struct SelectedRowsAddTensor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
                  const phi::SelectedRows& input1,
140 141
                  const framework::Tensor& input2,
                  framework::Tensor* output) {
H
hong 已提交
142 143 144 145
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
    PADDLE_ENFORCE_EQ(
146 147
        in1_height,
        in2_dims[0],
H
hong 已提交
148 149
        platform::errors::InvalidArgument(
            "The two inputs height must be equal."
150
            "But received first input height = [%d], first input height = [%d]",
151 152
            in1_height,
            in2_dims[0]));
H
hong 已提交
153
    PADDLE_ENFORCE_EQ(
154 155
        in1_height,
        out_dims[0],
H
hong 已提交
156 157
        platform::errors::InvalidArgument(
            "The input and output height must be equal."
158
            "But received input height = [%d], output height = [%d]",
159 160
            in1_height,
            out_dims[0]));
H
hong 已提交
161 162 163 164 165 166

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
167 168
        in1_row_numel,
        input2.numel() / in1_height,
H
hong 已提交
169 170
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
171
            "But received first input width = [%d], second input width = [%d]",
172 173
            in1_row_numel,
            input2.numel() / in1_height));
H
hong 已提交
174
    PADDLE_ENFORCE_EQ(
175 176
        in1_row_numel,
        output->numel() / in1_height,
H
hong 已提交
177 178
        platform::errors::InvalidArgument(
            "The input and output width must be equal."
179
            "But received input width = [%d], output width = [%d]",
180 181
            in1_row_numel,
            output->numel() / in1_height));
H
hong 已提交
182 183 184 185 186 187 188 189 190 191 192 193

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

    phi::funcs::SetConstant<phi::GPUContext, T> functor;
    functor(context, output, static_cast<T>(0));

    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid(in1_rows.size(), 1);
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
194 195
    SelectedRowsAddTensorKernel<T, block_size>
        <<<grid, threads, 0, context.stream()>>>(
196 197 198
            in1_data,
            mixv_in1_rows.CUDAData(context.GetPlace()),
            out_data,
199
            in1_row_numel);
H
hong 已提交
200 201 202 203 204 205 206 207 208 209 210 211

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
  }
};

template struct SelectedRowsAddTensor<phi::GPUContext, float>;
template struct SelectedRowsAddTensor<phi::GPUContext, double>;
template struct SelectedRowsAdd<phi::GPUContext, platform::float16>;
template struct SelectedRowsAddTensor<phi::GPUContext, platform::float16>;

Q
QI JUN 已提交
212
template <typename T>
Q
QI JUN 已提交
213 214
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
215 216
                  const phi::SelectedRows& input1,
                  const int64_t input2_offset,
217
                  phi::SelectedRows* input2) {
Q
QI JUN 已提交
218
    auto in1_height = input1.height();
219
    PADDLE_ENFORCE_EQ(
220 221
        in1_height,
        input2->height(),
222
        platform::errors::InvalidArgument("The two inputs height must be equal."
223
                                          "But received first input height = "
224
                                          "[%d], second input height = [%d]",
225 226
                                          in1_height,
                                          input2->height()));
Q
QI JUN 已提交
227

228
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
229 230 231 232 233 234
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
235
    paddle::framework::MixVector<int64_t> mixv_in2_rows(&in2_rows);
Y
Fix CI  
Yu Yang 已提交
236
    if (in1_rows.size()) {
237
      mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Y
Fix CI  
Yu Yang 已提交
238
    }
Q
QI JUN 已提交
239 240

    auto in1_place = input1.place();
241 242
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place),
                      true,
243
                      platform::errors::InvalidArgument(
244
                          "The running environment is not on the GPU place."));
Q
QI JUN 已提交
245
    auto in2_place = input2->place();
246 247
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place),
                      true,
248
                      platform::errors::InvalidArgument(
249
                          "The running environment is not on the GPU place."));
Q
QI JUN 已提交
250 251 252

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
253 254 255 256 257 258
    memory::Copy(in2_place,
                 in2_data + input2_offset,
                 in1_place,
                 in1_data,
                 in1_value.numel() * sizeof(T),
                 context.stream());
Q
QI JUN 已提交
259 260 261
  }
};

Q
QI JUN 已提交
262 263 264 265
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
266 267
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
                                  platform::float16>;
Q
QI JUN 已提交
268 269 270 271 272 273 274

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
C
chengduo 已提交
275
  const int ty = blockIdx.x;
Q
QI JUN 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

H
hong 已提交
289 290 291
template <typename T>
struct SelectedRowsAddToTensor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
292 293
                  const phi::SelectedRows& input1,
                  framework::Tensor* input2) {
H
hong 已提交
294 295 296
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(
297 298
        in1_height,
        in2_dims[0],
H
hong 已提交
299
        platform::errors::InvalidArgument("The two inputs height must be equal."
300
                                          "But received first input height = "
H
hong 已提交
301
                                          "[%d], second input height = [%d]",
302 303
                                          in1_height,
                                          in2_dims[0]));
H
hong 已提交
304 305 306 307 308 309

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
310 311
        in1_row_numel,
        input2->numel() / in1_height,
H
hong 已提交
312 313
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
314
            "But received first input width = [%d], second input width = [%d]",
315 316
            in1_row_numel,
            input2->numel() / in1_height));
H
hong 已提交
317 318 319 320 321 322 323

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid(in1_rows.size(), 1);
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
324 325
    SelectedRowsAddToTensorKernel<T, block_size>
        <<<grid, threads, 0, context.stream()>>>(
326 327 328
            in1_data,
            mixv_in1_rows.CUDAData(context.GetPlace()),
            in2_data,
329
            in1_row_numel);
H
hong 已提交
330 331 332 333 334 335 336 337
  }
};

template struct SelectedRowsAddToTensor<phi::GPUContext, float>;
template struct SelectedRowsAddToTensor<phi::GPUContext, double>;
template struct SelectedRowsAddToTensor<phi::GPUContext, int>;
template struct SelectedRowsAddToTensor<phi::GPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::GPUContext, platform::float16>;
T
typhoonzero 已提交
338 339 340 341

namespace scatter {

template <typename T, int block_size>
342 343 344 345 346 347
__global__ void MergeAddKernel(const T* input,
                               const int64_t* input_rows,
                               T* out,
                               const int64_t* out_rows,
                               size_t out_rows_size,
                               int64_t row_numel) {
S
sneaxiy 已提交
348
  const int ty = blockIdx.x;
T
typhoonzero 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

369 370 371
template <typename DeviceContext, typename T>
struct MergeAddImpl {
  phi::SelectedRows operator()(const DeviceContext& context,
372 373 374
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
S
sneaxiy 已提交
375 376 377 378
    (*this)(context, input, &out);
    return out;
  }

379 380 381 382
  void operator()(const DeviceContext& context,
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
D
dzhwinter 已提交
383
    framework::Vector<int64_t> input_rows(input.rows());
Q
Qiao Longfei 已提交
384 385 386 387
    if (input_rows.size() == 0) {
      return;
    }

388
    phi::SelectedRows& out = *output;
T
typhoonzero 已提交
389
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
Q
Qiao Longfei 已提交
390 391
    std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
T
typhoonzero 已提交
392 393

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
394 395 396 397

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
398
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
T
typhoonzero 已提交
399 400
        context.GetPlace());

401
    phi::funcs::SetConstant<DeviceContext, T> constant_functor;
C
chengduo 已提交
402
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
T
typhoonzero 已提交
403

T
wip  
typhoonzero 已提交
404
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
405 406 407 408
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
S
sneaxiy 已提交
409
    dim3 grid1(input_rows.size(), 1);
T
typhoonzero 已提交
410

411 412
    paddle::framework::MixVector<int64_t> mix_vector_input(&input_rows);
    paddle::framework::MixVector<int64_t> mix_vector_out(out.mutable_rows());
S
sneaxiy 已提交
413
    MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
414 415 416 417 418
        input_data,
        mix_vector_input.CUDAData(context.GetPlace()),
        out_data,
        mix_vector_out.CUDAMutableData(context.GetPlace()),
        out.rows().size(),
419 420
        input_width);
    mix_vector_out.CopyToCPU();
T
typhoonzero 已提交
421
  }
422

423
  void operator()(const DeviceContext& context,
424
                  const std::vector<const phi::SelectedRows*>& inputs,
425 426
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
427
    if (inputs.size() == 0) {
M
minqiyang 已提交
428
      VLOG(3) << "no input! return";
429 430
      return;
    }
431
    const phi::SelectedRows* has_value_input = nullptr;
432
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
433
      if (in->rows().size() > 0) {
434 435 436 437 438
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
439
      VLOG(3) << "no input has value! just return" << std::endl;
440 441 442 443
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
444
    phi::SelectedRows& out = *output;
445 446
    std::set<int64_t> merged_row_set;
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
447
      if (input->rows().size() == 0) {
448 449
        continue;
      }
450 451
      PADDLE_ENFORCE_EQ(input_width,
                        input->value().dims()[1],
452 453 454
                        platform::errors::InvalidArgument(
                            "All input should have same "
                            "dimension except for the first one."));
455 456
      PADDLE_ENFORCE_EQ(input_height,
                        input->height(),
457 458
                        platform::errors::InvalidArgument(
                            "All input should have same height."));
459 460
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
Q
Qiao Longfei 已提交
461
    std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
Q
format  
Qiao Longfei 已提交
462
                                        merged_row_set.end());
Q
Qiao Longfei 已提交
463
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
464 465 466 467

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
468
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
469 470
        context.GetPlace());

471
    phi::funcs::SetConstant<DeviceContext, T> constant_functor;
C
chengduo 已提交
472
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
473 474 475 476 477 478 479

    auto* out_data = out.mutable_value()->data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);

    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
480
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
481 482
        continue;
      }
483 484
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();
485 486
      dim3 grid1(input_rows.size(), 1);

487 488
      paddle::framework::MixVector<int64_t> mix_vector_input(&input_rows);
      paddle::framework::MixVector<int64_t> mix_vector_out(out.mutable_rows());
489
      MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
490 491 492 493 494
          input_data,
          mix_vector_input.CUDAData(context.GetPlace()),
          out_data,
          mix_vector_out.CUDAMutableData(context.GetPlace()),
          out.rows().size(),
495 496
          input_width);
      mix_vector_out.CopyToCPU();
497 498
    }
  }
T
typhoonzero 已提交
499 500
};

501 502 503 504 505 506 507 508 509 510 511
template <typename T>
struct MergeAdd<phi::GPUContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const phi::GPUContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<phi::GPUContext, T>()(context, input, sorted_result);
  }

  void operator()(const phi::GPUContext& context,
512 513
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
514 515 516 517 518 519
                  const bool sorted_result) {
    MergeAddImpl<phi::GPUContext, T>()(context, input, output, sorted_result);
  }

  void operator()(const phi::GPUContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
520 521
                  phi::SelectedRows* output,
                  const bool sorted_result) {
522 523 524 525
    MergeAddImpl<phi::GPUContext, T>()(context, inputs, output, sorted_result);
  }
};

526 527
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype)        \
  template struct MergeAddImpl<phi::GPUContext, dtype>; \
528 529 530 531 532 533 534 535 536 537
  template struct MergeAdd<phi::GPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int64_t)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::float16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<double>)
T
wip  
typhoonzero 已提交
538 539 540

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
541 542 543 544
                                     const int64_t* rows,
                                     const ScatterOps& op,
                                     T* tensor_out,
                                     int64_t row_numel) {
C
chengduo 已提交
545
  const int ty = blockIdx.x;
T
wip  
typhoonzero 已提交
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
591 592
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
593 594
                  const ScatterOps& op,
                  const phi::SelectedRows& input1,
T
typhoonzero 已提交
595
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
596 597
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
598 599
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
600 601 602

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
603
    PADDLE_ENFORCE_EQ(
604 605
        in1_height,
        in2_dims[0],
606
        platform::errors::InvalidArgument("The two inputs height must be equal."
607
                                          "But received first input height = "
608
                                          "[%d], second input height = [%d]",
609 610
                                          in1_height,
                                          in2_dims[0]));
T
wip  
typhoonzero 已提交
611 612 613 614 615

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
616
    PADDLE_ENFORCE_EQ(
617 618
        in1_row_numel,
        input2->numel() / in1_height,
619 620
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
621
            "But received first input width = [%d], second input width = [%d]",
622 623
            in1_row_numel,
            input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
624

T
typhoonzero 已提交
625 626
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
627

T
typhoonzero 已提交
628
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
C
chengduo 已提交
629
    dim3 grid(in1_rows.size(), 1);
630
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS>
631 632
        <<<grid, threads, 0, context.stream()>>>(
            in1_data, in1_rows.cuda_data(), op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
633 634
  }
};
T
typhoonzero 已提交
635
}  // namespace scatter
636 637 638
}  // namespace math
}  // namespace operators
}  // namespace paddle