selected_rows_functor.cu 21.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
typhoonzero 已提交
15
#include <set>
16
#include <vector>
T
typhoonzero 已提交
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/selected_rows_functor.h"
19
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
C
chengduo 已提交
20
#include "paddle/fluid/platform/float16.h"
21
#include "paddle/pten/kernels/funcs/math_function.h"
22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {
template <typename T>
Q
QI JUN 已提交
27 28
struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
29 30 31
                  const pten::SelectedRows& input1,
                  const pten::SelectedRows& input2,
                  pten::SelectedRows* output) {
32
    auto in1_height = input1.height();
33 34 35 36 37 38
    PADDLE_ENFORCE_EQ(
        in1_height, input2.height(),
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height  = "
                                          "[%d], second input height = [%d]",
                                          in1_height, input2.height()));
39 40
    output->set_height(in1_height);

D
dzhwinter 已提交
41
    framework::Vector<int64_t> in1_rows(input1.rows());
42 43 44 45 46 47 48 49 50 51 52 53 54 55
    auto& in2_rows = input2.rows();
    std::vector<int64_t> out_rows;
    out_rows.reserve(in1_rows.size() + in2_rows.size());

    // concat rows
    out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
    out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
    output->set_rows(out_rows);

    auto* out_value = output->mutable_value();
    auto& in1_value = input1.value();
    auto& in2_value = input2.value();

    auto in1_row_numel = in1_value.numel() / in1_rows.size();
56 57 58 59 60 61 62 63 64 65 66 67
    PADDLE_ENFORCE_EQ(
        in1_row_numel, in2_value.numel() / in2_rows.size(),
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, in2_value.numel() / in2_rows.size()));
    PADDLE_ENFORCE_EQ(
        in1_row_numel, out_value->numel() / out_rows.size(),
        platform::errors::InvalidArgument(
            "The input and oupput width must be equal."
            "But recieved input width = [%d], output width = [%d]",
            in1_row_numel, out_value->numel() / out_rows.size()));
68 69 70 71 72

    auto* out_data = out_value->data<T>();
    auto* in1_data = in1_value.data<T>();

    auto in1_place = input1.place();
73 74 75
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
76
    auto in2_place = input2.place();
77 78 79
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in2_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
80
    auto out_place = context.GetPlace();
81 82 83
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
84

85
    memory::Copy(out_place, out_data, in1_place, in1_data,
86
                 in1_value.numel() * sizeof(T), context.stream());
87 88

    auto* in2_data = in2_value.data<T>();
89
    memory::Copy(out_place, out_data + in1_value.numel(), in2_place, in2_data,
Q
QI JUN 已提交
90
                 in2_value.numel() * sizeof(T), context.stream());
91 92 93
  }
};

Q
QI JUN 已提交
94 95
template struct SelectedRowsAdd<platform::CUDADeviceContext, float>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, double>;
96 97

namespace {
Q
QI JUN 已提交
98
template <typename T, int block_size>
99 100
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
                                            const int64_t* rows, T* tensor_out,
Q
QI JUN 已提交
101
                                            int64_t row_numel) {
C
chengduo 已提交
102
  const int ty = blockIdx.x;
103 104 105 106 107 108 109 110 111
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we can not use
    // tensor_out[index] += selected_rows[index]; Instead, we have to use
    // AtomicAdd to avoid concurrent write error.
Q
qijun 已提交
112
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
113 114 115 116 117
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
118 119
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
120
                  const pten::SelectedRows& input1,
121 122 123 124
                  const framework::Tensor& input2, framework::Tensor* output) {
    auto in1_height = input1.height();
    auto in2_dims = input2.dims();
    auto out_dims = output->dims();
125 126 127 128 129 130 131 132 133 134 135 136
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument(
            "The two inputs height must be equal."
            "But recieved first input height = [%d], first input height = [%d]",
            in1_height, in2_dims[0]));
    PADDLE_ENFORCE_EQ(
        in1_height, out_dims[0],
        platform::errors::InvalidArgument(
            "The input and output height must be equal."
            "But recieved input height = [%d], output height = [%d]",
            in1_height, out_dims[0]));
137 138

    auto& in1_value = input1.value();
139
    auto& in1_rows = input1.rows();
140 141

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
142 143 144 145 146 147 148 149 150 151 152 153
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2.numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2.numel() / in1_height));
    PADDLE_ENFORCE_EQ(
        in1_row_numel, output->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The input and output width must be equal."
            "But recieved input width = [%d], output width = [%d]",
            in1_row_numel, output->numel() / in1_height));
