selected_rows_functor.cc 33.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
16

17
#include "paddle/fluid/framework/mixed_vector.h"
18
#include "paddle/fluid/platform/device/device_wrapper.h"
19

L
lidanqing 已提交
20
#ifdef PADDLE_WITH_MKLDNN
21
#include "paddle/phi/backends/onednn/axpy_handler.h"
L
lidanqing 已提交
22 23
#endif

24 25
namespace phi {
namespace funcs {
26
template <typename T>
L
Leo Chen 已提交
27 28
struct SelectedRowsAdd<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
29
                  const phi::SelectedRows& input1,
30 31
                  const phi::SelectedRows& input2,
                  phi::SelectedRows* output) {
32
    auto in1_height = input1.height();
33
    PADDLE_ENFORCE_EQ(
34 35
        in1_height,
        input2.height(),
36 37 38 39 40
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height  = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2.height()));
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    output->set_height(in1_height);

    auto& in1_rows = input1.rows();
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
58
    PADDLE_ENFORCE_EQ(
59 60
        in1_row_numel,
        in2_value.numel() / in2_rows.size(),
61
        phi::errors::InvalidArgument(
62
            "The two inputs width must be equal."
63
            "But received first input width = [%d], second input width = [%d]",
64 65
            in1_row_numel,
            in2_value.numel() / in2_rows.size()));
66
    PADDLE_ENFORCE_EQ(
67 68
        in1_row_numel,
        out_value->numel() / out_rows.size(),
69
        phi::errors::InvalidArgument(
70
            "The input and oupput width must be equal."
71
            "But received input width = [%d], output width = [%d]",
72 73
            in1_row_numel,
            out_value->numel() / out_rows.size()));
74 75

    auto in1_place = input1.place();
76
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in1_place),
77
                      true,
78
                      phi::errors::InvalidArgument(
79
                          "The running environment is not on the CPU place."));
80
    auto in2_place = input2.place();
81
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in2_place),
82
                      true,
83
                      phi::errors::InvalidArgument(
84
                          "The running environment is not on the CPU place."));
85
    auto out_place = context.GetPlace();
86
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(out_place),
87
                      true,
88
                      phi::errors::InvalidArgument(
89
                          "The running environment is not on the CPU place."));
90 91 92

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();
93 94 95 96 97
    paddle::memory::Copy(out_place,
                         out_data,
                         in1_place,
                         in1_data,
                         in1_value.numel() * sizeof(T));
98 99

    auto* in2_data = in2_value.data<T>();
100 101 102 103 104
    paddle::memory::Copy(out_place,
                         out_data + in1_value.numel(),
                         in2_place,
                         in2_data,
                         in2_value.numel() * sizeof(T));
105 106 107
  }
};

L
Leo Chen 已提交
108 109
template struct SelectedRowsAdd<phi::CPUContext, float>;
template struct SelectedRowsAdd<phi::CPUContext, double>;
110 111

template <typename T>
L
Leo Chen 已提交
112 113
struct SelectedRowsAddTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
114
                  const phi::SelectedRows& input1,
115 116
                  const phi::DenseTensor& input2,
                  phi::DenseTensor* output) {
117
    auto in1_height = input1.height();
118 119
    const auto& in2_dims = input2.dims();
    const auto& out_dims = output->dims();
120
    PADDLE_ENFORCE_EQ(
121 122
        in1_height,
        in2_dims[0],
123 124 125 126 127
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
128
    PADDLE_ENFORCE_EQ(
129 130
        in1_height,
        out_dims[0],
131
        phi::errors::InvalidArgument(
132
            "The input and output height must be equal."
133
            "But received input height = [%d], output height = [%d]",
134 135
            in1_height,
            out_dims[0]));
136 137 138 139 140

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
141
    PADDLE_ENFORCE_EQ(
142 143
        in1_row_numel,
        input2.numel() / in1_height,
144
        phi::errors::InvalidArgument(
145
            "The two inputs width must be equal."
146
            "But received first input width = [%d], second input width = [%d]",
147 148
            in1_row_numel,
            input2.numel() / in1_height));
149
    PADDLE_ENFORCE_EQ(
150 151
        in1_row_numel,
        output->numel() / in1_height,
152
        phi::errors::InvalidArgument(
153
            "The input and output width must be equal."
154
            "But received input width = [%d], output width = [%d]",
155 156
            in1_row_numel,
            output->numel() / in1_height));
157

L
Leo Chen 已提交
158
    phi::funcs::SetConstant<phi::CPUContext, T> functor;
159 160 161 162 163 164 165 166 167 168 169 170
    functor(context, output, 0.0);

    auto* in1_data = in1_value.data<T>();
    auto* out_data = output->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        out_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }

171 172
    auto out_eigen = EigenVector<T>::Flatten(*output);
    auto in2_eigen = EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
173
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
174 175 176
  }
};

