selected_rows_functor.cu 29.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15
#include <set>
16
#include <vector>
T
typhoonzero 已提交
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/selected_rows_functor.h"
19
#include "paddle/fluid/platform/bfloat16.h"
20
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
C
chengduo 已提交
21
#include "paddle/fluid/platform/float16.h"
22
#include "paddle/phi/kernels/funcs/math_function.h"
23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
28 29
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
30
                  const phi::SelectedRows& input1,
31 32
                  const phi::SelectedRows& input2,
                  phi::SelectedRows* output) {
33
    auto in1_height = input1.height();
34
    PADDLE_ENFORCE_EQ(
35 36
        in1_height,
        input2.height(),
37
        platform::errors::InvalidArgument("The two inputs height must be equal."
38
                                          "But received first input height  = "
39
                                          "[%d], second input height = [%d]",
40 41
                                          in1_height,
                                          input2.height()));
42 43
    output->set_height(in1_height);

D
dzhwinter 已提交
44
    framework::Vector<int64_t> in1_rows(input1.rows());
45 46 47 48 49 50 51 52 53 54 55 56 57 58
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
59
    PADDLE_ENFORCE_EQ(
60 61
        in1_row_numel,
        in2_value.numel() / in2_rows.size(),
62 63
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
64
            "But received first input width = [%d], second input width = [%d]",
65 66
            in1_row_numel,
            in2_value.numel() / in2_rows.size()));
67
    PADDLE_ENFORCE_EQ(
68 69
        in1_row_numel,
        out_value->numel() / out_rows.size(),
70 71
        platform::errors::InvalidArgument(
            "The input and oupput width must be equal."
72
            "But received input width = [%d], output width = [%d]",
73 74
            in1_row_numel,
            out_value->numel() / out_rows.size()));
75 76 77 78 79

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
80 81
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place),
                      true,
82
                      platform::errors::InvalidArgument(
83
                          "The running environment is not on the GPU place."));
84
    auto in2_place = input2.place();
85 86
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in2_place),
                      true,
87
                      platform::errors::InvalidArgument(
88
                          "The running environment is not on the GPU place."));
89
    auto out_place = context.GetPlace();
90 91
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_place),
                      true,
92
                      platform::errors::InvalidArgument(
93
                          "The running environment is not on the GPU place."));
94

95 96 97 98 99 100
    memory::Copy(out_place,
                 out_data,
                 in1_place,
                 in1_data,
                 in1_value.numel() * sizeof(T),
                 context.stream());
101 102

    auto* in2_data = in2_value.data<T>();
103 104 105 106 107 108
    memory::Copy(out_place,
                 out_data + in1_value.numel(),
                 in2_place,
                 in2_data,
                 in2_value.numel() * sizeof(T),
                 context.stream());
109 110 111
  }
};

Q
QI JUN 已提交
112 113
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
114 115

namespace {
Q
QI JUN 已提交
116
template <typename T, int block_size>
117
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
118 119
                                            const int64_t* rows,
                                            T* tensor_out,
Q
QI JUN 已提交
120
                                            int64_t row_numel) {
C
chengduo 已提交
121
  const int ty = blockIdx.x;
122 123 124 125 126 127 128 129 130
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
131
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
132 133 134 135 136
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
137 138
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
139
                  const phi::SelectedRows& input1,
140 141
                  const framework::Tensor& input2,
                  framework::Tensor* output) {
142 143 144
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
145
    PADDLE_ENFORCE_EQ(
146 147
        in1_height,
        in2_dims[0],
148 149
        platform::errors::InvalidArgument(
            "The two inputs height must be equal."
150
            "But received first input height = [%d], first input height = [%d]",
151 152
            in1_height,
            in2_dims[0]));
153
    PADDLE_ENFORCE_EQ(
154 155
        in1_height,
        out_dims[0],
156 157
        platform::errors::InvalidArgument(
            "The input and output height must be equal."
158
            "But received input height = [%d], output height = [%d]",
159 160
            in1_height,
            out_dims[0]));
161 162

    auto& in1_value = input1.value();
163
    auto& in1_rows = input1.rows();
164 165

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
166
    PADDLE_ENFORCE_EQ(
167 168
        in1_row_numel,
        input2.numel() / in1_height,
169 170
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
171
            "But received first input width = [%d], second input width = [%d]",
172 173
            in1_row_numel,
            input2.numel() / in1_height));
