selected_rows_functor.cc 33.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
16

17
#include "paddle/fluid/framework/mixed_vector.h"
18
#include "paddle/fluid/platform/device/device_wrapper.h"
19

L
lidanqing 已提交
20 21 22 23
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#endif

24 25
namespace phi {
namespace funcs {
26
template <typename T>
L
Leo Chen 已提交
27 28
struct SelectedRowsAdd<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
29
                  const phi::SelectedRows& input1,
30 31
                  const phi::SelectedRows& input2,
                  phi::SelectedRows* output) {
32
    auto in1_height = input1.height();
33
    PADDLE_ENFORCE_EQ(
34 35
        in1_height,
        input2.height(),
36 37 38 39 40
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height  = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2.height()));
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    output->set_height(in1_height);

    auto& in1_rows = input1.rows();
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
58
    PADDLE_ENFORCE_EQ(
59 60
        in1_row_numel,
        in2_value.numel() / in2_rows.size(),
61
        phi::errors::InvalidArgument(
62
            "The two inputs width must be equal."
63
            "But received first input width = [%d], second input width = [%d]",
64 65
            in1_row_numel,
            in2_value.numel() / in2_rows.size()));
66
    PADDLE_ENFORCE_EQ(
67 68
        in1_row_numel,
        out_value->numel() / out_rows.size(),
69
        phi::errors::InvalidArgument(
70
            "The input and oupput width must be equal."
71
            "But received input width = [%d], output width = [%d]",
72 73
            in1_row_numel,
            out_value->numel() / out_rows.size()));
74 75

    auto in1_place = input1.place();
76
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in1_place),
77
                      true,
78
                      phi::errors::InvalidArgument(
79
                          "The running environment is not on the CPU place."));
80
    auto in2_place = input2.place();
81
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in2_place),
82
                      true,
83
                      phi::errors::InvalidArgument(
84
                          "The running environment is not on the CPU place."));
85
    auto out_place = context.GetPlace();
86
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(out_place),
87
                      true,
88
                      phi::errors::InvalidArgument(
89
                          "The running environment is not on the CPU place."));
90 91 92

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();
93 94 95 96 97
    paddle::memory::Copy(out_place,
                         out_data,
                         in1_place,
                         in1_data,
                         in1_value.numel() * sizeof(T));
98 99

    auto* in2_data = in2_value.data<T>();
100 101 102 103 104
    paddle::memory::Copy(out_place,
                         out_data + in1_value.numel(),
                         in2_place,
                         in2_data,
                         in2_value.numel() * sizeof(T));
105 106 107
  }
};

L
Leo Chen 已提交
108 109
template struct SelectedRowsAdd<phi::CPUContext, float>;
template struct SelectedRowsAdd<phi::CPUContext, double>;
110 111

template <typename T>
L
Leo Chen 已提交
112 113
struct SelectedRowsAddTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
114
                  const phi::SelectedRows& input1,
115 116
                  const phi::DenseTensor& input2,
                  phi::DenseTensor* output) {
117
    auto in1_height = input1.height();
118 119
    const auto& in2_dims = input2.dims();
    const auto& out_dims = output->dims();
120
    PADDLE_ENFORCE_EQ(
121 122
        in1_height,
        in2_dims[0],
123 124 125 126 127
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
128
    PADDLE_ENFORCE_EQ(
129 130
        in1_height,
        out_dims[0],
131
        phi::errors::InvalidArgument(
132
            "The input and output height must be equal."
133
            "But received input height = [%d], output height = [%d]",
134 135
            in1_height,
            out_dims[0]));
136 137 138 139 140

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
141
    PADDLE_ENFORCE_EQ(
142 143
        in1_row_numel,
        input2.numel() / in1_height,
144
        phi::errors::InvalidArgument(
145
            "The two inputs width must be equal."
146
            "But received first input width = [%d], second input width = [%d]",
147 148
            in1_row_numel,
            input2.numel() / in1_height));
149
    PADDLE_ENFORCE_EQ(
150 151
        in1_row_numel,
        output->numel() / in1_height,
152
        phi::errors::InvalidArgument(
153
            "The input and output width must be equal."
154
            "But received input width = [%d], output width = [%d]",
155 156
            in1_row_numel,
            output->numel() / in1_height));
157

L
Leo Chen 已提交
158
    phi::funcs::SetConstant<phi::CPUContext, T> functor;
159 160 161 162 163 164 165 166 167 168 169 170
    functor(context, output, 0.0);

    auto* in1_data = in1_value.data<T>();
    auto* out_data = output->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        out_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }

171 172
    auto out_eigen = EigenVector<T>::Flatten(*output);
    auto in2_eigen = EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
173
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
174 175 176
  }
};

