selected_rows_functor.cc 36.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
16

17
#include "paddle/fluid/framework/mixed_vector.h"
18
#include "paddle/fluid/platform/device/device_wrapper.h"
19

L
lidanqing 已提交
20
#ifdef PADDLE_WITH_MKLDNN
21
#include "paddle/phi/backends/onednn/axpy_handler.h"
L
lidanqing 已提交
22 23
#endif

24 25
namespace phi {
namespace funcs {
26
template <typename T>
L
Leo Chen 已提交
27 28
struct SelectedRowsAdd<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
29
                  const phi::SelectedRows& input1,
30 31
                  const phi::SelectedRows& input2,
                  phi::SelectedRows* output) {
32
    auto in1_height = input1.height();
33
    PADDLE_ENFORCE_EQ(
34 35
        in1_height,
        input2.height(),
36 37 38 39 40
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height  = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2.height()));
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    output->set_height(in1_height);

    auto& in1_rows = input1.rows();
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
58
    PADDLE_ENFORCE_EQ(
59 60
        in1_row_numel,
        in2_value.numel() / in2_rows.size(),
61
        phi::errors::InvalidArgument(
62
            "The two inputs width must be equal."
63
            "But received first input width = [%d], second input width = [%d]",
64 65
            in1_row_numel,
            in2_value.numel() / in2_rows.size()));
66
    PADDLE_ENFORCE_EQ(
67 68
        in1_row_numel,
        out_value->numel() / out_rows.size(),
69
        phi::errors::InvalidArgument(
70
            "The input and oupput width must be equal."
71
            "But received input width = [%d], output width = [%d]",
72 73
            in1_row_numel,
            out_value->numel() / out_rows.size()));
74 75

    auto in1_place = input1.place();
76
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in1_place),
77
                      true,
78
                      phi::errors::InvalidArgument(
79
                          "The running environment is not on the CPU place."));
80
    auto in2_place = input2.place();
81
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in2_place),
82
                      true,
83
                      phi::errors::InvalidArgument(
84
                          "The running environment is not on the CPU place."));
85
    auto out_place = context.GetPlace();
86
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(out_place),
87
                      true,
88
                      phi::errors::InvalidArgument(
89
                          "The running environment is not on the CPU place."));
90 91 92

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();
93 94 95 96 97
    paddle::memory::Copy(out_place,
                         out_data,
                         in1_place,
                         in1_data,
                         in1_value.numel() * sizeof(T));
98 99

    auto* in2_data = in2_value.data<T>();
100 101 102 103 104
    paddle::memory::Copy(out_place,
                         out_data + in1_value.numel(),
                         in2_place,
                         in2_data,
                         in2_value.numel() * sizeof(T));
105 106 107
  }
};

L
Leo Chen 已提交
108 109
template struct SelectedRowsAdd<phi::CPUContext, float>;
template struct SelectedRowsAdd<phi::CPUContext, double>;
110 111

template <typename T>
L
Leo Chen 已提交
112 113
struct SelectedRowsAddTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
114
                  const phi::SelectedRows& input1,
115 116
                  const phi::DenseTensor& input2,
                  phi::DenseTensor* output) {
117
    auto in1_height = input1.height();
118 119
    const auto& in2_dims = input2.dims();
    const auto& out_dims = output->dims();
120
    PADDLE_ENFORCE_EQ(
121 122
        in1_height,
        in2_dims[0],
123 124 125 126 127
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
128
    PADDLE_ENFORCE_EQ(
129 130
        in1_height,
        out_dims[0],
131
        phi::errors::InvalidArgument(
132
            "The input and output height must be equal."
133
            "But received input height = [%d], output height = [%d]",
134 135
            in1_height,
            out_dims[0]));
136 137 138 139 140

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
141
    PADDLE_ENFORCE_EQ(
142 143
        in1_row_numel,
        input2.numel() / in1_height,
144
        phi::errors::InvalidArgument(
145
            "The two inputs width must be equal."
146
            "But received first input width = [%d], second input width = [%d]",
147 148
            in1_row_numel,
            input2.numel() / in1_height));
149
    PADDLE_ENFORCE_EQ(
150 151
        in1_row_numel,
        output->numel() / in1_height,
152
        phi::errors::InvalidArgument(
153
            "The input and output width must be equal."
154
            "But received input width = [%d], output width = [%d]",
155 156
            in1_row_numel,
            output->numel() / in1_height));
157

L
Leo Chen 已提交
158
    phi::funcs::SetConstant<phi::CPUContext, T> functor;
159 160 161 162 163 164 165 166 167 168 169 170
    functor(context, output, 0.0);

    auto* in1_data = in1_value.data<T>();
    auto* out_data = output->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        out_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }

171 172
    auto out_eigen = EigenVector<T>::Flatten(*output);
    auto in2_eigen = EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
173
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
174 175 176
  }
};