174
    PADDLE_ENFORCE_EQ(
175 176
        in1_row_numel,
        output->numel() / in1_height,
177 178
        platform::errors::InvalidArgument(
            "The input and output width must be equal."
179
            "But received input width = [%d], output width = [%d]",
180 181
            in1_row_numel,
            output->numel() / in1_height));
182 183 184 185 186

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

187
    phi::funcs::SetConstant<platform::CUDADeviceContext, T> functor;
C
chengduo 已提交
188
    functor(context, output, static_cast<T>(0));
189

Q
QI JUN 已提交
190
    const int block_size = 256;
191
    dim3 threads(block_size, 1);
C
chengduo 已提交
192
    dim3 grid(in1_rows.size(), 1);
193
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
194 195
    SelectedRowsAddTensorKernel<T, block_size>
        <<<grid, threads, 0, context.stream()>>>(
196 197 198
            in1_data,
            mixv_in1_rows.CUDAData(context.GetPlace()),
            out_data,
199
            in1_row_numel);
200 201 202

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
203
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
204 205 206
  }
};

H
hong 已提交
207 208 209 210
template <typename T>
struct SelectedRowsAddTensor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
                  const phi::SelectedRows& input1,
211 212
                  const framework::Tensor& input2,
                  framework::Tensor* output) {
H
hong 已提交
213 214 215 216
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
    PADDLE_ENFORCE_EQ(
217 218
        in1_height,
        in2_dims[0],
H
hong 已提交
219 220
        platform::errors::InvalidArgument(
            "The two inputs height must be equal."
221
            "But received first input height = [%d], first input height = [%d]",
222 223
            in1_height,
            in2_dims[0]));
H
hong 已提交
224
    PADDLE_ENFORCE_EQ(
225 226
        in1_height,
        out_dims[0],
H
hong 已提交
227 228
        platform::errors::InvalidArgument(
            "The input and output height must be equal."
229
            "But received input height = [%d], output height = [%d]",
230 231
            in1_height,
            out_dims[0]));
H
hong 已提交
232 233 234 235 236 237

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
238 239
        in1_row_numel,
        input2.numel() / in1_height,
H
hong 已提交
240 241
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
242
            "But received first input width = [%d], second input width = [%d]",
243 244
            in1_row_numel,
            input2.numel() / in1_height));
H
hong 已提交
245
    PADDLE_ENFORCE_EQ(
246 247
        in1_row_numel,
        output->numel() / in1_height,
H
hong 已提交
248 249
        platform::errors::InvalidArgument(
            "The input and output width must be equal."
250
            "But received input width = [%d], output width = [%d]",
251 252
            in1_row_numel,
            output->numel() / in1_height));
H
hong 已提交
253 254 255 256 257 258 259 260 261 262 263 264

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

    phi::funcs::SetConstant<phi::GPUContext, T> functor;
    functor(context, output, static_cast<T>(0));

    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid(in1_rows.size(), 1);
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
265 266
    SelectedRowsAddTensorKernel<T, block_size>
        <<<grid, threads, 0, context.stream()>>>(
267 268 269
            in1_data,
            mixv_in1_rows.CUDAData(context.GetPlace()),
            out_data,
270
            in1_row_numel);
H
hong 已提交
271 272 273 274 275 276 277

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
  }
};

Q
QI JUN 已提交
278 279
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
C
chengduo 已提交
280 281 282
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
                                      platform::float16>;
Q
QI JUN 已提交
283

H
hong 已提交
284 285 286 287 288
template struct SelectedRowsAddTensor<phi::GPUContext, float>;
template struct SelectedRowsAddTensor<phi::GPUContext, double>;
template struct SelectedRowsAdd<phi::GPUContext, platform::float16>;
template struct SelectedRowsAddTensor<phi::GPUContext, platform::float16>;

Q
QI JUN 已提交
289
template <typename T>
Q
QI JUN 已提交
290 291
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
292 293
                  const phi::SelectedRows& input1,
                  const int64_t input2_offset,