L
Leo Chen 已提交
177 178
template struct SelectedRowsAddTensor<phi::CPUContext, float>;
template struct SelectedRowsAddTensor<phi::CPUContext, double>;
Q
QI JUN 已提交
179 180

template <typename T>
L
Leo Chen 已提交
181 182
struct SelectedRowsAddTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
183 184
                  const phi::SelectedRows& input1,
                  const int64_t input2_offset,
185
                  phi::SelectedRows* input2) {
Q
QI JUN 已提交
186
    auto in1_height = input1.height();
187
    PADDLE_ENFORCE_EQ(
188 189
        in1_height,
        input2->height(),
190 191 192 193 194
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2->height()));
Q
QI JUN 已提交
195 196 197 198 199 200 201 202

    auto& in1_rows = input1.rows();
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
203 204
    paddle::framework::MixVector<int64_t> mixv_in2_rows(&in2_rows);
    mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Q
QI JUN 已提交
205 206

    auto in1_place = input1.place();
207
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in1_place),
208
                      true,
209
                      phi::errors::InvalidArgument(
210
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
211
    auto in2_place = input2->place();
212
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in2_place),
213
                      true,
214
                      phi::errors::InvalidArgument(
215
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
216 217 218

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
219 220 221 222 223
    paddle::memory::Copy(in2_place,
                         in2_data + input2_offset,
                         in1_place,
                         in1_data,
                         in1_value.numel() * sizeof(T));
Q
QI JUN 已提交
224 225 226
  }
};

L
Leo Chen 已提交
227 228 229 230
template struct SelectedRowsAddTo<phi::CPUContext, float>;
template struct SelectedRowsAddTo<phi::CPUContext, double>;
template struct SelectedRowsAddTo<phi::CPUContext, int>;
template struct SelectedRowsAddTo<phi::CPUContext, int64_t>;
Q
QI JUN 已提交
231

M
minqiyang 已提交
232
template <typename T>
L
Leo Chen 已提交
233 234
struct SelectedRowsSumTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
235
                  const std::vector<phi::SelectedRows*>& input1,
M
minqiyang 已提交
236
                  const std::vector<int64_t>& input2_offsets,
237
                  phi::SelectedRows* input2) {
M
minqiyang 已提交
238 239 240 241 242 243
    // Ensure all selected rows have the same height
    size_t size = 0u;
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
      auto& in_rows = (*iter)->rows();
      size += in_rows.end() - in_rows.begin();
      auto in1_height = (*iter)->height();
244 245
      PADDLE_ENFORCE_EQ(in1_height,
                        input2->height(),
246
                        phi::errors::InvalidArgument(
247
                            "The two inputs height must be equal."
248
                            "But received first input height = [%d], second "
249
                            "input height = [%d]",
250 251
                            in1_height,
                            input2->height()));
M
minqiyang 已提交
252 253 254 255 256
    }
    // concat rows
    std::vector<int64_t> in2_rows;
    in2_rows.reserve(in2_rows.size() + size);
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
257
      const paddle::framework::Vector<int64_t>& in_rows = (*iter)->rows();
M
minqiyang 已提交
258 259 260 261 262 263
      in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end());
    }
    input2->set_rows(in2_rows);

    auto* in2_value = input2->mutable_value();
    auto* in2_data = in2_value->data<T>();