L
Leo Chen 已提交
177 178
template struct SelectedRowsAddTensor<phi::CPUContext, float>;
template struct SelectedRowsAddTensor<phi::CPUContext, double>;
Q
QI JUN 已提交
179 180

template <typename T>
L
Leo Chen 已提交
181 182
struct SelectedRowsAddTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
183 184
                  const phi::SelectedRows& input1,
                  const int64_t input2_offset,
185
                  phi::SelectedRows* input2) {
Q
QI JUN 已提交
186
    auto in1_height = input1.height();
187
    PADDLE_ENFORCE_EQ(
188 189
        in1_height,
        input2->height(),
190 191 192 193 194
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     input2->height()));
Q
QI JUN 已提交
195 196 197 198 199 200 201 202

    auto& in1_rows = input1.rows();
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
203 204
    paddle::framework::MixVector<int64_t> mixv_in2_rows(&in2_rows);
    mixv_in2_rows.Extend(in1_rows.begin(), in1_rows.end());
Q
QI JUN 已提交
205 206

    auto in1_place = input1.place();
207
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in1_place),
208
                      true,
209
                      phi::errors::InvalidArgument(
210
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
211
    auto in2_place = input2->place();
212
    PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(in2_place),
213
                      true,
214
                      phi::errors::InvalidArgument(
215
                          "The running environment is not on the CPU place."));
Q
QI JUN 已提交
216 217 218

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
219 220 221 222 223
    paddle::memory::Copy(in2_place,
                         in2_data + input2_offset,
                         in1_place,
                         in1_data,
                         in1_value.numel() * sizeof(T));
Q
QI JUN 已提交
224 225 226
  }
};

L
Leo Chen 已提交
227 228 229 230
template struct SelectedRowsAddTo<phi::CPUContext, float>;
template struct SelectedRowsAddTo<phi::CPUContext, double>;
template struct SelectedRowsAddTo<phi::CPUContext, int>;
template struct SelectedRowsAddTo<phi::CPUContext, int64_t>;
Q
QI JUN 已提交
231

M
minqiyang 已提交
232
template <typename T>
L
Leo Chen 已提交
233 234
struct SelectedRowsSumTo<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
235
                  const std::vector<phi::SelectedRows*>& input1,
M
minqiyang 已提交
236
                  const std::vector<int64_t>& input2_offsets,
237
                  phi::SelectedRows* input2) {
M
minqiyang 已提交
238 239 240 241 242 243
    // Ensure all selected rows have the same height
    size_t size = 0u;
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
      auto& in_rows = (*iter)->rows();
      size += in_rows.end() - in_rows.begin();
      auto in1_height = (*iter)->height();
244 245
      PADDLE_ENFORCE_EQ(in1_height,
                        input2->height(),
246
                        phi::errors::InvalidArgument(
247
                            "The two inputs height must be equal."
248
                            "But received first input height = [%d], second "
249
                            "input height = [%d]",
250 251
                            in1_height,
                            input2->height()));
M
minqiyang 已提交
252 253 254 255 256
    }
    // concat rows
    std::vector<int64_t> in2_rows;
    in2_rows.reserve(in2_rows.size() + size);
    for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
257
      const paddle::framework::Vector<int64_t>& in_rows = (*iter)->rows();