L
Leo Chen 已提交
177 178
template struct SelectedRowsAddTensor<phi::CPUContext, float>;
template struct SelectedRowsAddTensor<phi::CPUContext, double>;
Q
QI JUN 已提交
179 180

template <typename T>
L
Leo Chen 已提交
181 182
struct SelectedRowsAddTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
183 184
                  const phi::SelectedRows& input1,
                  const int64_t input2_offset,
185
                  phi::SelectedRows* input2) {
Q
QI JUN 已提交
186
    auto in1_height = input1.height();
187
    PADDLE_ENFORCE_EQ(
188 189
        in1_height,
        input2->height(),
190 191 192 193 194
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2->height()));
Q
QI JUN 已提交
195 196 197 198 199 200 201 202

    auto& in1_rows = input1.rows();
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
203 204
    paddle::framework::MixVector<int64_t> mixv_in2_rows(&in2_rows);
    mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Q
QI JUN 已提交
205 206

    auto in1_place = input1.place();
207
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in1_place),
208
                      true,
209
                      phi::errors::InvalidArgument(
210
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
211
    auto in2_place = input2->place();
212
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in2_place),
213
                      true,
214
                      phi::errors::InvalidArgument(
215
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
216 217 218

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
219 220 221 222 223
    paddle::memory::Copy(in2_place,
                         in2_data + input2_offset,
                         in1_place,
                         in1_data,
                         in1_value.numel() * sizeof(T));
Q
QI JUN 已提交
224 225 226
  }
};

L
Leo Chen 已提交
227 228 229 230
template struct SelectedRowsAddTo<phi::CPUContext, float>;
template struct SelectedRowsAddTo<phi::CPUContext, double>;
template struct SelectedRowsAddTo<phi::CPUContext, int>;
template struct SelectedRowsAddTo<phi::CPUContext, int64_t>;
Q
QI JUN 已提交
231

M
minqiyang 已提交
232
template <typename T>
L
Leo Chen 已提交
233 234
struct SelectedRowsSumTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
235
                  const std::vector<phi::SelectedRows*>& input1,
M
minqiyang 已提交
236
                  const std::vector<int64_t>& input2_offsets,
237
                  phi::SelectedRows* input2) {
M
minqiyang 已提交
238 239 240 241 242 243
    // Ensure all selected rows have the same height
    size_t size = 0u;
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
      auto& in_rows = (*iter)->rows();
      size += in_rows.end() - in_rows.begin();
      auto in1_height = (*iter)->height();
244 245
      PADDLE_ENFORCE_EQ(in1_height,
                        input2->height(),
246
                        phi::errors::InvalidArgument(
247
                            "The two inputs height must be equal."
248
                            "But received first input height = [%d], second "
249
                            "input height = [%d]",
250 251
                            in1_height,
                            input2->height()));
M
minqiyang 已提交
252 253 254 255 256
    }
    // concat rows
    std::vector<int64_t> in2_rows;
    in2_rows.reserve(in2_rows.size() + size);
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
257
      const paddle::framework::Vector<int64_t>& in_rows = (*iter)->rows();
M
minqiyang 已提交
258 259 260 261 262 263
      in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end());
    }
    input2->set_rows(in2_rows);

    auto* in2_value = input2->mutable_value();
    auto* in2_data = in2_value->data<T>();
L
Leo Chen 已提交
264
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
M
minqiyang 已提交
265 266 267 268 269 270 271 272 273 274
    size_t offset = 0u;
    for (size_t i = 0u; i != input1.size(); ++i) {
      auto& in_value = input1[i]->value();
      const auto* in_data = in_value.data<T>();
      offset += input2_offsets[i];
      blas.VCOPY(in_value.numel(), in_data, in2_data + offset);
    }
  }
};

L
Leo Chen 已提交
275 276
template struct SelectedRowsSumTo<phi::CPUContext, float>;
template struct SelectedRowsSumTo<phi::CPUContext, double>;
M
minqiyang 已提交
277