294
                  phi::SelectedRows* input2) {
Q
QI JUN 已提交
295
    auto in1_height = input1.height();
296
    PADDLE_ENFORCE_EQ(
297 298
        in1_height,
        input2->height(),
299
        platform::errors::InvalidArgument("The two inputs height must be equal."
300
                                          "But received first input height = "
301
                                          "[%d], second input height = [%d]",
302 303
                                          in1_height,
                                          input2->height()));
Q
QI JUN 已提交
304

305
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
306 307 308 309 310 311
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
312
    paddle::framework::MixVector<int64_t> mixv_in2_rows(&in2_rows);
Y
Fix CI  
Yu Yang 已提交
313
    if (in1_rows.size()) {
314
      mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Y
Fix CI  
Yu Yang 已提交
315
    }
Q
QI JUN 已提交
316 317

    auto in1_place = input1.place();
318 319
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place),
                      true,
320
                      platform::errors::InvalidArgument(
321
                          "The running environment is not on the GPU place."));
Q
QI JUN 已提交
322
    auto in2_place = input2->place();
323 324
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place),
                      true,
325
                      platform::errors::InvalidArgument(
326
                          "The running environment is not on the GPU place."));
Q
QI JUN 已提交
327 328 329

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
330 331 332 333 334 335
    memory::Copy(in2_place,
                 in2_data + input2_offset,
                 in1_place,
                 in1_data,
                 in1_value.numel() * sizeof(T),
                 context.stream());
Q
QI JUN 已提交
336 337 338
  }
};

Q
QI JUN 已提交
339 340 341 342
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
343 344
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
                                  platform::float16>;
Q
QI JUN 已提交
345 346 347 348 349 350 351

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
C
chengduo 已提交
352
  const int ty = blockIdx.x;
Q
QI JUN 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
367 368
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
369 370
                  const phi::SelectedRows& input1,
                  framework::Tensor* input2) {
Q
QI JUN 已提交
371 372
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
373
    PADDLE_ENFORCE_EQ(
374 375
        in1_height,
        in2_dims[0],
376
        platform::errors::InvalidArgument("The two inputs height must be equal."
377
                                          "But received first input height = "
378
                                          "[%d], second input height = [%d]",
379 380
                                          in1_height,
                                          in2_dims[0]));
Q
QI JUN 已提交
381 382

    auto& in1_value = input1.value();
383
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
384 385

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
386
    PADDLE_ENFORCE_EQ(
387 388
        in1_row_numel,
        input2->numel() / in1_height,
389 390
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
391
            "But received first input width = [%d], second input width = [%d]",
392 393
            in1_row_numel,
            input2->numel() / in1_height));
Q
QI JUN 已提交
394 395 396 397 398

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
C
chengduo 已提交
399
    dim3 grid(in1_rows.size(), 1);
400
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
401 402
    SelectedRowsAddToTensorKernel<T, block_size>
        <<<grid, threads, 0, context.stream()>>>(
403 404 405
            in1_data,
            mixv_in1_rows.CUDAData(context.GetPlace()),
            in2_data,
406
            in1_row_numel);
Q
QI JUN 已提交
407 408 409
  }
};

H
hong 已提交
410 411 412
template <typename T>
struct SelectedRowsAddToTensor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
413 414
                  const phi::SelectedRows& input1,
                  framework::Tensor* input2) {
H
hong 已提交
415 416 417
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(
418 419
        in1_height,
        in2_dims[0],
H
hong 已提交
420
        platform::errors::InvalidArgument("The two inputs height must be equal."
421
                                          "But received first input height = "
H
hong 已提交
422
                                          "[%d], second input height = [%d]",
423 424
                                          in1_height,
                                          in2_dims[0]));
H
hong 已提交
425 426 427 428 429 430

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
431 432
        in1_row_numel,
        input2->numel() / in1_height,
H
hong 已提交
433 434
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
435
            "But received first input width = [%d], second input width = [%d]",
436 437
            in1_row_numel,
            input2->numel() / in1_height));
H
hong 已提交
438 439 440 441 442 443 444

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
    dim3 grid(in1_rows.size(), 1);
    paddle::framework::MixVector<int64_t> mixv_in1_rows(&in1_rows);
445 446
    SelectedRowsAddToTensorKernel<T, block_size>
        <<<grid, threads, 0, context.stream()>>>(
447 448 449
            in1_data,
            mixv_in1_rows.CUDAData(context.GetPlace()),
            in2_data,
450
            in1_row_numel);
