selected_rows_functor.cc 35.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
16

H
Huang Jiyi 已提交
17
#include "paddle/phi/core/mixed_vector.h"
18

19 20 21 22
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#endif

L
lidanqing 已提交
23
#ifdef PADDLE_WITH_MKLDNN
24
#include "paddle/phi/backends/onednn/axpy_handler.h"
L
lidanqing 已提交
25 26
#endif

27 28
namespace phi {
namespace funcs {
29
template <typename T>
L
Leo Chen 已提交
30 31
struct SelectedRowsAdd<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
32
                  const phi::SelectedRows& input1,
33 34
                  const phi::SelectedRows& input2,
                  phi::SelectedRows* output) {
35
    auto in1_height = input1.height();
36
    PADDLE_ENFORCE_EQ(
37 38
        in1_height,
        input2.height(),
39 40 41 42 43
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height  = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2.height()));
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    output->set_height(in1_height);

    auto& in1_rows = input1.rows();
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
61
    PADDLE_ENFORCE_EQ(
62 63
        in1_row_numel,
        in2_value.numel() / in2_rows.size(),
64
        phi::errors::InvalidArgument(
65
            "The two inputs width must be equal."
66
            "But received first input width = [%d], second input width = [%d]",
67 68
            in1_row_numel,
            in2_value.numel() / in2_rows.size()));
69
    PADDLE_ENFORCE_EQ(
70 71
        in1_row_numel,
        out_value->numel() / out_rows.size(),
72
        phi::errors::InvalidArgument(
73
            "The input and oupput width must be equal."
74
            "But received input width = [%d], output width = [%d]",
75 76
            in1_row_numel,
            out_value->numel() / out_rows.size()));
77 78

    auto in1_place = input1.place();
79
    PADDLE_ENFORCE_EQ(in1_place.GetType() == phi::AllocationType::CPU,
80
                      true,
81
                      phi::errors::InvalidArgument(
82
                          "The running environment is not on the CPU place."));
83
    auto in2_place = input2.place();
84
    PADDLE_ENFORCE_EQ(in2_place.GetType() == phi::AllocationType::CPU,
85
                      true,
86
                      phi::errors::InvalidArgument(
87
                          "The running environment is not on the CPU place."));
88
    auto out_place = context.GetPlace();
89
    PADDLE_ENFORCE_EQ(out_place.GetType() == phi::AllocationType::CPU,
90
                      true,
91
                      phi::errors::InvalidArgument(
92
                          "The running environment is not on the CPU place."));
93 94 95

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();
96 97 98 99 100
    memory_utils::Copy(out_place,
                       out_data,
                       in1_place,
                       in1_data,
                       in1_value.numel() * sizeof(T));
101 102

    auto* in2_data = in2_value.data<T>();
103 104 105 106 107
    memory_utils::Copy(out_place,
                       out_data + in1_value.numel(),
                       in2_place,
                       in2_data,
                       in2_value.numel() * sizeof(T));
108 109 110
  }
};

L
Leo Chen 已提交
111 112
template struct SelectedRowsAdd<phi::CPUContext, float>;
template struct SelectedRowsAdd<phi::CPUContext, double>;
113 114

template <typename T>
L
Leo Chen 已提交
115 116
struct SelectedRowsAddTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
117
                  const phi::SelectedRows& input1,
118 119
                  const phi::DenseTensor& input2,
                  phi::DenseTensor* output) {
120
    auto in1_height = input1.height();
121 122
    const auto& in2_dims = input2.dims();
    const auto& out_dims = output->dims();
123
    PADDLE_ENFORCE_EQ(
124 125
        in1_height,
        in2_dims[0],
126 127 128 129 130
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
131
    PADDLE_ENFORCE_EQ(
132 133
        in1_height,
        out_dims[0],
134
        phi::errors::InvalidArgument(
135
            "The input and output height must be equal."
136
            "But received input height = [%d], output height = [%d]",
137 138
            in1_height,
            out_dims[0]));
139 140 141 142 143

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
144
    PADDLE_ENFORCE_EQ(
145 146
        in1_row_numel,
        input2.numel() / in1_height,
147
        phi::errors::InvalidArgument(
148
            "The two inputs width must be equal."
149
            "But received first input width = [%d], second input width = [%d]",
150 151
            in1_row_numel,
            input2.numel() / in1_height));
152
    PADDLE_ENFORCE_EQ(
153 154
        in1_row_numel,
        output->numel() / in1_height,
155
        phi::errors::InvalidArgument(
156
            "The input and output width must be equal."
157
            "But received input width = [%d], output width = [%d]",
158 159
            in1_row_numel,
            output->numel() / in1_height));
160

L
Leo Chen 已提交
161
    phi::funcs::SetConstant<phi::CPUContext, T> functor;
162 163 164 165 166 167 168 169 170 171 172 173
    functor(context, output, 0.0);

    auto* in1_data = in1_value.data<T>();
    auto* out_data = output->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        out_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }

174 175
    auto out_eigen = EigenVector<T>::Flatten(*output);
    auto in2_eigen = EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
176
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
177 178 179
  }
};