L
Leo Chen 已提交
264
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
M
minqiyang 已提交
265 266 267 268 269 270 271 272 273 274
    size_t offset = 0u;
    for (size_t i = 0u; i != input1.size(); ++i) {
      auto& in_value = input1[i]->value();
      const auto* in_data = in_value.data<T>();
      offset += input2_offsets[i];
      blas.VCOPY(in_value.numel(), in_data, in2_data + offset);
    }
  }
};

L
Leo Chen 已提交
275 276
template struct SelectedRowsSumTo<phi::CPUContext, float>;
template struct SelectedRowsSumTo<phi::CPUContext, double>;
M
minqiyang 已提交
277

H
hong 已提交
278 279 280
template <typename T>
struct SelectedRowsAddToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
281
                  const phi::SelectedRows& input1,
282
                  phi::DenseTensor* input2) {
H
hong 已提交
283 284 285 286 287
    if (UNLIKELY(input1.rows().size() == 0)) {
      LOG(WARNING) << "input selected rows is empty!";
      return;
    }
    auto in1_height = input1.height();
288
    const auto& in2_dims = input2->dims();
H
hong 已提交
289
    PADDLE_ENFORCE_EQ(
290 291
        in1_height,
        in2_dims[0],
292 293 294 295 296
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
H
hong 已提交
297 298 299 300 301 302

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
303 304
        in1_row_numel,
        input2->numel() / in1_height,
305
        phi::errors::InvalidArgument(
H
hong 已提交
306
            "The two inputs width must be equal."
307
            "But received first input width = [%d], second input width = [%d]",
308 309
            in1_row_numel,
            input2->numel() / in1_height));
H
hong 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }
  }
};

template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
327
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;
T
typhoonzero 已提交
328 329 330 331 332 333 334 335
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
// Another group of functors is called "scatter updates", which means
// use SelectedRows to update a dense tensor with different Ops, like
// add or mul.
namespace scatter {

336
template <typename T, typename DeviceContext>
337
typename std::enable_if<!std::is_integral<T>::value>::type elementwise_add_to(
338 339 340
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
341
    T* out) {
342
  blas->AXPY(data_len, T(1.f), in, out);
Q
Qiao Longfei 已提交
343 344
}

345
template <typename T, typename DeviceContext>
346
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
347 348 349
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
350
    T* out) {
T
Tao Luo 已提交
351
  for (size_t i = 0; i < data_len; i++) {
Q
Qiao Longfei 已提交
352 353
    out[i] += in[i];
  }
T
typhoonzero 已提交
354 355
}

356
template <typename T, typename DeviceContext>
357
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
358
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
359
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
360 361
                  int64_t input_width,
                  const DeviceContext& context,
362
                  T* out_data) {
363
#ifndef PADDLE_WITH_MKLDNN
364
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
365 366 367 368 369 370 371 372 373
#endif
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

#ifdef PADDLE_WITH_MKLDNN
374 375 376
    OneDNNContext onednn_context(context.GetPlace());
    funcs::OneDNNAXPYHandler<T> axpy_handler(
        input_width, T(1.f), onednn_context.GetEngine());
377 378 379 380 381 382 383 384
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
      axpy_handler(&input_data[i * input_width],
                   &out_data[out_i * input_width]);
    }
#else
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
385 386 387 388
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
389 390 391 392 393
    }
#endif
  }
}

394
template <typename T, typename DeviceContext>
395
typename std::enable_if<!std::is_same<T, phi::dtype::bfloat16>::value>::type
396
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
397
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
398 399
                  int64_t input_width,
                  const DeviceContext& context,
400
                  T* out_data) {
401
  VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name();
402
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
403 404 405 406 407 408 409 410 411
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
412 413 414 415
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
416 417 418 419
    }
  }
}