154 155 156 157 158

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2.data<T>();
    auto* out_data = output->data<T>();

159
    pten::funcs::SetConstant<platform::CUDADeviceContext, T> functor;
C
chengduo 已提交
160
    functor(context, output, static_cast<T>(0));
161

Q
QI JUN 已提交
162
    const int block_size = 256;
163
    dim3 threads(block_size, 1);
C
chengduo 已提交
164
    dim3 grid(in1_rows.size(), 1);
Q
QI JUN 已提交
165 166
    SelectedRowsAddTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
167 168
        in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
        in1_row_numel);
169 170 171

    auto out_eigen = framework::EigenVector<T>::Flatten(*output);
    auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
Q
QI JUN 已提交
172
    out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen;
173 174 175
  }
};

Q
QI JUN 已提交
176 177
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
C
chengduo 已提交
178 179 180
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
                                      platform::float16>;
Q
QI JUN 已提交
181 182

template <typename T>
Q
QI JUN 已提交
183 184
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
185 186
                  const pten::SelectedRows& input1, const int64_t input2_offset,
                  pten::SelectedRows* input2) {
Q
QI JUN 已提交
187
    auto in1_height = input1.height();
188 189 190 191 192 193
    PADDLE_ENFORCE_EQ(
        in1_height, input2->height(),
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, input2->height()));
Q
QI JUN 已提交
194

195
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
196 197 198 199 200 201
    auto& in2_rows = *(input2->mutable_rows());

    auto& in1_value = input1.value();
    auto* in2_value = input2->mutable_value();

    // concat rows
Y
Fix CI  
Yu Yang 已提交
202 203 204
    if (in1_rows.size()) {
      in2_rows.Extend(in1_rows.begin(), in1_rows.end());
    }
Q
QI JUN 已提交
205 206

    auto in1_place = input1.place();
207 208 209
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
Q
QI JUN 已提交
210
    auto in2_place = input2->place();
211 212 213
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(in1_place), true,
                      platform::errors::InvalidArgument(
                          "The running enviroment is not on the GPU place."));
Q
QI JUN 已提交
214 215 216

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = in2_value->data<T>();
217
    memory::Copy(in2_place, in2_data + input2_offset, in1_place, in1_data,
Q
QI JUN 已提交
218
                 in1_value.numel() * sizeof(T), context.stream());
Q
QI JUN 已提交
219 220 221
  }
};

Q
QI JUN 已提交
222 223 224 225
template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
226 227
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
                                  platform::float16>;
Q
QI JUN 已提交
228 229 230 231 232 233 234

namespace {
template <typename T, int block_size>
__global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
                                              const int64_t* rows,
                                              T* tensor_out,
                                              int64_t row_numel) {
C
chengduo 已提交
235
  const int ty = blockIdx.x;
Q
QI JUN 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
    paddle::platform::CudaAtomicAdd(tensor_out + index, selected_rows[index]);
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
250 251
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
252
                  const pten::SelectedRows& input1, framework::Tensor* input2) {
Q
QI JUN 已提交
253 254
    auto in1_height = input1.height();
    auto in2_dims = input2->dims();
255 256 257 258 259 260
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, in2_dims[0]));
Q
QI JUN 已提交
261 262

    auto& in1_value = input1.value();
263
    auto& in1_rows = input1.rows();
Q
QI JUN 已提交
264 265

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
266 267 268 269 270 271
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2->numel() / in1_height));
Q
QI JUN 已提交
272 273 274 275 276

    auto* in1_data = in1_value.data<T>();
    auto* in2_data = input2->data<T>();
    const int block_size = 256;
    dim3 threads(block_size, 1);
C
chengduo 已提交
277
    dim3 grid(in1_rows.size(), 1);
Q
QI JUN 已提交
278 279
    SelectedRowsAddToTensorKernel<
        T, block_size><<<grid, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
280 281
        in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
        in1_row_numel);
Q
QI JUN 已提交
282 283 284
  }
};

Q
QI JUN 已提交
285 286 287 288
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
289 290
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
                                        platform::float16>;
T
typhoonzero 已提交
291 292 293 294 295 296 297