H
hong 已提交
451 452 453
  }
};

Q
QI JUN 已提交
454 455 456 457
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
458 459
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
                                        platform::float16>;
H
hong 已提交
460 461 462 463 464
template struct SelectedRowsAddToTensor<phi::GPUContext, float>;
template struct SelectedRowsAddToTensor<phi::GPUContext, double>;
template struct SelectedRowsAddToTensor<phi::GPUContext, int>;
template struct SelectedRowsAddToTensor<phi::GPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::GPUContext, platform::float16>;
T
typhoonzero 已提交
465 466 467 468

namespace scatter {

template <typename T, int block_size>
469 470 471 472 473 474
__global__ void MergeAddKernel(const T* input,
                               const int64_t* input_rows,
                               T* out,
                               const int64_t* out_rows,
                               size_t out_rows_size,
                               int64_t row_numel) {
S
sneaxiy 已提交
475
  const int ty = blockIdx.x;
T
typhoonzero 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

496 497 498
template <typename DeviceContext, typename T>
struct MergeAddImpl {
  phi::SelectedRows operator()(const DeviceContext& context,
499 500 501
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
S
sneaxiy 已提交
502 503 504 505
    (*this)(context, input, &out);
    return out;
  }

506 507 508 509
  void operator()(const DeviceContext& context,
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
D
dzhwinter 已提交
510
    framework::Vector<int64_t> input_rows(input.rows());
Q
Qiao Longfei 已提交
511 512 513 514
    if (input_rows.size() == 0) {
      return;
    }

515
    phi::SelectedRows& out = *output;
T
typhoonzero 已提交
516
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
Q
Qiao Longfei 已提交
517 518
    std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
T
typhoonzero 已提交
519 520

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
521 522 523 524

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
525
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
T
typhoonzero 已提交
526 527
        context.GetPlace());

528
    phi::funcs::SetConstant<DeviceContext, T> constant_functor;
C
chengduo 已提交
529
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
T
typhoonzero 已提交
530

T
wip  
typhoonzero 已提交
531
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
532 533 534 535
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
S
sneaxiy 已提交
536
    dim3 grid1(input_rows.size(), 1);
T
typhoonzero 已提交
537

538 539
    paddle::framework::MixVector<int64_t> mix_vector_input(&input_rows);
    paddle::framework::MixVector<int64_t> mix_vector_out(out.mutable_rows());
S
sneaxiy 已提交
540
    MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
541 542 543 544 545
        input_data,
        mix_vector_input.CUDAData(context.GetPlace()),
        out_data,
        mix_vector_out.CUDAMutableData(context.GetPlace()),
        out.rows().size(),
546 547
        input_width);
    mix_vector_out.CopyToCPU();
T
typhoonzero 已提交
548
  }
549

550
  void operator()(const DeviceContext& context,
551
                  const std::vector<const phi::SelectedRows*>& inputs,
552 553
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
554
    if (inputs.size() == 0) {
M
minqiyang 已提交
555
      VLOG(3) << "no input! return";
556 557
      return;
    }
558
    const phi::SelectedRows* has_value_input = nullptr;
559
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
560
      if (in->rows().size() > 0) {
561 562 563 564 565
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
566
      VLOG(3) << "no input has value! just return" << std::endl;
567 568 569 570
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
571
    phi::SelectedRows& out = *output;
572 573
    std::set<int64_t> merged_row_set;
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
574
      if (input->rows().size() == 0) {
575 576
        continue;
      }
577 578
      PADDLE_ENFORCE_EQ(input_width,
                        input->value().dims()[1],
579 580 581
                        platform::errors::InvalidArgument(
                            "All input should have same "
                            "dimension except for the first one."));
582 583
      PADDLE_ENFORCE_EQ(input_height,
                        input->height(),
584 585
                        platform::errors::InvalidArgument(
                            "All input should have same height."));
586 587
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
Q
Qiao Longfei 已提交
588
    std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
Q
format  
Qiao Longfei 已提交
589
                                        merged_row_set.end());
Q
Qiao Longfei 已提交
590
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
591 592 593 594

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
595
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
596 597
        context.GetPlace());

598
    phi::funcs::SetConstant<DeviceContext, T> constant_functor;
C
chengduo 已提交
599
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
600 601 602 603 604 605 606