420 421 422
template <typename DeviceContext, typename T>
struct MergeAddImpl {
  phi::SelectedRows operator()(const DeviceContext& context,
423 424 425
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
426
    (*this)(context, input, &out, sorted_result);
S
sneaxiy 已提交
427 428 429
    return out;
  }

430 431 432 433
  void operator()(const DeviceContext& context,
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
434
    std::vector<const phi::SelectedRows*> inputs;
435
    inputs.push_back(&input);
436
    (*this)(context, inputs, output, sorted_result);
437
  }
T
typhoonzero 已提交
438

439
  void operator()(const DeviceContext& context,
440
                  const std::vector<const phi::SelectedRows*>& inputs,
441 442
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
Q
Qiao Longfei 已提交
443
    if (inputs.size() == 0) {
M
minqiyang 已提交
444
      VLOG(3) << "no input! return";
Q
Qiao Longfei 已提交
445 446
      return;
    }
447
    const phi::SelectedRows* has_value_input = nullptr;
Q
Qiao Longfei 已提交
448
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
449
      if (in->rows().size() > 0) {
Q
Qiao Longfei 已提交
450 451 452 453 454
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
455
      VLOG(3) << "no input has value! just return" << std::endl;
Q
Qiao Longfei 已提交
456 457 458 459
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
460
    phi::SelectedRows& out = *output;
461
    std::set<int64_t> merged_row_set;
462
    size_t row_num = 0;
463
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
464
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
465 466
        continue;
      }
467 468 469 470 471 472 473 474 475
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
476
      row_num += input->rows().size();
477 478
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
479

480
    out.set_height(input_height);
T
wip  
typhoonzero 已提交
481
    out.mutable_value()->mutable_data<T>(
482
        phi::make_ddim(
483
            {static_cast<int64_t>(merged_row_set.size()), input_width}),
T
typhoonzero 已提交
484
        context.GetPlace());
485
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
486

487 488 489 490 491 492
    if (merged_row_set.size() == row_num && !sorted_result) {
      // no duplicated ids, just concat the result together
      std::vector<int64_t> merge_rows;
      merge_rows.reserve(row_num);
      // concat rows
      for (auto* in : inputs) {
493 494
        merge_rows.insert(
            merge_rows.end(), in->rows().begin(), in->rows().end());
495 496 497 498 499 500 501
      }
      out.set_rows(merge_rows);
      auto in_place = inputs[0]->place();
      auto out_place = out.place();
      int64_t copied_numel = 0;
      for (auto* in : inputs) {
        auto* in_data = in->value().data<T>();
502
        auto in_numel = in->rows().size() * input_width;
503 504 505 506 507
        paddle::memory::Copy(out_place,
                             out_data + copied_numel,
                             in_place,
                             in_data,
                             in_numel * sizeof(T));
508 509 510 511 512
        copied_numel += in_numel;
      }
    } else {
      std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                      merged_row_set.end());
T
typhoonzero 已提交
513

514 515 516
      if (sorted_result) {
        std::sort(merge_rows.begin(), merge_rows.end());
      }
T
typhoonzero 已提交
517

518 519
      out.set_rows(merge_rows);

520
      phi::funcs::SetConstant<DeviceContext, T> constant_functor;
521
      constant_functor(context, out.mutable_value(), static_cast<T>(0.f));
522 523 524 525

      std::unordered_map<int64_t, size_t> rows_to_id;
      for (size_t i = 0; i < merge_rows.size(); ++i) {
        rows_to_id[merge_rows[i]] = i;
Q
Qiao Longfei 已提交
526
      }
527

528 529
      add_sparse_inputs<T, DeviceContext>(
          inputs, rows_to_id, input_width, context, out_data);
T
typhoonzero 已提交
530
    }
T
wip  
typhoonzero 已提交
531 532 533
  }
};

534 535 536 537 538 539 540 541 542 543 544
template <typename T>
struct MergeAdd<phi::CPUContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const phi::CPUContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<phi::CPUContext, T>()(context, input, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
545 546
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
547 548 549 550 551 552
                  const bool sorted_result) {
    MergeAddImpl<phi::CPUContext, T>()(context, input, output, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
553 554
                  phi::SelectedRows* output,
                  const bool sorted_result) {
555 556 557 558
    MergeAddImpl<phi::CPUContext, T>()(context, inputs, output, sorted_result);
  }
};