H
hong 已提交
278 279 280
template <typename T>
struct SelectedRowsAddToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
281
                  const phi::SelectedRows& input1,
282
                  phi::DenseTensor* input2) {
H
hong 已提交
283 284 285 286 287
    if (UNLIKELY(input1.rows().size() == 0)) {
      LOG(WARNING) << "input selected rows is empty!";
      return;
    }
    auto in1_height = input1.height();
288
    const auto& in2_dims = input2->dims();
H
hong 已提交
289
    PADDLE_ENFORCE_EQ(
290 291
        in1_height,
        in2_dims[0],
292 293 294 295 296
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
H
hong 已提交
297 298 299 300 301 302

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
303 304
        in1_row_numel,
        input2->numel() / in1_height,
305
        phi::errors::InvalidArgument(
H
hong 已提交
306
            "The two inputs width must be equal."
307
            "But received first input width = [%d], second input width = [%d]",
308 309
            in1_row_numel,
            input2->numel() / in1_height));
H
hong 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }
  }
};

323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
#ifdef PADDLE_WITH_XPU
template <typename T>
struct SelectedRowsAddToTensor<phi::XPUContext, T> {
  void operator()(const phi::XPUContext& context,
                  const phi::SelectedRows& input1,
                  phi::DenseTensor* input2) {
    if (UNLIKELY(input1.rows().size() == 0)) {
      LOG(WARNING) << "input selected rows is empty!";
      return;
    }
    using XPUType = typename XPUTypeTrait<T>::Type;
    auto in1_height = input1.height();
    const auto& in2_dims = input2->dims();
    PADDLE_ENFORCE_EQ(
        in1_height,
        in2_dims[0],
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();
    int64_t* in1_rows_data = nullptr;
    xpu::VectorParam<int64_t> in1_rows_vec{
        in1_rows.data(), static_cast<int>(in1_rows.size()), in1_rows_data};

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
        in1_row_numel,
        input2->numel() / in1_height,
        phi::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But received first input width = [%d], second input width = [%d]",
            in1_row_numel,
            input2->numel() / in1_height));

    auto* in1_data = in1_value.data<T>();
    auto* out_data = input2->data<T>();

    int h = in1_rows.size();
    int w = in1_row_numel;
    const std::vector<int> xshape{h, w};

    int r = xpu::scatter<XPUType, int64_t>(
        context.x_context(),
        nullptr,
        reinterpret_cast<const XPUType*>(in1_data),
        reinterpret_cast<XPUType*>(out_data),
        in1_rows_vec,
        xshape,
        0,
        false);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter");
  }
};

#endif

H
hong 已提交
383 384 385 386
template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
387
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;
388 389 390 391

#ifdef PADDLE_WITH_XPU
template struct SelectedRowsAddToTensor<phi::XPUContext, float>;
#endif
T
typhoonzero 已提交
392 393 394 395 396 397 398 399
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
// Another group of functors is called "scatter updates", which means
// use SelectedRows to update a dense tensor with different Ops, like
// add or mul.
namespace scatter {

400
template <typename T, typename DeviceContext>
401
typename std::enable_if<!std::is_integral<T>::value>::type elementwise_add_to(
402 403 404
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
405
    T* out) {
406
  blas->AXPY(data_len, T(1.f), in, out);
Q
Qiao Longfei 已提交
407 408
}

409
template <typename T, typename DeviceContext>
410
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
411 412 413
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
414
    T* out) {
T
Tao Luo 已提交
415
  for (size_t i = 0; i < data_len; i++) {
Q
Qiao Longfei 已提交
416 417
    out[i] += in[i];
  }
T
typhoonzero 已提交
418 419
}

420
template <typename T, typename DeviceContext>
421
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
422
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
423
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
424 425
                  int64_t input_width,
                  const DeviceContext& context,
426
                  T* out_data) {
427
#ifndef PADDLE_WITH_MKLDNN
428
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
429 430 431 432 433 434 435 436 437
#endif
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

#ifdef PADDLE_WITH_MKLDNN
438 439 440
    OneDNNContext onednn_context(context.GetPlace());
    funcs::OneDNNAXPYHandler<T> axpy_handler(
        input_width, T(1.f), onednn_context.GetEngine());
441 442 443 444 445 446 447 448
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
      axpy_handler(&input_data[i * input_width],
                   &out_data[out_i * input_width]);
    }
#else
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
449 450 451 452
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
453 454 455 456 457
    }
#endif
  }
}

