selected_rows_functor.cu 21.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15
#include <set>
16
#include <vector>
T
typhoonzero 已提交
17

Y
Yi Wang 已提交
18 19
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduo 已提交
21
#include "paddle/fluid/platform/float16.h"
22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
27 28
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
29 30 31 32
                  const framework::SelectedRows& input1,
                  const framework::SelectedRows& input2,
                  framework::SelectedRows* output) {
    auto in1_height = input1.height();
33 34 35 36 37 38
    PADDLE_ENFORCE_EQ(
        in1_height, input2.height(),
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height  = "
                                          "[%d], second input height = [%d]",
                                          in1_height, input2.height()));
39 40
    output->set_height(in1_height);

D
dzhwinter 已提交
41
    framework::Vector<int64_t> in1_rows(input1.rows());
42 43 44 45 46 47 48 49 50 51 52 53 54 55
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
56 57 58 59 60 61 62 63 64 65 66 67
    PADDLE_ENFORCE_EQ(
        in1_row_numel, in2_value.numel() / in2_rows.size(),
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, in2_value.numel() / in2_rows.size()));
    PADDLE_ENFORCE_EQ(
        in1_row_numel, out_value->numel() / out_rows.size(),
        platform::errors::InvalidArgument(
            "The input and oupput width must be equal."
            "But recieved input width = [%d], output width = [%d]",
            in1_row_numel, out_value->numel() / out_rows.size()));
68 69 70 71 72

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
73 74 75
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
76
    auto in2_place = input2.place();
77 78 79
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in2_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
80
    auto out_place = context.GetPlace();
81 82 83
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
84

85 86
    memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, out_place), out_data,
                 BOOST_GET_CONST(platform::CUDAPlace, in1_place), in1_data,
87
                 in1_value.numel() * sizeof(T), context.stream());
88 89

    auto* in2_data = in2_value.data<T>();
90
    memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, out_place),
Q
QI JUN 已提交
91
                 out_data + in1_value.numel(),
92
                 BOOST_GET_CONST(platform::CUDAPlace, in2_place), in2_data,
Q
QI JUN 已提交
93
                 in2_value.numel() * sizeof(T), context.stream());
94 95 96
  }
};

Q
QI JUN 已提交
97 98
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
99 100

namespace {
Q
QI JUN 已提交
101
template <typename T, int block_size>
102 103
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
                                            const int64_t* rows, T* tensor_out,
Q
QI JUN 已提交
104
                                            int64_t row_numel) {
C
chengduo 已提交
105
  const int ty = blockIdx.x;
106 107 108 109 110 111 112 113 114
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
115
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
116 117 118 119 120
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
121 122
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
123 124 125 126 127
                  const framework::SelectedRows& input1,
                  const framework::Tensor& input2, framework::Tensor* output) {
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
128 129 130 131 132 133 134 135 136 137 138 139
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument(
            "The two inputs height must be equal."
            "But recieved first input height = [%d], first input height = [%d]",
            in1_height, in2_dims[0]));
    PADDLE_ENFORCE_EQ(
        in1_height, out_dims[0],
        platform::errors::InvalidArgument(
            "The input and output height must be equal."
            "But recieved input height = [%d], output height = [%d]",
            in1_height, out_dims[0]));
140 141

    auto& in1_value = input1.value();
142
    auto& in1_rows = input1.rows();
143 144

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
145 146 147 148 149 150 151 152 153 154 155 156
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2.numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2.numel() / in1_height));
    PADDLE_ENFORCE_EQ(
        in1_row_numel, output->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The input and output width must be equal."
            "But recieved input width = [%d], output width = [%d]",
            in1_row_numel, output->numel() / in1_height));
157 158 159 160 161

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

Q
QI JUN 已提交
162
    SetConstant<platform::CUDADeviceContext, T> functor;
C
chengduo 已提交
163
    functor(context, output, static_cast<T>(0));
164

Q
QI JUN 已提交
165
    const int block_size = 256;
166
    dim3 threads(block_size, 1);