L
Leo Chen 已提交
559 560
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype)    \
  template struct MergeAddImpl<phi::CPUContext, dtype>; \
561 562 563 564 565 566
  template struct MergeAdd<phi::CPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t)
567 568 569
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<double>)
570

571 572
#ifdef PADDLE_WITH_XPU
template <typename T>
573 574
struct MergeAdd<phi::XPUContext, T> {
  phi::SelectedRows operator()(const phi::XPUContext& context,
575 576 577
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
578 579 580 581
    (*this)(context, input, &out, sorted_result);
    return out;
  }

582
  void operator()(const phi::XPUContext& context,
583 584
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
585
                  const bool sorted_result = false) {
586
    paddle::framework::Vector<int64_t> input_rows(input.rows());
587 588 589 590
    if (input_rows.size() == 0) {
      return;
    }

591
    phi::SelectedRows& out = *output;
592 593 594 595 596 597 598
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
    std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
    auto input_width = input.value().dims()[1];

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
599
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
600 601 602 603 604 605 606
        context.GetPlace());

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

607 608 609 610
    auto* y_data = out.mutable_value()->data<T>();
    auto* x_data = input.value().data<T>();
    int xm = input_rows.size();
    int ym = merge_rows.size();
611
    int n = input_width;
612 613 614 615

    xpu::ctx_guard RAII_GUARD(context.x_context());
    int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
    int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
616 617 618 619 620 621 622 623 624 625
    paddle::memory::Copy(context.GetPlace(),
                         y_rows_data,
                         phi::CPUPlace(),
                         merge_rows.data(),
                         ym * sizeof(int64_t));
    paddle::memory::Copy(context.GetPlace(),
                         x_rows_data,
                         phi::CPUPlace(),
                         input_rows.data(),
                         xm * sizeof(int64_t));
626 627 628 629 630 631 632 633
    int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                            x_data,
                                            y_data,
                                            x_rows_data,
                                            y_rows_data,
                                            xm,
                                            n,
                                            ym);
634
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
635 636
  }

637
  void operator()(const phi::XPUContext& context,
638
                  const std::vector<const phi::SelectedRows*>& inputs,
639 640
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
641 642 643 644
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
645
    const phi::SelectedRows* has_value_input = nullptr;
646 647 648 649 650 651 652 653 654 655 656 657
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
658
    phi::SelectedRows& out = *output;
659 660 661 662 663 664
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
665 666 667 668 669 670 671 672 673
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
674 675 676 677 678 679 680 681 682 683 684 685 686 687
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());

    if (sorted_result) {
      std::sort(merge_rows.begin(), merge_rows.end());
    }

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
688
        phi::make_ddim(
689 690 691
            {static_cast<int64_t>(merged_row_set.size()), input_width}),
        context.GetPlace());

692
    float* y_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());
693 694 695 696 697 698 699 700 701 702 703 704

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto& input_rows = input->rows();

705 706 707
      auto* x_data = input->value().data<T>();
      int xm = input_rows.size();
      int ym = merge_rows.size();
708
      int n = input_width;
709 710 711 712

      xpu::ctx_guard RAII_GUARD(context.x_context());
      int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
      int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
713 714 715 716 717 718 719 720 721 722
      paddle::memory::Copy(context.GetPlace(),
                           y_rows_data,
                           phi::CPUPlace(),
                           merge_rows.data(),
                           ym * sizeof(int64_t));
      paddle::memory::Copy(context.GetPlace(),
                           x_rows_data,
                           phi::CPUPlace(),
                           input_rows.data(),
                           xm * sizeof(int64_t));
723 724 725 726 727 728 729 730
      int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                              x_data,
                                              y_data,
                                              x_rows_data,
                                              y_rows_data,
                                              xm,
                                              n,
                                              ym);
731
      PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
732 733 734 735 736
    }
  }
};