M
minqiyang 已提交
258 259 260 261 262 263
      in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end());
    }
    input2->set_rows(in2_rows);

    auto* in2_value = input2->mutable_value();
    auto* in2_data = in2_value->data<T>();
L
Leo Chen 已提交
264
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
M
minqiyang 已提交
265 266 267 268 269 270 271 272 273 274
    size_t offset = 0u;
    for (size_t i = 0u; i != input1.size(); ++i) {
      auto& in_value = input1[i]->value();
      const auto* in_data = in_value.data<T>();
      offset += input2_offsets[i];
      blas.VCOPY(in_value.numel(), in_data, in2_data + offset);
    }
  }
};

L
Leo Chen 已提交
275 276
template struct SelectedRowsSumTo<phi::CPUContext, float>;
template struct SelectedRowsSumTo<phi::CPUContext, double>;
M
minqiyang 已提交
277

H
hong 已提交
278 279 280
template <typename T>
struct SelectedRowsAddToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
281
                  const phi::SelectedRows& input1,
282
                  phi::DenseTensor* input2) {
H
hong 已提交
283 284 285 286 287
    if (UNLIKELY(input1.rows().size() == 0)) {
      LOG(WARNING) << "input selected rows is empty!";
      return;
    }
    auto in1_height = input1.height();
288
    const auto& in2_dims = input2->dims();
H
hong 已提交
289
    PADDLE_ENFORCE_EQ(
290 291
        in1_height,
        in2_dims[0],
292 293 294 295 296
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
H
hong 已提交
297 298 299 300 301 302

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
    PADDLE_ENFORCE_EQ(
303 304
        in1_row_numel,
        input2->numel() / in1_height,
305
        phi::errors::InvalidArgument(
H
hong 已提交
306
            "The two inputs width must be equal."
307
            "But received first input width = [%d], second input width = [%d]",
308 309
            in1_row_numel,
            input2->numel() / in1_height));
H
hong 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    for (size_t i = 0; i < in1_rows.size(); i++) {
      for (int64_t j = 0; j < in1_row_numel; j++) {
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
      }
    }
  }
};

template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
327
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;
T
typhoonzero 已提交
328 329 330 331 332 333 334 335
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
// Another group of functors is called "scatter updates", which means
// use SelectedRows to update a dense tensor with different Ops, like
// add or mul.
namespace scatter {

336
template <typename T, typename DeviceContext>
337
typename std::enable_if<!std::is_integral<T>::value>::type elementwise_add_to(
338 339 340
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
341
    T* out) {
342
  blas->AXPY(data_len, T(1.f), in, out);
Q
Qiao Longfei 已提交
343 344
}

345
template <typename T, typename DeviceContext>
346
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
347 348 349
    phi::funcs::BlasT<DeviceContext, T>* blas,
    size_t data_len,
    const T* in,
350
    T* out) {
T
Tao Luo 已提交
351
  for (size_t i = 0; i < data_len; i++) {
Q
Qiao Longfei 已提交
352 353
    out[i] += in[i];
  }
T
typhoonzero 已提交
354 355
}

356
template <typename T, typename DeviceContext>
357
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
358
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
359
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
360 361
                  int64_t input_width,
                  const DeviceContext& context,
362
                  T* out_data) {
363
#ifndef PADDLE_WITH_MKLDNN
364
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
365 366 367 368 369 370 371 372 373
#endif
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

#ifdef PADDLE_WITH_MKLDNN
374
    paddle::operators::OneDNNAXPYHandler<T> axpy_handler(input_width, T(1.f));
375 376 377 378 379 380 381 382
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
      axpy_handler(&input_data[i * input_width],
                   &out_data[out_i * input_width]);
    }
#else
    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
383 384 385 386
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
387 388 389 390 391
    }
#endif
  }
}

392
template <typename T, typename DeviceContext>
393
typename std::enable_if<!std::is_same<T, phi::dtype::bfloat16>::value>::type
394
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
395
                  const std::unordered_map<int64_t, size_t>& rows_to_id,