458
template <typename T, typename DeviceContext>
459
typename std::enable_if<!std::is_same<T, phi::dtype::bfloat16>::value>::type
460
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
461
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
462 463
                  int64_t input_width,
                  const DeviceContext& context,
464
                  T* out_data) {
465
  VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name();
466
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
467 468 469 470 471 472 473 474 475
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
476 477 478 479
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
480 481 482 483
    }
  }
}

484 485 486
template <typename DeviceContext, typename T>
struct MergeAddImpl {
  phi::SelectedRows operator()(const DeviceContext& context,
487 488 489
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
490
    (*this)(context, input, &out, sorted_result);
S
sneaxiy 已提交
491 492 493
    return out;
  }

494 495 496 497
  void operator()(const DeviceContext& context,
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
498
    std::vector<const phi::SelectedRows*> inputs;
499
    inputs.push_back(&input);
500
    (*this)(context, inputs, output, sorted_result);
501
  }
T
typhoonzero 已提交
502

503
  void operator()(const DeviceContext& context,
504
                  const std::vector<const phi::SelectedRows*>& inputs,
505 506
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
Q
Qiao Longfei 已提交
507
    if (inputs.size() == 0) {
M
minqiyang 已提交
508
      VLOG(3) << "no input! return";
Q
Qiao Longfei 已提交
509 510
      return;
    }
511
    const phi::SelectedRows* has_value_input = nullptr;
Q
Qiao Longfei 已提交
512
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
513
      if (in->rows().size() > 0) {
Q
Qiao Longfei 已提交
514 515 516 517 518
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
519
      VLOG(3) << "no input has value! just return" << std::endl;
Q
Qiao Longfei 已提交
520 521 522 523
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
524
    phi::SelectedRows& out = *output;
525
    std::set<int64_t> merged_row_set;
526
    size_t row_num = 0;
527
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
528
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
529 530
        continue;
      }
531 532 533 534 535 536 537 538 539
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
540
      row_num += input->rows().size();
541 542
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
543

544
    out.set_height(input_height);
545 546 547 548
    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(phi::make_ddim(
        {static_cast<int64_t>(merged_row_set.size()), input_width}));
    auto* out_data = context.template Alloc<T>(out_tensor);
T
typhoonzero 已提交
549

550 551 552 553 554 555
    if (merged_row_set.size() == row_num && !sorted_result) {
      // no duplicated ids, just concat the result together
      std::vector<int64_t> merge_rows;
      merge_rows.reserve(row_num);
      // concat rows
      for (auto* in : inputs) {
556 557
        merge_rows.insert(
            merge_rows.end(), in->rows().begin(), in->rows().end());
558 559 560 561 562 563 564
      }
      out.set_rows(merge_rows);
      auto in_place = inputs[0]->place();
      auto out_place = out.place();
      int64_t copied_numel = 0;
      for (auto* in : inputs) {
        auto* in_data = in->value().data<T>();
565
        auto in_numel = in->rows().size() * input_width;
566 567 568 569 570
        paddle::memory::Copy(out_place,
                             out_data + copied_numel,
                             in_place,
                             in_data,
                             in_numel * sizeof(T));
571 572 573 574 575
        copied_numel += in_numel;
      }
    } else {
      std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                      merged_row_set.end());
T
typhoonzero 已提交
576

577 578 579
      if (sorted_result) {
        std::sort(merge_rows.begin(), merge_rows.end());
      }
T
typhoonzero 已提交
580

581 582
      out.set_rows(merge_rows);

583
      phi::funcs::SetConstant<DeviceContext, T> constant_functor;
584
      constant_functor(context, out.mutable_value(), static_cast<T>(0.f));
585 586 587 588

      std::unordered_map<int64_t, size_t> rows_to_id;
      for (size_t i = 0; i < merge_rows.size(); ++i) {
        rows_to_id[merge_rows[i]] = i;
Q
Qiao Longfei 已提交
589
      }
590

591 592
      add_sparse_inputs<T, DeviceContext>(
          inputs, rows_to_id, input_width, context, out_data);
T
typhoonzero 已提交
593
    }
T
wip  
typhoonzero 已提交
594 595 596
  }
};