#endif
737
template <typename T>
L
Leo Chen 已提交
738 739
struct MergeAverage<phi::CPUContext, T> {
  phi::SelectedRows operator()(const phi::CPUContext& context,
740 741
                               const phi::SelectedRows& input) {
    phi::SelectedRows out;
742 743 744 745
    (*this)(context, input, &out);
    return out;
  }

L
Leo Chen 已提交
746
  void operator()(const phi::CPUContext& context,
747 748
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output) {
749
    std::vector<const phi::SelectedRows*> inputs;
750 751 752 753
    inputs.push_back(&input);
    (*this)(context, inputs, output);
  }

L
Leo Chen 已提交
754
  void operator()(const phi::CPUContext& context,
755 756
                  const std::vector<const phi::SelectedRows*>& inputs,
                  phi::SelectedRows* output) {
757 758 759 760
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
761
    const phi::SelectedRows* has_value_input = nullptr;
762 763 764 765 766 767 768 769 770 771 772 773
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
774
    phi::SelectedRows& out = *output;
775 776 777 778 779 780
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
781 782 783 784 785 786 787 788 789
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All input should have same height."));
790 791 792 793 794 795
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
796
        phi::make_ddim(
797 798 799 800 801 802 803 804 805 806
            {static_cast<int64_t>(merged_row_set.size()), input_width}),
        context.GetPlace());
    auto* out_data = out.mutable_value()->data<T>();

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());
    std::sort(merge_rows.begin(), merge_rows.end());

    out.set_rows(merge_rows);

L
Leo Chen 已提交
807
    phi::funcs::SetConstant<phi::CPUContext, T> constant_functor;
808 809 810 811 812 813 814
    constant_functor(context, out.mutable_value(), 0.0);

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

L
Leo Chen 已提交
815
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
816 817 818 819 820 821 822 823 824
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();

      for (size_t i = 0; i < input_rows.size(); i++) {
        size_t out_i = rows_to_id[input_rows[i]];
825 826
        elementwise_add_to<T>(&blas,
                              static_cast<size_t>(input_width),
827 828
                              &input_data[i * input_width],
                              &out_data[out_i * input_width]);
829 830 831 832 833 834 835 836 837 838 839 840
      }
    }
    size_t input_width_cast = static_cast<size_t>(input_width);
    T count = static_cast<T>(inputs.size());
    for (size_t i = 0; i < merge_rows.size(); i++) {
      for (size_t j = 0; j < input_width_cast; j++) {
        out_data[i * input_width + j] = out_data[i * input_width + j] / count;
      }
    }
  }
};

841
#ifdef PADDLE_WITH_XPU
842
template struct MergeAdd<phi::XPUContext, float>;
843 844
#endif

L
Leo Chen 已提交
845 846 847 848
template struct MergeAverage<phi::CPUContext, int>;
template struct MergeAverage<phi::CPUContext, int64_t>;
template struct MergeAverage<phi::CPUContext, float>;
template struct MergeAverage<phi::CPUContext, double>;
849

T
wip  
typhoonzero 已提交
850
template <typename T>
L
Leo Chen 已提交
851 852
struct UpdateToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
853 854
                  const ScatterOps& op,
                  const phi::SelectedRows& input1,
855
                  phi::DenseTensor* input2) {
T
wip  
typhoonzero 已提交
856
    auto in1_height = input1.height();
857
    const auto& in2_dims = input2->dims();
858
    PADDLE_ENFORCE_EQ(
859 860
        in1_height,
        in2_dims[0],
861 862 863 864 865
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
T
wip  
typhoonzero 已提交
866 867 868 869 870

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
871
    PADDLE_ENFORCE_EQ(
872 873
        in1_row_numel,
        input2->numel() / in1_height,
874 875 876 877 878
        phi::errors::InvalidArgument("The two inputs width must be equal."
                                     "But received first input width = [%d], "
                                     "second input width = [%d]",
                                     in1_row_numel,
                                     input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    // FIXME(typhoonzero): use macro fix the below messy code.
    switch (op) {
      case ScatterOps::ASSIGN:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::ADD:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUB:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] -=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUBBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] -
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
      case ScatterOps::MUL:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] *=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIV:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] /=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIVBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] /
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
    }
T
typhoonzero 已提交
923 924 925 926
  }
};

}  // namespace scatter
927 928
}  // namespace funcs
}  // namespace phi