    auto* out_data = out.mutable_value()->data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);

    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
607
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
608 609
        continue;
      }
610 611
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();
612 613
      dim3 grid1(input_rows.size(), 1);

614 615
      paddle::framework::MixVector<int64_t> mix_vector_input(&input_rows);
      paddle::framework::MixVector<int64_t> mix_vector_out(out.mutable_rows());
616
      MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
617 618 619 620 621
          input_data,
          mix_vector_input.CUDAData(context.GetPlace()),
          out_data,
          mix_vector_out.CUDAMutableData(context.GetPlace()),
          out.rows().size(),
622 623
          input_width);
      mix_vector_out.CopyToCPU();
624 625
    }
  }
T
typhoonzero 已提交
626 627
};

628 629 630 631 632 633 634
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const platform::CUDADeviceContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
635 636
    return MergeAddImpl<platform::CUDADeviceContext, T>()(
        context, input, sorted_result);
637 638 639
  }

  void operator()(const platform::CUDADeviceContext& context,
640 641
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
642
                  const bool sorted_result) {
643 644
    MergeAddImpl<platform::CUDADeviceContext, T>()(
        context, input, output, sorted_result);
645 646 647 648
  }

  void operator()(const platform::CUDADeviceContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
649 650 651 652
                  phi::SelectedRows* output,
                  const bool sorted_result) {
    MergeAddImpl<platform::CUDADeviceContext, T>()(
        context, inputs, output, sorted_result);
653 654 655 656 657 658 659 660 661 662 663 664 665 666
  }
};

template <typename T>
struct MergeAdd<phi::GPUContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const phi::GPUContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<phi::GPUContext, T>()(context, input, sorted_result);
  }

  void operator()(const phi::GPUContext& context,
667 668
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
669 670 671 672 673 674
                  const bool sorted_result) {
    MergeAddImpl<phi::GPUContext, T>()(context, input, output, sorted_result);
  }

  void operator()(const phi::GPUContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
675 676
                  phi::SelectedRows* output,
                  const bool sorted_result) {
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
    MergeAddImpl<phi::GPUContext, T>()(context, inputs, output, sorted_result);
  }
};

#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype)                    \
  template struct MergeAddImpl<platform::CUDADeviceContext, dtype>; \
  template struct MergeAddImpl<phi::GPUContext, dtype>;             \
  template struct MergeAdd<platform::CUDADeviceContext, dtype>;     \
  template struct MergeAdd<phi::GPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int64_t)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::float16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<double>)
T
wip  
typhoonzero 已提交
695 696 697

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
698 699 700 701
                                     const int64_t* rows,
                                     const ScatterOps& op,
                                     T* tensor_out,
                                     int64_t row_numel) {
C
chengduo 已提交
702
  const int ty = blockIdx.x;
T
wip  
typhoonzero 已提交
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
748 749
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
750 751
                  const ScatterOps& op,
                  const phi::SelectedRows& input1,
T
typhoonzero 已提交
752
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
753 754
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
755 756
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
757 758 759

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
760
    PADDLE_ENFORCE_EQ(
761 762
        in1_height,
        in2_dims[0],
763
        platform::errors::InvalidArgument("The two inputs height must be equal."
764
                                          "But received first input height = "
765
                                          "[%d], second input height = [%d]",
766 767
                                          in1_height,
                                          in2_dims[0]));
T
wip  
typhoonzero 已提交
768 769 770 771 772

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
773
    PADDLE_ENFORCE_EQ(
774 775
        in1_row_numel,
        input2->numel() / in1_height,
776 777
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
778
            "But received first input width = [%d], second input width = [%d]",
779 780
            in1_row_numel,
            input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
781

T
typhoonzero 已提交
782 783
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
784

T
typhoonzero 已提交
785
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
C
chengduo 已提交
786
    dim3 grid(in1_rows.size(), 1);
787
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS>
788 789
        <<<grid, threads, 0, context.stream()>>>(
            in1_data, in1_rows.cuda_data(), op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
790 791
  }
};
T
typhoonzero 已提交
792
}  // namespace scatter
793 794 795
}  // namespace math
}  // namespace operators
}  // namespace paddle