396 397
                  int64_t input_width,
                  const DeviceContext& context,
398
                  T* out_data) {
399
  VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name();
400
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
401 402 403 404 405 406 407 408 409
  for (auto* input : inputs) {
    if (input->rows().size() == 0) {
      continue;
    }
    auto* input_data = input->value().data<T>();
    auto& input_rows = input->rows();

    for (size_t i = 0; i < input_rows.size(); i++) {
      size_t out_i = rows_to_id.at(input_rows[i]);
410 411 412 413
      elementwise_add_to<T, DeviceContext>(&blas,
                                           static_cast<size_t>(input_width),
                                           &input_data[i * input_width],
                                           &out_data[out_i * input_width]);
414 415 416 417
    }
  }
}

418 419 420
template <typename DeviceContext, typename T>
struct MergeAddImpl {
  phi::SelectedRows operator()(const DeviceContext& context,
421 422 423
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
424
    (*this)(context, input, &out, sorted_result);
S
sneaxiy 已提交
425 426 427
    return out;
  }

428 429 430 431
  void operator()(const DeviceContext& context,
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
432
    std::vector<const phi::SelectedRows*> inputs;
433
    inputs.push_back(&input);
434
    (*this)(context, inputs, output, sorted_result);
435
  }
T
typhoonzero 已提交
436

437
  void operator()(const DeviceContext& context,
438
                  const std::vector<const phi::SelectedRows*>& inputs,
439 440
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
Q
Qiao Longfei 已提交
441
    if (inputs.size() == 0) {
M
minqiyang 已提交
442
      VLOG(3) << "no input! return";
Q
Qiao Longfei 已提交
443 444
      return;
    }
445
    const phi::SelectedRows* has_value_input = nullptr;
Q
Qiao Longfei 已提交
446
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
447
      if (in->rows().size() > 0) {
Q
Qiao Longfei 已提交
448 449 450 451 452
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
453
      VLOG(3) << "no input has value! just return" << std::endl;
Q
Qiao Longfei 已提交
454 455 456 457
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
458
    phi::SelectedRows& out = *output;
459
    std::set<int64_t> merged_row_set;
460
    size_t row_num = 0;
461
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
462
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
463 464
        continue;
      }
465 466 467 468 469 470 471 472 473
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
474
      row_num += input->rows().size();
475 476
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
477

478
    out.set_height(input_height);
T
wip  
typhoonzero 已提交
479
    out.mutable_value()->mutable_data<T>(
480
        phi::make_ddim(
481
            {static_cast<int64_t>(merged_row_set.size()), input_width}),
T
typhoonzero 已提交
482
        context.GetPlace());
483
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
484

485 486 487 488 489 490
    if (merged_row_set.size() == row_num && !sorted_result) {
      // no duplicated ids, just concat the result together
      std::vector<int64_t> merge_rows;
      merge_rows.reserve(row_num);
      // concat rows
      for (auto* in : inputs) {
491 492
        merge_rows.insert(
            merge_rows.end(), in->rows().begin(), in->rows().end());
493 494 495 496 497 498 499
      }
      out.set_rows(merge_rows);
      auto in_place = inputs[0]->place();
      auto out_place = out.place();
      int64_t copied_numel = 0;
      for (auto* in : inputs) {
        auto* in_data = in->value().data<T>();
500
        auto in_numel = in->rows().size() * input_width;
501 502 503 504 505
        paddle::memory::Copy(out_place,
                             out_data + copied_numel,
                             in_place,
                             in_data,
                             in_numel * sizeof(T));
506 507 508 509 510
        copied_numel += in_numel;
      }
    } else {
      std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                      merged_row_set.end());
T
typhoonzero 已提交
511

512 513 514
      if (sorted_result) {
        std::sort(merge_rows.begin(), merge_rows.end());
      }
T
typhoonzero 已提交
515

516 517
      out.set_rows(merge_rows);

518
      phi::funcs::SetConstant<DeviceContext, T> constant_functor;
519
      constant_functor(context, out.mutable_value(), static_cast<T>(0.f));
520 521 522 523

      std::unordered_map<int64_t, size_t> rows_to_id;
      for (size_t i = 0; i < merge_rows.size(); ++i) {
        rows_to_id[merge_rows[i]] = i;
Q
Qiao Longfei 已提交
524
      }
525

526 527
      add_sparse_inputs<T, DeviceContext>(
          inputs, rows_to_id, input_width, context, out_data);
T
typhoonzero 已提交
528
    }
T
wip  
typhoonzero 已提交
529 530 531
  }
};