L
Leo Chen 已提交
180 181
template struct SelectedRowsAddTensor<phi::CPUContext, float>;
template struct SelectedRowsAddTensor<phi::CPUContext, double>;
Q
QI JUN 已提交
182 183

template <typename T>
L
Leo Chen 已提交
184 185
struct SelectedRowsAddTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
186 187
                  const phi::SelectedRows& input1,
                  const int64_t input2_offset,
188
                  phi::SelectedRows* input2) {
Q
QI JUN 已提交
189
    auto in1_height = input1.height();
190
    PADDLE_ENFORCE_EQ(
191 192
        in1_height,
        input2->height(),
193 194 195 196 197
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2->height()));
Q
QI JUN 已提交
198 199 200 201 202 203 204 205

    auto& in1_rows = input1.rows();
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
H
Huang Jiyi 已提交
206
    phi::MixVector<int64_t> mixv_in2_rows(&in2_rows);
207
    mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Q
QI JUN 已提交
208 209

    auto in1_place = input1.place();
210
    PADDLE_ENFORCE_EQ(in1_place.GetType() == phi::AllocationType::CPU,
211
                      true,
212
                      phi::errors::InvalidArgument(
213
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
214
    auto in2_place = input2->place();
215
    PADDLE_ENFORCE_EQ(in2_place.GetType() == phi::AllocationType::CPU,
216
                      true,
217
                      phi::errors::InvalidArgument(
218
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
219 220 221

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
222 223 224 225 226
    memory_utils::Copy(in2_place,
                       in2_data + input2_offset,
                       in1_place,
                       in1_data,
                       in1_value.numel() * sizeof(T));
Q
QI JUN 已提交
227 228 229
  }
};

L
Leo Chen 已提交
230 231 232 233
template struct SelectedRowsAddTo<phi::CPUContext, float>;
template struct SelectedRowsAddTo<phi::CPUContext, double>;
template struct SelectedRowsAddTo<phi::CPUContext, int>;
template struct SelectedRowsAddTo<phi::CPUContext, int64_t>;
Q
QI JUN 已提交
234

M
minqiyang 已提交
235
template <typename T>
L
Leo Chen 已提交
236 237
struct SelectedRowsSumTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
238
                  const std::vector<phi::SelectedRows*>& input1,
M
minqiyang 已提交
239
                  const std::vector<int64_t>& input2_offsets,
240
                  phi::SelectedRows* input2) {
M
minqiyang 已提交
241 242 243 244 245 246
    // Ensure all selected rows have the same height
    size_t size = 0u;
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
      auto& in_rows = (*iter)->rows();
      size += in_rows.end() - in_rows.begin();
      auto in1_height = (*iter)->height();
247 248
      PADDLE_ENFORCE_EQ(in1_height,
                        input2->height(),
249
                        phi::errors::InvalidArgument(
250
                            "The two inputs height must be equal."
251
                            "But received first input height = [%d], second "
252
                            "input height = [%d]",
253 254
                            in1_height,
                            input2->height()));
M
minqiyang 已提交
255 256 257 258 259
    }
    // concat rows
    std::vector<int64_t> in2_rows;
    in2_rows.reserve(in2_rows.size() + size);
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
H
Huang Jiyi 已提交
260
      const phi::Vector<int64_t>& in_rows = (*iter)->rows();
M
minqiyang 已提交
261 262 263 264 265 266
      in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end());
    }
    input2->set_rows(in2_rows);

    auto* in2_value = input2->mutable_value();
    auto* in2_data = in2_value->data<T>();
L
Leo Chen 已提交
267
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
M
minqiyang 已提交
268 269 270 271 272 273 274 275 276 277
    size_t offset = 0u;
    for (size_t i = 0u; i != input1.size(); ++i) {
      auto& in_value = input1[i]->value();
      const auto* in_data = in_value.data<T>();
      offset += input2_offsets[i];
      blas.VCOPY(in_value.numel(), in_data, in2_data + offset);
    }
  }
};