C
chengduo 已提交
167
    dim3 grid(in1_rows.size(), 1);
Q
QI JUN 已提交
168 169
    SelectedRowsAddTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
170 171
        in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
        in1_row_numel);
172 173 174

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
175
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
176 177 178
  }
};

Q
QI JUN 已提交
179 180
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
C
chengduo 已提交
181 182 183
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
                                      platform::float16>;
Q
QI JUN 已提交
184 185

template <typename T>
Q
QI JUN 已提交
186 187
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
188 189 190 191
                  const framework::SelectedRows& input1,
                  const int64_t input2_offset,
                  framework::SelectedRows* input2) {
    auto in1_height = input1.height();
192 193 194 195 196 197
    PADDLE_ENFORCE_EQ(
        in1_height, input2->height(),
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, input2->height()));
Q
QI JUN 已提交
198

199
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
200 201 202 203 204 205
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
Y
Fix CI  
Yu Yang 已提交
206 207 208
    if (in1_rows.size()) {
      in2_rows.Extend(in1_rows.begin(), in1_rows.end());
    }
Q
QI JUN 已提交
209 210

    auto in1_place = input1.place();
211 212 213
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
Q
QI JUN 已提交
214
    auto in2_place = input2->place();
215 216 217
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
Q
QI JUN 已提交
218 219 220

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
221
    memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, in2_place),
Q
QI JUN 已提交
222
                 in2_data + input2_offset,
223
                 BOOST_GET_CONST(platform::CUDAPlace, in1_place), in1_data,
Q
QI JUN 已提交
224
                 in1_value.numel() * sizeof(T), context.stream());
Q
QI JUN 已提交
225 226 227
  }
};

Q
QI JUN 已提交
228 229 230 231
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
232 233
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
                                  platform::float16>;
Q
QI JUN 已提交
234 235 236 237 238 239 240

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
C
chengduo 已提交
241
  const int ty = blockIdx.x;
Q
QI JUN 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
256 257
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
QI JUN 已提交
258 259 260 261
                  const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
262 263 264 265 266 267
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, in2_dims[0]));
Q
QI JUN 已提交
268 269

    auto& in1_value = input1.value();
270
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
271 272

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
273 274 275 276 277 278
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2->numel() / in1_height));
Q
QI JUN 已提交
279 280 281 282 283

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
C
chengduo 已提交
284
    dim3 grid(in1_rows.size(), 1);
Q
QI JUN 已提交
285 286
    SelectedRowsAddToTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
287 288
        in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
        in1_row_numel);
Q
QI JUN 已提交
289 290 291
  }
};

Q
QI JUN 已提交
292 293 294 295
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
296 297
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
                                        platform::float16>;
T
typhoonzero 已提交
298 299 300 301 302 303 304

namespace scatter {

template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
                               T* out, const int64_t* out_rows,
                               size_t out_rows_size, int64_t row_numel) {
S
sneaxiy 已提交
305
  const int ty = blockIdx.x;
T
typhoonzero 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

template <typename T>
T
typhoonzero 已提交
327 328
struct MergeAdd<platform::CUDADeviceContext, T> {
  framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
329 330
                                     const framework::SelectedRows& input,
                                     const bool sorted_result = false) {
T
wip  
typhoonzero 已提交
331
    framework::SelectedRows out;
S
sneaxiy 已提交
332 333 334 335 336 337
    (*this)(context, input, &out);
    return out;
  }

  void operator()(const platform::CUDADeviceContext& context,
                  const framework::SelectedRows& input,
M
minqiyang 已提交
338 339
                  framework::SelectedRows* output,
                  const bool sorted_result = false) {
D
dzhwinter 已提交
340
    framework::Vector<int64_t> input_rows(input.rows());
Q
Qiao Longfei 已提交
341 342 343 344 345
    if (input_rows.size() == 0) {
      return;
    }

    framework::SelectedRows& out = *output;
T
typhoonzero 已提交
346
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
Q
Qiao Longfei 已提交
347 348
    std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
T
typhoonzero 已提交
349 350

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
351 352 353 354

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
T
typhoonzero 已提交
355 356 357 358 359
        framework::make_ddim(
            {static_cast<int64_t>(merge_rows.size()), input_width}),
        context.GetPlace());

    math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
C
chengduo 已提交
360
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
T
typhoonzero 已提交
361

T
wip  
typhoonzero 已提交
362
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
363 364 365 366
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
S
sneaxiy 已提交
367
    dim3 grid1(input_rows.size(), 1);
T
typhoonzero 已提交
368

S
sneaxiy 已提交
369
    MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
370 371 372
        input_data, input_rows.CUDAData(context.GetPlace()), out_data,
        out.mutable_rows()->CUDAMutableData(context.GetPlace()),
        out.rows().size(), input_width);