532 533 534 535 536 537 538 539 540 541 542
template <typename T>
struct MergeAdd<phi::CPUContext, T> {
  // unary functor, merge by adding duplicated rows in
  // the input SelectedRows object.
  phi::SelectedRows operator()(const phi::CPUContext& context,
                               const phi::SelectedRows& input,
                               const bool sorted_result) {
    return MergeAddImpl<phi::CPUContext, T>()(context, input, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
543 544
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
545 546 547 548 549 550
                  const bool sorted_result) {
    MergeAddImpl<phi::CPUContext, T>()(context, input, output, sorted_result);
  }

  void operator()(const phi::CPUContext& context,
                  const std::vector<const phi::SelectedRows*>& inputs,
551 552
                  phi::SelectedRows* output,
                  const bool sorted_result) {
553 554 555 556
    MergeAddImpl<phi::CPUContext, T>()(context, inputs, output, sorted_result);
  }
};

L
Leo Chen 已提交
557 558
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype)    \
  template struct MergeAddImpl<phi::CPUContext, dtype>; \
559 560 561 562 563 564
  template struct MergeAdd<phi::CPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t)
565 566 567
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(phi::dtype::complex<double>)
568

569 570
#ifdef PADDLE_WITH_XPU
template <typename T>
571 572
struct MergeAdd<phi::XPUContext, T> {
  phi::SelectedRows operator()(const phi::XPUContext& context,
573 574 575
                               const phi::SelectedRows& input,
                               const bool sorted_result = false) {
    phi::SelectedRows out;
576 577 578 579
    (*this)(context, input, &out, sorted_result);
    return out;
  }

580
  void operator()(const phi::XPUContext& context,
581 582
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output,
583
                  const bool sorted_result = false) {
584
    paddle::framework::Vector<int64_t> input_rows(input.rows());
585 586 587 588
    if (input_rows.size() == 0) {
      return;
    }

589
    phi::SelectedRows& out = *output;
590 591 592 593 594 595 596
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
    std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
    auto input_width = input.value().dims()[1];

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
597
        phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
598 599 600 601 602 603 604
        context.GetPlace());

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

605 606 607 608
    auto* y_data = out.mutable_value()->data<T>();
    auto* x_data = input.value().data<T>();
    int xm = input_rows.size();
    int ym = merge_rows.size();
609
    int n = input_width;
610 611 612 613

    xpu::ctx_guard RAII_GUARD(context.x_context());
    int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
    int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
614 615 616 617 618 619 620 621 622 623
    paddle::memory::Copy(context.GetPlace(),
                         y_rows_data,
                         phi::CPUPlace(),
                         merge_rows.data(),
                         ym * sizeof(int64_t));
    paddle::memory::Copy(context.GetPlace(),
                         x_rows_data,
                         phi::CPUPlace(),
                         input_rows.data(),
                         xm * sizeof(int64_t));
624 625 626 627 628 629 630 631
    int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                            x_data,
                                            y_data,
                                            x_rows_data,
                                            y_rows_data,
                                            xm,
                                            n,
                                            ym);
632
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
633 634
  }