L
Leo Chen 已提交
278 279
template struct SelectedRowsSumTo<phi::CPUContext, float>;
template struct SelectedRowsSumTo<phi::CPUContext, double>;
M
minqiyang 已提交
280

H
hong 已提交
281 282 283
template <typename T>
struct SelectedRowsAddToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
284
                  const phi::SelectedRows& input1,
285
                  phi::DenseTensor* input2) {
H
hong 已提交
286 287 288 289 290
    if (UNLIKELY(input1.rows().size() == 0)) {
      LOG(WARNING) << "input selected rows is empty!";
      return;
    }
    auto in1_height = input1.height();
291
    const auto& in2_dims = input2->dims();
H
hong 已提交
292
    PADDLE_ENFORCE_EQ(
293 294
        in1_height,
        in2_dims[0],
295 296 297 298 299
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
H
hong 已提交
300 301 302 303 304 305

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
306 307
        in1_row_numel,
        input2->numel() / in1_height,
308
        phi::errors::InvalidArgument(
H
hong 已提交
309
            "The two inputs width must be equal."
310
            "But received first input width = [%d], second input width = [%d]",
311 312
            in1_row_numel,
            input2->numel() / in1_height));
H
hong 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }
  }
};

326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
#ifdef PADDLE_WITH_XPU
template <typename T>
struct SelectedRowsAddToTensor<phi::XPUContext, T> {
  void operator()(const phi::XPUContext& context,
                  const phi::SelectedRows& input1,
                  phi::DenseTensor* input2) {
    if (UNLIKELY(input1.rows().size() == 0)) {
      LOG(WARNING) << "input selected rows is empty!";
      return;
    }
    using XPUType = typename XPUTypeTrait<T>::Type;
    auto in1_height = input1.height();
    const auto& in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(
        in1_height,
        in2_dims[0],
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();
    int64_t* in1_rows_data = nullptr;
    xpu::VectorParam<int64_t> in1_rows_vec{
        in1_rows.data(), static_cast<int>(in1_rows.size()), in1_rows_data};

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
        in1_row_numel,
        input2->numel() / in1_height,
        phi::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But received first input width = [%d], second input width = [%d]",
            in1_row_numel,
            input2->numel() / in1_height));

    auto* in1_data = in1_value.data<T>();
    auto* out_data = input2->data<T>();

    int h = in1_rows.size();
    int w = in1_row_numel;
    const std::vector<int> xshape{h, w};

    int r = xpu::scatter<XPUType, int64_t>(
        context.x_context(),
        nullptr,
        reinterpret_cast<const XPUType*>(in1_data),
        reinterpret_cast<XPUType*>(out_data),
        in1_rows_vec,
        xshape,
        0,
        false);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter");
  }
};

#endif

H
hong 已提交
386 387 388 389
template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
390
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;
391 392 393 394

#ifdef PADDLE_WITH_XPU
template struct SelectedRowsAddToTensor<phi::XPUContext, float>;
#endif
T
typhoonzero 已提交
395 396 397 398 399 400 401 402
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
// Another group of functors is called "scatter updates", which means
// use SelectedRows to update a dense tensor with different Ops, like
// add or mul.
namespace scatter {

403
template <typename T, typename DeviceContext>
404
typename std::enable_if<!std::is_integral<T>::value>::type elementwise_add_to(
405 406 407
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
408
    T* out) {
409
  blas->AXPY(data_len, T(1.f), in, out);
Q
Qiao Longfei 已提交
410 411
}

412
template <typename T, typename DeviceContext>
413
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
414 415 416
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
417
    T* out) {
T
Tao Luo 已提交
418
  for (size_t i = 0; i < data_len; i++) {
Q
Qiao Longfei 已提交
419 420
    out[i] += in[i];
  }
T
typhoonzero 已提交
421 422
}

423
template <typename T, typename DeviceContext>
424
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
425
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
426
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
427 428
                  int64_t input_width,
                  const DeviceContext& context,
429
                  T* out_data) {
430
#ifndef PADDLE_WITH_MKLDNN
431
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
432 433 434 435 436 437 438 439 440
#endif
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

#ifdef PADDLE_WITH_MKLDNN
441 442 443
    OneDNNContext onednn_context(context.GetPlace());
    funcs::OneDNNAXPYHandler<T> axpy_handler(
        input_width, T(1.f), onednn_context.GetEngine());
444 445 446 447 448 449 450 451
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
      axpy_handler(&input_data[i * input_width],
                   &out_data[out_i * input_width]);
    }