namespace scatter {

template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
                               T* out, const int64_t* out_rows,
                               size_t out_rows_size, int64_t row_numel) {
S
sneaxiy 已提交
298
  const int ty = blockIdx.x;
T
typhoonzero 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
  int tid = threadIdx.x;
  __shared__ size_t out_idx;

  if (tid == 0) {
    for (size_t i = 0; i < out_rows_size; i++) {
      if (input_rows[ty] == out_rows[i]) {
        out_idx = i;
      }
    }
  }

  __syncthreads();

  input += ty * row_numel;
  out += out_idx * row_numel;
  for (int index = tid; index < row_numel; index += block_size) {
    paddle::platform::CudaAtomicAdd(out + index, input[index]);
  }
}

template <typename T>
T
typhoonzero 已提交
320
struct MergeAdd<platform::CUDADeviceContext, T> {
321 322 323 324
  pten::SelectedRows operator()(const platform::CUDADeviceContext& context,
                                const pten::SelectedRows& input,
                                const bool sorted_result = false) {
    pten::SelectedRows out;
S
sneaxiy 已提交
325 326 327 328 329
    (*this)(context, input, &out);
    return out;
  }

  void operator()(const platform::CUDADeviceContext& context,
330
                  const pten::SelectedRows& input, pten::SelectedRows* output,
M
minqiyang 已提交
331
                  const bool sorted_result = false) {
D
dzhwinter 已提交
332
    framework::Vector<int64_t> input_rows(input.rows());
Q
Qiao Longfei 已提交
333 334 335 336
    if (input_rows.size() == 0) {
      return;
    }

337
    pten::SelectedRows& out = *output;
T
typhoonzero 已提交
338
    std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
Q
Qiao Longfei 已提交
339 340
    std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
T
typhoonzero 已提交
341 342

    auto input_width = input.value().dims()[1];
T
wip  
typhoonzero 已提交
343 344 345 346

    out.set_rows(merge_rows);
    out.set_height(input.height());
    out.mutable_value()->mutable_data<T>(
347
        pten::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
T
typhoonzero 已提交
348 349
        context.GetPlace());

350
    pten::funcs::SetConstant<platform::CUDADeviceContext, T> constant_functor;
C
chengduo 已提交
351
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
T
typhoonzero 已提交
352

T
wip  
typhoonzero 已提交
353
    auto* out_data = out.mutable_value()->data<T>();
T
typhoonzero 已提交
354 355 356 357
    auto* input_data = input.value().data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);
S
sneaxiy 已提交
358
    dim3 grid1(input_rows.size(), 1);
T
typhoonzero 已提交
359

S
sneaxiy 已提交
360
    MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
Y
Yu Yang 已提交
361 362 363
        input_data, input_rows.CUDAData(context.GetPlace()), out_data,
        out.mutable_rows()->CUDAMutableData(context.GetPlace()),
        out.rows().size(), input_width);
T
typhoonzero 已提交
364
  }
365 366

  void operator()(const platform::CUDADeviceContext& context,
367 368
                  const std::vector<const pten::SelectedRows*>& inputs,
                  pten::SelectedRows* output,
M
minqiyang 已提交
369
                  const bool sorted_result = false) {
370
    if (inputs.size() == 0) {
M
minqiyang 已提交
371
      VLOG(3) << "no input! return";
372 373
      return;
    }
374
    const pten::SelectedRows* has_value_input = nullptr;
375
    for (auto* in : inputs) {
Q
Qiao Longfei 已提交
376
      if (in->rows().size() > 0) {
377 378 379 380 381
        has_value_input = in;
        break;
      }
    }
    if (has_value_input == nullptr) {
M
minqiyang 已提交
382
      VLOG(3) << "no input has value! just return" << std::endl;
383 384 385 386
      return;
    }
    auto input_width = has_value_input->value().dims()[1];
    auto input_height = has_value_input->height();
387
    pten::SelectedRows& out = *output;
388 389
    std::set<int64_t> merged_row_set;
    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
390
      if (input->rows().size() == 0) {
391 392
        continue;
      }
393
      PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
394 395 396
                        platform::errors::InvalidArgument(
                            "All input should have same "
                            "dimension except for the first one."));
397
      PADDLE_ENFORCE_EQ(input_height, input->height(),
398 399
                        platform::errors::InvalidArgument(
                            "All input should have same height."));
400 401
      merged_row_set.insert(input->rows().begin(), input->rows().end());
    }
Q
Qiao Longfei 已提交
402
    std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
Q
format  
Qiao Longfei 已提交
403
                                        merged_row_set.end());
Q
Qiao Longfei 已提交
404
    framework::Vector<int64_t> merge_rows(merge_rows_cpu);
405 406 407 408

    out.set_rows(merge_rows);
    out.set_height(input_height);
    out.mutable_value()->mutable_data<T>(
409
        pten::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
410 411
        context.GetPlace());