T
typhoonzero 已提交
373
  }
374 375 376

  void operator()(const platform::CUDADeviceContext& context,
                  const std::vector<const framework::SelectedRows*>& inputs,
M
minqiyang 已提交
377 378
                  framework::SelectedRows* output,
                  const bool sorted_result = false) {
379
    if (inputs.size() == 0) {
M
minqiyang 已提交
380
      VLOG(3) << "no input! return";
381 382 383 384
      return;
    }
    const framework::SelectedRows* has_value_input = nullptr;
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
385
      if (in->rows().size() > 0) {
386 387 388 389 390
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
391
      VLOG(3) << "no input has value! just return" << std::endl;
392 393 394 395
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
396 397 398
    framework::SelectedRows& out = *output;
    std::set<int64_t> merged_row_set;
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
399
      if (input->rows().size() == 0) {
400 401
        continue;
      }
402
      PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
403 404 405
                        platform::errors::InvalidArgument(
                            "All input should have same "
                            "dimension except for the first one."));
406
      PADDLE_ENFORCE_EQ(input_height, input->height(),
407 408
                        platform::errors::InvalidArgument(
                            "All input should have same height."));
409 410
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
Q
Qiao Longfei 已提交
411
    std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
Q
format  
Qiao Longfei 已提交
412
                                        merged_row_set.end());
Q
Qiao Longfei 已提交
413
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
414 415 416 417 418 419 420 421

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
        framework::make_ddim(
            {static_cast<int64_t>(merge_rows.size()), input_width}),
        context.GetPlace());

Q
Qiao Longfei 已提交
422
    math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
C
chengduo 已提交
423
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
424 425 426 427 428 429 430

    auto* out_data = out.mutable_value()->data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);

    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
431
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
432 433
        continue;
      }
434 435
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();
436 437 438 439 440 441 442 443
      dim3 grid1(input_rows.size(), 1);

      MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
          input_data, input_rows.CUDAData(context.GetPlace()), out_data,
          out.mutable_rows()->CUDAMutableData(context.GetPlace()),
          out.rows().size(), input_width);
    }
  }
T
typhoonzero 已提交
444 445
};

T
typhoonzero 已提交
446 447 448 449
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
450
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
451 452 453
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext,
                         platform::complex<double>>;
T
wip  
typhoonzero 已提交
454 455 456 457 458

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
                                     const int64_t* rows, const ScatterOps& op,
                                     T* tensor_out, int64_t row_numel) {
C
chengduo 已提交
459
  const int ty = blockIdx.x;
T
wip  
typhoonzero 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
505 506 507 508
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const ScatterOps& op, const framework::SelectedRows& input1,
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
509 510
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
511 512
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
513 514 515

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
516 517 518 519 520 521
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, in2_dims[0]));
T
wip  
typhoonzero 已提交
522 523 524 525 526

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
527 528 529 530 531 532
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
533

T
typhoonzero 已提交
534 535
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
536

T
typhoonzero 已提交
537
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
C
chengduo 已提交
538
    dim3 grid(in1_rows.size(), 1);
T
typhoonzero 已提交
539
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
D
dzhwinter 已提交
540 541
        grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
                                              op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
542 543
  }
};
T
typhoonzero 已提交
544
}  // namespace scatter
545 546 547
}  // namespace math
}  // namespace operators
}  // namespace paddle