#else
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
452 453 454 455
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
456 457 458 459 460
    }
#endif
  }
}

461
template <typename T, typename DeviceContext>
462
typename std::enable_if<!std::is_same<T, phi::dtype::bfloat16>::value>::type
463
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
464
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
465 466
                  int64_t input_width,
                  const DeviceContext& context,
467
                  T* out_data) {
468
  VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name();
469
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
470 471 472 473 474 475 476 477 478
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
479 480 481 482
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
483 484 485 486
    }
  }
}

487 488 489
template <typename DeviceContext, typename T>
struct MergeAddImpl {
  phi::SelectedRows operator()(const DeviceContext& context,
490 491 492
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
493
    (*this)(context, input, &out, sorted_result);
S
sneaxiy 已提交
494 495 496
    return out;
  }

497 498 499 500
  void operator()(const DeviceContext& context,
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
501
    std::vector<const phi::SelectedRows*> inputs;
502
    inputs.push_back(&input);
503
    (*this)(context, inputs, output, sorted_result);
504
  }
T
typhoonzero 已提交
505

506
  void operator()(const DeviceContext& context,
507
                  const std::vector<const phi::SelectedRows*>& inputs,
508 509
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
Q
Qiao Longfei 已提交
510
    if (inputs.size() == 0) {
M
minqiyang 已提交
511
      VLOG(3) << "no input! return";
Q
Qiao Longfei 已提交
512 513
      return;
    }
514
    const phi::SelectedRows* has_value_input = nullptr;
Q
Qiao Longfei 已提交
515
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
516
      if (in->rows().size() > 0) {
Q
Qiao Longfei 已提交
517 518 519 520 521
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
522
      VLOG(3) << "no input has value! just return" << std::endl;
Q
Qiao Longfei 已提交
523 524 525 526
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
527
    phi::SelectedRows& out = *output;
528
    std::set<int64_t> merged_row_set;
529
    size_t row_num = 0;
530
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
531
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
532 533
        continue;
      }
534 535 536 537 538 539 540 541 542
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
543
      row_num += input->rows().size();
544 545
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
546

547
    out.set_height(input_height);
548 549 550 551
    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(phi::make_ddim(
        {static_cast<int64_t>(merged_row_set.size()), input_width}));
    auto* out_data = context.template Alloc<T>(out_tensor);
T
typhoonzero 已提交
552

553 554 555 556 557 558
    if (merged_row_set.size() == row_num && !sorted_result) {
      // no duplicated ids, just concat the result together
      std::vector<int64_t> merge_rows;
      merge_rows.reserve(row_num);
      // concat rows
      for (auto* in : inputs) {
559 560
        merge_rows.insert(
            merge_rows.end(), in->rows().begin(), in->rows().end());
561 562 563 564 565 566 567
      }
      out.set_rows(merge_rows);
      auto in_place = inputs[0]->place();
      auto out_place = out.place();
      int64_t copied_numel = 0;
      for (auto* in : inputs) {
        auto* in_data = in->value().data<T>();
568
        auto in_numel = in->rows().size() * input_width;
569 570 571 572 573
        memory_utils::Copy(out_place,
                           out_data + copied_numel,
                           in_place,
                           in_data,
                           in_numel * sizeof(T));
574 575 576 577 578
        copied_numel += in_numel;
      }
    } else {
      std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                      merged_row_set.end());
T
typhoonzero 已提交
579

580 581 582
      if (sorted_result) {
        std::sort(merge_rows.begin(), merge_rows.end());
      }
T
typhoonzero 已提交
583

584 585
      out.set_rows(merge_rows);

586
      phi::funcs::SetConstant<DeviceContext, T> constant_functor;
587
      constant_functor(context, out.mutable_value(), static_cast<T>(0.f));
588 589 590 591

      std::unordered_map<int64_t, size_t> rows_to_id;
      for (size_t i = 0; i < merge_rows.size(); ++i) {
        rows_to_id[merge_rows[i]] = i;
Q
Qiao Longfei 已提交
592
      }
593

594 595
      add_sparse_inputs<T, DeviceContext>(
          inputs, rows_to_id, input_width, context, out_data);
T
typhoonzero 已提交
596
    }
T
wip  
typhoonzero 已提交
597 598 599
  }
};