597 598 599 600 601 602 603 604 605 606 607
template <typename T>
struct MergeAdd<phi::CPUContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const phi::CPUContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<phi::CPUContext, T>()(context, input, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
608 609
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
610 611 612 613 614 615
                  const bool sorted_result) {
    MergeAddImpl<phi::CPUContext, T>()(context, input, output, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
616 617
                  phi::SelectedRows* output,
                  const bool sorted_result) {
618 619 620 621
    MergeAddImpl<phi::CPUContext, T>()(context, inputs, output, sorted_result);
  }
};

L
Leo Chen 已提交
622 623
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype)    \
  template struct MergeAddImpl<phi::CPUContext, dtype>; \
624 625 626 627 628 629
  template struct MergeAdd<phi::CPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t)
630 631 632
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<double>)
633

634 635
#ifdef PADDLE_WITH_XPU
template <typename T>
636 637
struct MergeAdd<phi::XPUContext, T> {
  phi::SelectedRows operator()(const phi::XPUContext& context,
638 639 640
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
641 642 643 644
    (*this)(context, input, &out, sorted_result);
    return out;
  }

645
  void operator()(const phi::XPUContext& context,
646 647
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
648
                  const bool sorted_result = false) {
649
    paddle::framework::Vector<int64_t> input_rows(input.rows());
650 651 652 653
    if (input_rows.size() == 0) {
      return;
    }

654
    phi::SelectedRows& out = *output;
655 656 657 658 659 660
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
    std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
    auto input_width = input.value().dims()[1];

    out.set_rows(merge_rows);
    out.set_height(input.height());
661 662 663 664
    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}));
    context.template Alloc<T>(out_tensor);
665 666 667 668 669 670

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

671 672 673 674
    auto* y_data = out.mutable_value()->data<T>();
    auto* x_data = input.value().data<T>();
    int xm = input_rows.size();
    int ym = merge_rows.size();
675
    int n = input_width;
676 677 678 679

    xpu::ctx_guard RAII_GUARD(context.x_context());
    int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
    int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
680 681 682 683 684 685 686 687 688 689
    paddle::memory::Copy(context.GetPlace(),
                         y_rows_data,
                         phi::CPUPlace(),
                         merge_rows.data(),
                         ym * sizeof(int64_t));
    paddle::memory::Copy(context.GetPlace(),
                         x_rows_data,
                         phi::CPUPlace(),
                         input_rows.data(),
                         xm * sizeof(int64_t));
690 691 692 693 694 695 696 697
    int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                            x_data,
                                            y_data,
                                            x_rows_data,
                                            y_rows_data,
                                            xm,
                                            n,
                                            ym);
698
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
699 700
  }

701
  void operator()(const phi::XPUContext& context,
702
                  const std::vector<const phi::SelectedRows*>& inputs,
703 704
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
705 706 707 708
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
709
    const phi::SelectedRows* has_value_input = nullptr;
710 711 712 713 714 715 716 717 718 719 720 721
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
722
    phi::SelectedRows& out = *output;
723 724 725 726 727 728
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
729 730 731 732 733 734 735 736 737
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
738 739 740 741 742 743 744 745 746 747 748 749 750 751
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());

    if (sorted_result) {
      std::sort(merge_rows.begin(), merge_rows.end());
    }

    out.set_rows(merge_rows);
    out.set_height(input_height);

752 753 754 755 756 757
    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(phi::make_ddim(
        {static_cast<int64_t>(merged_row_set.size()), input_width}));
    context.template Alloc<T>(out_tensor);

    float* y_data = reinterpret_cast<float*>(out_tensor->data<T>());
758 759 760 761 762 763 764 765 766 767 768 769

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto& input_rows = input->rows();

770 771 772
      auto* x_data = input->value().data<T>();
      int xm = input_rows.size();
      int ym = merge_rows.size();
773
      int n = input_width;
774 775 776 777

      xpu::ctx_guard RAII_GUARD(context.x_context());
      int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
      int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
778 779 780 781 782 783 784 785 786 787
      paddle::memory::Copy(context.GetPlace(),
                           y_rows_data,
                           phi::CPUPlace(),
                           merge_rows.data(),
                           ym * sizeof(int64_t));
      paddle::memory::Copy(context.GetPlace(),
                           x_rows_data,
                           phi::CPUPlace(),
                           input_rows.data(),
                           xm * sizeof(int64_t));
788 789 790 791 792 793 794 795
      int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                              x_data,
                                              y_data,
                                              x_rows_data,
                                              y_rows_data,
                                              xm,
                                              n,
                                              ym);
796
      PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
797 798 799 800 801
    }
  }
};