412
    pten::funcs::SetConstant<platform::CUDADeviceContext, T> constant_functor;
C
chengduo 已提交
413
    constant_functor(context, out.mutable_value(), static_cast<T>(0));
414 415 416 417 418 419 420

    auto* out_data = out.mutable_value()->data<T>();

    const int block_size = 256;
    dim3 threads(block_size, 1);

    for (auto* input : inputs) {
Q
Qiao Longfei 已提交
421
      if (input->rows().size() == 0) {
Q
Qiao Longfei 已提交
422 423
        continue;
      }
424 425
      auto* input_data = input->value().data<T>();
      auto& input_rows = input->rows();
426 427 428 429 430 431 432 433
      dim3 grid1(input_rows.size(), 1);

      MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
          input_data, input_rows.CUDAData(context.GetPlace()), out_data,
          out.mutable_rows()->CUDAMutableData(context.GetPlace()),
          out.rows().size(), input_width);
    }
  }
T
typhoonzero 已提交
434 435
};

T
typhoonzero 已提交
436 437 438 439
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
C
chengduo 已提交
440
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
441 442 443
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext,
                         platform::complex<double>>;
T
wip  
typhoonzero 已提交
444 445 446 447 448

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
                                     const int64_t* rows, const ScatterOps& op,
                                     T* tensor_out, int64_t row_numel) {
C
chengduo 已提交
449
  const int ty = blockIdx.x;
T
wip  
typhoonzero 已提交
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;
  // FIXME(typhoonzero): use macro fix the below messy code.
  switch (op) {
    case ScatterOps::ASSIGN:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index];
      }
      break;
    case ScatterOps::ADD:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] += selected_rows[index];
      }
      break;
    case ScatterOps::SUB:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] -= selected_rows[index];
      }
      break;
    case ScatterOps::SUBBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] - tensor_out[index];
      }
      break;
    case ScatterOps::MUL:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] *= selected_rows[index];
      }
      break;
    case ScatterOps::DIV:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] /= selected_rows[index];
      }
      break;
    case ScatterOps::DIVBY:
      for (int index = tid; index < row_numel; index += block_size) {
        tensor_out[index] = selected_rows[index] / tensor_out[index];
      }
      break;
  }
}

template <typename T>
T
typhoonzero 已提交
495 496
struct UpdateToTensor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
497
                  const ScatterOps& op, const pten::SelectedRows& input1,
T
typhoonzero 已提交
498
                  framework::Tensor* input2) {
T
wip  
typhoonzero 已提交
499 500
    // NOTE: Use SelectedRowsAddToTensor for better performance
    //       no additional MergeAdd called.
T
typhoonzero 已提交
501 502
    MergeAdd<platform::CUDADeviceContext, T> merge_func;
    auto merged_in1 = merge_func(context, input1);
T
wip  
typhoonzero 已提交
503 504 505

    auto in1_height = merged_in1.height();
    auto in2_dims = input2->dims();
506 507 508 509 510 511
    PADDLE_ENFORCE_EQ(
        in1_height, in2_dims[0],
        platform::errors::InvalidArgument("The two inputs height must be equal."
                                          "But recieved first input height = "
                                          "[%d], second input height = [%d]",
                                          in1_height, in2_dims[0]));
T
wip  
typhoonzero 已提交
512 513 514 515 516

    auto& in1_value = merged_in1.value();
    auto& in1_rows = merged_in1.rows();

    int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
517 518 519 520 521 522
    PADDLE_ENFORCE_EQ(
        in1_row_numel, input2->numel() / in1_height,
        platform::errors::InvalidArgument(
            "The two inputs width must be equal."
            "But recieved first input width = [%d], second input width = [%d]",
            in1_row_numel, input2->numel() / in1_height));
T
wip  
typhoonzero 已提交
523

T
typhoonzero 已提交
524 525
    auto* in1_data = in1_value.template data<T>();
    auto* in2_data = input2->data<T>();
T
wip  
typhoonzero 已提交
526

T
typhoonzero 已提交
527
    dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
C
chengduo 已提交
528
    dim3 grid(in1_rows.size(), 1);
T
typhoonzero 已提交
529
    UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
D
dzhwinter 已提交
530 531
        grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
                                              op, in2_data, in1_row_numel);
T
wip  
typhoonzero 已提交
532 533
  }
};
T
typhoonzero 已提交
534
}  // namespace scatter
535 536 537
}  // namespace math
}  // namespace operators
}  // namespace paddle