600 601 602 603 604 605 606 607 608 609 610
template <typename T>
struct MergeAdd<phi::CPUContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const phi::CPUContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<phi::CPUContext, T>()(context, input, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
611 612
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
613 614 615 616 617 618
                  const bool sorted_result) {
    MergeAddImpl<phi::CPUContext, T>()(context, input, output, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
619 620
                  phi::SelectedRows* output,
                  const bool sorted_result) {
621 622 623 624
    MergeAddImpl<phi::CPUContext, T>()(context, inputs, output, sorted_result);
  }
};

L
Leo Chen 已提交
625 626
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype)    \
  template struct MergeAddImpl<phi::CPUContext, dtype>; \
627 628 629 630 631 632
  template struct MergeAdd<phi::CPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t)
633 634 635
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<double>)
636

637 638
#ifdef PADDLE_WITH_XPU
template <typename T>
639 640
struct MergeAdd<phi::XPUContext, T> {
  phi::SelectedRows operator()(const phi::XPUContext& context,
641 642 643
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
644 645 646 647
    (*this)(context, input, &out, sorted_result);
    return out;
  }

648
  void operator()(const phi::XPUContext& context,
649 650
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
651
                  const bool sorted_result = false) {
H
Huang Jiyi 已提交
652
    phi::Vector<int64_t> input_rows(input.rows());
653 654 655 656
    if (input_rows.size() == 0) {
      return;
    }

657
    phi::SelectedRows& out = *output;
658 659 660 661 662 663
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
    std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
    auto input_width = input.value().dims()[1];

    out.set_rows(merge_rows);
    out.set_height(input.height());
664 665 666 667
    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}));
    context.template Alloc<T>(out_tensor);
668 669 670 671 672 673

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

674 675 676 677
    auto* y_data = out.mutable_value()->data<T>();
    auto* x_data = input.value().data<T>();
    int xm = input_rows.size();
    int ym = merge_rows.size();
678
    int n = input_width;
679 680 681 682

    xpu::ctx_guard RAII_GUARD(context.x_context());
    int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
    int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
683 684 685 686 687 688 689 690 691 692
    memory_utils::Copy(context.GetPlace(),
                       y_rows_data,
                       phi::CPUPlace(),
                       merge_rows.data(),
                       ym * sizeof(int64_t));
    memory_utils::Copy(context.GetPlace(),
                       x_rows_data,
                       phi::CPUPlace(),
                       input_rows.data(),
                       xm * sizeof(int64_t));
693 694 695 696 697 698 699 700
    int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                            x_data,
                                            y_data,
                                            x_rows_data,
                                            y_rows_data,
                                            xm,
                                            n,
                                            ym);
701
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
702 703
  }

704
  void operator()(const phi::XPUContext& context,
705
                  const std::vector<const phi::SelectedRows*>& inputs,
706 707
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
708 709 710 711
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
712
    const phi::SelectedRows* has_value_input = nullptr;
713 714 715 716 717 718 719 720 721 722 723 724
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
725
    phi::SelectedRows& out = *output;
726 727 728 729 730 731
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
732 733 734 735 736 737 738 739 740
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
741 742 743 744 745 746 747 748 749 750 751 752 753 754
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());

    if (sorted_result) {
      std::sort(merge_rows.begin(), merge_rows.end());
    }

    out.set_rows(merge_rows);
    out.set_height(input_height);

755 756 757 758 759 760
    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(phi::make_ddim(
        {static_cast<int64_t>(merged_row_set.size()), input_width}));
    context.template Alloc<T>(out_tensor);

    float* y_data = reinterpret_cast<float*>(out_tensor->data<T>());
761 762 763 764 765 766 767 768 769 770 771 772

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto& input_rows = input->rows();

773 774 775
      auto* x_data = input->value().data<T>();
      int xm = input_rows.size();
      int ym = merge_rows.size();
776
      int n = input_width;
777 778 779 780

      xpu::ctx_guard RAII_GUARD(context.x_context());
      int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
      int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
781 782 783 784 785 786 787 788 789 790
      memory_utils::Copy(context.GetPlace(),
                         y_rows_data,
                         phi::CPUPlace(),
                         merge_rows.data(),
                         ym * sizeof(int64_t));
      memory_utils::Copy(context.GetPlace(),
                         x_rows_data,
                         phi::CPUPlace(),
                         input_rows.data(),
                         xm * sizeof(int64_t));
791 792 793 794 795 796 797 798
      int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                              x_data,
                                              y_data,
                                              x_rows_data,
                                              y_rows_data,
                                              xm,
                                              n,
                                              ym);