635
  void operator()(const phi::XPUContext& context,
636
                  const std::vector<const phi::SelectedRows*>& inputs,
637 638
                  phi::SelectedRows* output,
                  const bool sorted_result = false) {
639 640 641 642
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
643
    const phi::SelectedRows* has_value_input = nullptr;
644 645 646 647 648 649 650 651 652 653 654 655
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
656
    phi::SelectedRows& out = *output;
657 658 659 660 661 662
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
663 664 665 666 667 668 669 670 671
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All inputs should have same height."));
672 673 674 675 676 677 678 679 680 681 682 683 684 685
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());

    if (sorted_result) {
      std::sort(merge_rows.begin(), merge_rows.end());
    }

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
686
        phi::make_ddim(
687 688 689
            {static_cast<int64_t>(merged_row_set.size()), input_width}),
        context.GetPlace());

690
    float* y_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());
691 692 693 694 695 696 697 698 699 700 701 702

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto& input_rows = input->rows();

703 704 705
      auto* x_data = input->value().data<T>();
      int xm = input_rows.size();
      int ym = merge_rows.size();
706
      int n = input_width;
707 708 709 710

      xpu::ctx_guard RAII_GUARD(context.x_context());
      int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
      int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
711 712 713 714 715 716 717 718 719 720
      paddle::memory::Copy(context.GetPlace(),
                           y_rows_data,
                           phi::CPUPlace(),
                           merge_rows.data(),
                           ym * sizeof(int64_t));
      paddle::memory::Copy(context.GetPlace(),
                           x_rows_data,
                           phi::CPUPlace(),
                           input_rows.data(),
                           xm * sizeof(int64_t));
721 722 723 724 725 726 727 728
      int r = xpu::merge_dup_rows<T, int64_t>(context.x_context(),
                                              x_data,
                                              y_data,
                                              x_rows_data,
                                              y_rows_data,
                                              xm,
                                              n,
                                              ym);
729
      PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
730 731 732 733 734
    }
  }
};

#endif
735
template <typename T>
L
Leo Chen 已提交
736 737
struct MergeAverage<phi::CPUContext, T> {
  phi::SelectedRows operator()(const phi::CPUContext& context,
738 739
                               const phi::SelectedRows& input) {
    phi::SelectedRows out;
740 741 742 743
    (*this)(context, input, &out);
    return out;
  }

L
Leo Chen 已提交
744
  void operator()(const phi::CPUContext& context,
745 746
                  const phi::SelectedRows& input,
                  phi::SelectedRows* output) {
747
    std::vector<const phi::SelectedRows*> inputs;
748 749 750 751
    inputs.push_back(&input);
    (*this)(context, inputs, output);
  }

L
Leo Chen 已提交
752
  void operator()(const phi::CPUContext& context,
753 754
                  const std::vector<const phi::SelectedRows*>& inputs,
                  phi::SelectedRows* output) {
755 756 757 758
    if (inputs.size() == 0) {
      VLOG(3) << "no input! return";
      return;
    }
759
    const phi::SelectedRows* has_value_input = nullptr;
760 761 762 763 764 765 766 767 768 769 770 771
    for (auto* in : inputs) {
      if (in->rows().size() > 0) {
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
      VLOG(3) << "no input has value! just return" << std::endl;
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
772
    phi::SelectedRows& out = *output;
773 774 775 776 777 778
    std::set<int64_t> merged_row_set;
    size_t row_num = 0;
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
779 780 781 782 783 784 785 786 787
      PADDLE_ENFORCE_EQ(
          input_width,
          input->value().dims()[1],
          phi::errors::InvalidArgument("All inputs should have same "
                                       "dimension except for the first one."));
      PADDLE_ENFORCE_EQ(
          input_height,
          input->height(),
          phi::errors::InvalidArgument("All input should have same height."));
788 789 790 791 792 793
      row_num += input->rows().size();
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }

    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
794
        phi::make_ddim(
795 796 797 798 799 800 801 802 803 804
            {static_cast<int64_t>(merged_row_set.size()), input_width}),
        context.GetPlace());
    auto* out_data = out.mutable_value()->data<T>();

    std::vector<int64_t> merge_rows(merged_row_set.begin(),
                                    merged_row_set.end());
    std::sort(merge_rows.begin(), merge_rows.end());

    out.set_rows(merge_rows);

L
Leo Chen 已提交
805
    phi::funcs::SetConstant<phi::CPUContext, T> constant_functor;
806 807 808 809 810 811 812
    constant_functor(context, out.mutable_value(), 0.0);

    std::unordered_map<int64_t, size_t> rows_to_id;
    for (size_t i = 0; i < merge_rows.size(); ++i) {
      rows_to_id[merge_rows[i]] = i;
    }

L
Leo Chen 已提交
813
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
814 815 816 817 818 819 820 821 822
    for (auto* input : inputs) {
      if (input->rows().size() == 0) {
        continue;
      }
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();

      for (size_t i = 0; i < input_rows.size(); i++) {
        size_t out_i = rows_to_id[input_rows[i]];
823 824
        elementwise_add_to<T>(&blas,
                              static_cast<size_t>(input_width),
825 826
                              &input_data[i * input_width],
                              &out_data[out_i * input_width]);
827 828 829 830 831 832 833 834 835 836 837 838
      }
    }
    size_t input_width_cast = static_cast<size_t>(input_width);
    T count = static_cast<T>(inputs.size());
    for (size_t i = 0; i < merge_rows.size(); i++) {
      for (size_t j = 0; j < input_width_cast; j++) {
        out_data[i * input_width + j] = out_data[i * input_width + j] / count;
      }
    }
  }
};