#endif
802
template <typename T>
L
Leo Chen 已提交
803 804
struct MergeAverage<phi::CPUContext, T> {
  phi::SelectedRows operator()(const phi::CPUContext& context,
805 806
                               const phi::SelectedRows& input) {
    phi::SelectedRows out;
807 808 809 810
    (*this)(context, input, &out);
    return out;
  }

L
Leo Chen 已提交
811
  void operator()(const phi::CPUContext& context,
812 813
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output) {
814
    std::vector<const phi::SelectedRows*> inputs;
815 816 817 818
    inputs.push_back(&input);
    (*this)(context, inputs, output);
  }

L
Leo Chen 已提交
819
  void operator()(const phi::CPUContext& context,
820 821
                  const std::vector<const phi::SelectedRows*>& inputs,
                  phi::SelectedRows* output) {
822 823 824 825
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
826
    const phi::SelectedRows* has_value_input = nullptr;
827 828 829 830 831 832 833 834 835 836 837 838
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
839
    phi::SelectedRows& out = *output;
840 841 842 843 844 845
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
846 847 848 849 850 851 852 853 854
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All input should have same height."));
855 856 857 858 859
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    out.set_height(input_height);
860 861 862 863 864

    DenseTensor* out_tensor = out.mutable_value();
    out_tensor->Resize(phi::make_ddim(
        {static_cast<int64_t>(merged_row_set.size()), input_width}));
    auto* out_data = context.template Alloc<T>(out_tensor);
865 866 867 868 869 870 871

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());
    std::sort(merge_rows.begin(), merge_rows.end());

    out.set_rows(merge_rows);

L
Leo Chen 已提交
872
    phi::funcs::SetConstant<phi::CPUContext, T> constant_functor;
873 874 875 876 877 878 879
    constant_functor(context, out.mutable_value(), 0.0);

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

L
Leo Chen 已提交
880
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
881 882 883 884 885 886 887 888 889
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();

      for (size_t i = 0; i < input_rows.size(); i++) {
        size_t out_i = rows_to_id[input_rows[i]];
890 891
        elementwise_add_to<T>(&blas,
                              static_cast<size_t>(input_width),
892 893
                              &input_data[i * input_width],
                              &out_data[out_i * input_width]);
894 895 896 897 898 899 900 901 902 903 904 905
      }
    }
    size_t input_width_cast = static_cast<size_t>(input_width);
    T count = static_cast<T>(inputs.size());
    for (size_t i = 0; i < merge_rows.size(); i++) {
      for (size_t j = 0; j < input_width_cast; j++) {
        out_data[i * input_width + j] = out_data[i * input_width + j] / count;
      }
    }
  }
};

906
#ifdef PADDLE_WITH_XPU
907
template struct MergeAdd<phi::XPUContext, float>;
908 909
#endif

L
Leo Chen 已提交
910 911 912 913
template struct MergeAverage<phi::CPUContext, int>;
template struct MergeAverage<phi::CPUContext, int64_t>;
template struct MergeAverage<phi::CPUContext, float>;
template struct MergeAverage<phi::CPUContext, double>;
914

T
wip  
typhoonzero 已提交
915
template <typename T>
L
Leo Chen 已提交
916 917
struct UpdateToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
918 919
                  const ScatterOps& op,
                  const phi::SelectedRows& input1,
920
                  phi::DenseTensor* input2) {
T
wip  
typhoonzero 已提交
921
    auto in1_height = input1.height();
922
    const auto& in2_dims = input2->dims();
923
    PADDLE_ENFORCE_EQ(
924 925
        in1_height,
        in2_dims[0],
926 927 928 929 930
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
T
wip  
typhoonzero 已提交
931 932 933 934 935

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
936
    PADDLE_ENFORCE_EQ(
937 938
        in1_row_numel,
        input2->numel() / in1_height,
939 940 941 942 943
        phi::errors::InvalidArgument("The two inputs width must be equal."
                                     "But received first input width = [%d], "
                                     "second input width = [%d]",
                                     in1_row_numel,
                                     input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    // FIXME(typhoonzero): use macro fix the below messy code.
    switch (op) {
      case ScatterOps::ASSIGN:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::ADD:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUB:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] -=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUBBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] -
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
      case ScatterOps::MUL:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] *=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIV:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] /=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIVBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] /
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
    }
T
typhoonzero 已提交
988 989 990 991
  }
};

}  // namespace scatter
992 993
}  // namespace funcs
}  // namespace phi