799
      PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
800 801 802 803 804
    }
  }
};

#endif
805
template <typename T>
L
Leo Chen 已提交
806 807
struct MergeAverage<phi::CPUContext, T> {
  phi::SelectedRows operator()(const phi::CPUContext& context,
808 809
                               const phi::SelectedRows& input) {
    phi::SelectedRows out;
810 811 812 813
    (*this)(context, input, &out);
    return out;
  }

L
Leo Chen 已提交
814
  void operator()(const phi::CPUContext& context,
815 816
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output) {
817
    std::vector<const phi::SelectedRows*> inputs;
818 819 820 821
    inputs.push_back(&input);
    (*this)(context, inputs, output);
  }

L
Leo Chen 已提交
822
  void operator()(const phi::CPUContext& context,
823 824
                  const std::vector<const phi::SelectedRows*>& inputs,
                  phi::SelectedRows* output) {
825 826 827 828
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
829
    const phi::SelectedRows* has_value_input = nullptr;
830 831 832 833 834 835 836 837 838 839 840 841
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
842
    phi::SelectedRows& out = *output;
843 844 845 846 847 848
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
849 850 851 852 853 854 855 856 857
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All input should have same height."));
858 859 860 861 862
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    out.set_height(input_height);
863 864 865 866 867

    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(phi::make_ddim(
        {static_cast<int64_t>(merged_row_set.size()), input_width}));
    auto* out_data = context.template Alloc<T>(out_tensor);
868 869 870 871 872 873 874

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());
    std::sort(merge_rows.begin(), merge_rows.end());

    out.set_rows(merge_rows);

L
Leo Chen 已提交
875
    phi::funcs::SetConstant<phi::CPUContext, T> constant_functor;
876 877 878 879 880 881 882
    constant_functor(context, out.mutable_value(), 0.0);

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

L
Leo Chen 已提交
883
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
884 885 886 887 888 889 890 891 892
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();

      for (size_t i = 0; i < input_rows.size(); i++) {
        size_t out_i = rows_to_id[input_rows[i]];
893 894
        elementwise_add_to<T>(&blas,
                              static_cast<size_t>(input_width),
895 896
                              &input_data[i * input_width],
                              &out_data[out_i * input_width]);
897 898 899 900 901 902 903 904 905 906 907 908
      }
    }
    size_t input_width_cast = static_cast<size_t>(input_width);
    T count = static_cast<T>(inputs.size());
    for (size_t i = 0; i < merge_rows.size(); i++) {
      for (size_t j = 0; j < input_width_cast; j++) {
        out_data[i * input_width + j] = out_data[i * input_width + j] / count;
      }
    }
  }
};

909
#ifdef PADDLE_WITH_XPU
910
template struct MergeAdd<phi::XPUContext, float>;
911 912
#endif

L
Leo Chen 已提交
913 914 915 916
template struct MergeAverage<phi::CPUContext, int>;
template struct MergeAverage<phi::CPUContext, int64_t>;
template struct MergeAverage<phi::CPUContext, float>;
template struct MergeAverage<phi::CPUContext, double>;
917

T
wip  
typhoonzero 已提交
918
template <typename T>
L
Leo Chen 已提交
919 920
struct UpdateToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
921 922
                  const ScatterOps& op,
                  const phi::SelectedRows& input1,
923
                  phi::DenseTensor* input2) {
T
wip  
typhoonzero 已提交
924
    auto in1_height = input1.height();
925
    const auto& in2_dims = input2->dims();
926
    PADDLE_ENFORCE_EQ(
927 928
        in1_height,
        in2_dims[0],
929 930 931 932 933
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
T
wip  
typhoonzero 已提交
934 935 936 937 938

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
939
    PADDLE_ENFORCE_EQ(
940 941
        in1_row_numel,
        input2->numel() / in1_height,
942 943 944 945 946
        phi::errors::InvalidArgument("The two inputs width must be equal."
                                     "But received first input width = [%d], "
                                     "second input width = [%d]",
                                     in1_row_numel,
                                     input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    // FIXME(typhoonzero): use macro fix the below messy code.
    switch (op) {
      case ScatterOps::ASSIGN:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::ADD:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUB:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] -=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUBBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] -
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
      case ScatterOps::MUL:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] *=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIV:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] /=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIVBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] /
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
    }
T
typhoonzero 已提交
991 992 993 994
  }
};

}  // namespace scatter
995 996
}  // namespace funcs
}  // namespace phi