839
#ifdef PADDLE_WITH_XPU
840
template struct MergeAdd<phi::XPUContext, float>;
841 842
#endif

L
Leo Chen 已提交
843 844 845 846
template struct MergeAverage<phi::CPUContext, int>;
template struct MergeAverage<phi::CPUContext, int64_t>;
template struct MergeAverage<phi::CPUContext, float>;
template struct MergeAverage<phi::CPUContext, double>;
847

T
wip  
typhoonzero 已提交
848
template <typename T>
L
Leo Chen 已提交
849 850
struct UpdateToTensor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& context,
851 852
                  const ScatterOps& op,
                  const phi::SelectedRows& input1,
853
                  phi::DenseTensor* input2) {
T
wip  
typhoonzero 已提交
854
    auto in1_height = input1.height();
855
    const auto& in2_dims = input2->dims();
856
    PADDLE_ENFORCE_EQ(
857 858
        in1_height,
        in2_dims[0],
859 860 861 862 863
        phi::errors::InvalidArgument("The two inputs height must be equal."
                                     "But received first input height = "
                                     "[%d], second input height = [%d]",
                                     in1_height,
                                     in2_dims[0]));
T
wip  
typhoonzero 已提交
864 865 866 867 868

    auto& in1_value = input1.value();
    auto& in1_rows = input1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
869
    PADDLE_ENFORCE_EQ(
870 871
        in1_row_numel,
        input2->numel() / in1_height,
872
        phi::errors::InvalidArgument(
873
            "The two inputs width must be equal."
874
            "But received first input width = [%d], second input width = [%d]",
875 876
            in1_row_numel,
            input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920

    auto* in1_data = in1_value.data<T>();
    auto* input2_data = input2->data<T>();

    // FIXME(typhoonzero): use macro fix the below messy code.
    switch (op) {
      case ScatterOps::ASSIGN:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::ADD:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] +=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUB:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] -=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::SUBBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] -
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
      case ScatterOps::MUL:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] *=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIV:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] /=
            in1_data[i * in1_row_numel + j];
        break;
      case ScatterOps::DIVBY:
        INLINE_FOR2(in1_rows.size(), in1_row_numel)
        input2_data[in1_rows[i] * in1_row_numel + j] =
            in1_data[i * in1_row_numel + j] /
            input2_data[in1_rows[i] * in1_row_numel + j];
        break;
    }
T
typhoonzero 已提交
921 922 923 924
  }
};

}  // namespace scatter
925 926
}  // namespace funcs
}  // namespace phi