fused_seqpool_cvm_op.cu 24.2 KB
Newer Older
D
danleifeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
16

D
danleifeng 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/fused/fused_seqpool_cvm_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"

namespace paddle {
namespace operators {

template <typename T>
using Vector = framework::Vector<T>;

#define CUDA_KERNEL_LOOP(i, n)                                  \
  for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

// normal
template <typename T>
34 35
__global__ void FusedSeqpoolKernelNormal(const size_t N,
                                         T **input_values,
D
danleifeng 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
                                         T **seqpool_output_values,
                                         size_t **lods_values,
                                         const int batch_size,
                                         const int embedding_size,
                                         const float pad_value) {
  CUDA_KERNEL_LOOP(i, N) {
    int key = i / embedding_size;
    int offset = i % embedding_size;
    int x = key / batch_size;  // slot id
    int y = key % batch_size;  // ins id
    auto &start = *(lods_values[x] + y);
    auto &end = *(lods_values[x] + y + 1);

    T val = static_cast<T>(pad_value);
    for (auto k = start; k < end; ++k) {
      val += *(input_values[x] + k * embedding_size + offset);
    }
    *(seqpool_output_values[x] + y * embedding_size + offset) = val;
  }
}

// join need show click input
template <typename T>
59 60
__global__ void FusedCVMKernelWithCVM(const size_t N,
                                      T **output_values,
D
danleifeng 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
                                      T **seqpool_output_values,
                                      const int batch_size,
                                      const int embedding_size,
                                      const int cvm_offset) {
  CUDA_KERNEL_LOOP(i, N) {
    int key = i / embedding_size;
    int offset = i % embedding_size;
    int x = key / batch_size;  // slot id
    int y = key % batch_size;  // ins id
    if (offset == 0) {         // show
      *(output_values[x] + y * embedding_size) =
          log(*(seqpool_output_values[x] + y * embedding_size) + 1);
    } else if (offset == 1) {  // click
      *(output_values[x] + y * embedding_size + offset) =
          log(*(seqpool_output_values[x] + y * embedding_size + 1) + 1) -
          log(*(seqpool_output_values[x] + y * embedding_size) + 1);
    } else {
      *(output_values[x] + y * embedding_size + offset) =
          *(seqpool_output_values[x] + y * embedding_size + offset);
    }
  }
}

// update not need show click input
template <typename T>
86 87
__global__ void FusedCVMKernelNoCVM(const size_t N,
                                    T **output_values,
D
danleifeng 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
                                    T **seqpool_output_values,
                                    const int batch_size,
                                    const int no_cvm_embedding_size,
                                    const int cvm_offset) {
  CUDA_KERNEL_LOOP(i, N) {
    int key = i / no_cvm_embedding_size;
    int offset = i % no_cvm_embedding_size;
    int x = key / batch_size;  // slot id
    int y = key % batch_size;  // ins id
    // no cvm
    *(output_values[x] + y * no_cvm_embedding_size + offset) =
        *(seqpool_output_values[x] + y * (no_cvm_embedding_size + cvm_offset) +
          offset + cvm_offset);
  }
}

template <typename T>
void FusedSeqpoolCVM(const framework::ExecutionContext
                         &ctx,  // const paddle::platform::Place &place,
                     const std::vector<const T *> &input_data,
                     const std::vector<T *> &output_data,
                     const std::vector<T *> &seqpool_output_data,
110 111 112 113 114 115
                     std::vector<const size_t *> lods,
                     const int batch_size,
                     const int slot_num,
                     const int embedding_size,
                     const float padding_value,
                     const bool use_cvm,
D
danleifeng 已提交
116 117 118 119 120 121 122 123 124 125 126 127
                     const int cvm_offset) {
  auto stream =
      ctx.template device_context<platform::CUDADeviceContext>().stream();
  auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
  size_t total_ptr_len = input_data.size() + output_data.size() +
                         seqpool_output_data.size() + lods.size();
  auto temp_ptr =
      memory::AllocShared(ctx.GetPlace(), total_ptr_len * sizeof(void *));
  void *ptr = temp_ptr->ptr();

#ifdef PADDLE_WITH_HIP
  T **gpu_input_values = reinterpret_cast<T **>(temp_ptr->ptr());
128 129
  platform::GpuMemcpyAsync(gpu_input_values,
                           input_data.data(),
D
danleifeng 已提交
130
                           input_data.size() * sizeof(T *),
131 132
                           hipMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
133 134
  T **gpu_output_values =
      reinterpret_cast<T **>(&gpu_input_values[input_data.size()]);
135 136
  platform::GpuMemcpyAsync(gpu_output_values,
                           output_data.data(),
D
danleifeng 已提交
137
                           output_data.size() * sizeof(T *),
138 139
                           hipMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
140 141
  T **gpu_seqpool_output_values =
      reinterpret_cast<T **>(&gpu_output_values[output_data.size()]);
142 143 144 145 146
  platform::GpuMemcpyAsync(gpu_seqpool_output_values,
                           seqpool_output_data.data(),
                           seqpool_output_data.size() * sizeof(T *),
                           hipMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
147 148
  size_t **lods_values = reinterpret_cast<size_t **>(
      &gpu_seqpool_output_values[seqpool_output_data.size()]);
149 150
  platform::GpuMemcpyAsync(lods_values,
                           lods.data(),
D
danleifeng 已提交
151
                           lods.size() * sizeof(size_t *),
152 153
                           hipMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
154 155
#else
  T **gpu_input_values = reinterpret_cast<T **>(temp_ptr->ptr());
156 157
  platform::GpuMemcpyAsync(gpu_input_values,
                           input_data.data(),
D
danleifeng 已提交
158
                           input_data.size() * sizeof(T *),
159 160
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
161 162
  T **gpu_output_values =
      reinterpret_cast<T **>(&gpu_input_values[input_data.size()]);
163 164
  platform::GpuMemcpyAsync(gpu_output_values,
                           output_data.data(),
D
danleifeng 已提交
165
                           output_data.size() * sizeof(T *),
166 167
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
168 169
  T **gpu_seqpool_output_values =
      reinterpret_cast<T **>(&gpu_output_values[output_data.size()]);
170 171 172 173 174
  platform::GpuMemcpyAsync(gpu_seqpool_output_values,
                           seqpool_output_data.data(),
                           seqpool_output_data.size() * sizeof(T *),
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
175 176
  size_t **lods_values = reinterpret_cast<size_t **>(
      &gpu_seqpool_output_values[seqpool_output_data.size()]);
177 178
  platform::GpuMemcpyAsync(lods_values,
                           lods.data(),
D
danleifeng 已提交
179
                           lods.size() * sizeof(size_t *),
180 181
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
182 183 184
#endif

  size_t N = static_cast<size_t>(batch_size * slot_num * embedding_size);
185
  platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(dev_ctx, N);
D
danleifeng 已提交
186
  // first sum pool
187 188 189 190 191 192 193 194 195 196
  FusedSeqpoolKernelNormal<<<config.block_per_grid.x,
                             config.thread_per_block.x,
                             0,
                             stream>>>(N,
                                       gpu_input_values,
                                       gpu_seqpool_output_values,
                                       lods_values,
                                       batch_size,
                                       embedding_size,
                                       padding_value);
D
danleifeng 已提交
197 198
  // second log
  if (use_cvm) {
199 200 201 202 203 204 205 206 207
    FusedCVMKernelWithCVM<<<config.block_per_grid.x,
                            config.thread_per_block.x,
                            0,
                            stream>>>(N,
                                      gpu_output_values,
                                      gpu_seqpool_output_values,
                                      batch_size,
                                      embedding_size,
                                      cvm_offset);
D
danleifeng 已提交
208 209 210 211
  } else {
    // not need show click input
    N = static_cast<size_t>(batch_size * slot_num *
                            (embedding_size - cvm_offset));
212 213
    platform::GpuLaunchConfig config =
        platform::GetGpuLaunchConfig1D(dev_ctx, N);
214 215 216 217 218 219 220 221 222
    FusedCVMKernelNoCVM<<<config.block_per_grid.x,
                          config.thread_per_block.x,
                          0,
                          stream>>>(N,
                                    gpu_output_values,
                                    gpu_seqpool_output_values,
                                    batch_size,
                                    (embedding_size - cvm_offset),
                                    cvm_offset);
D
danleifeng 已提交
223 224 225 226 227
  }
}

// join grad
template <typename T>
228 229 230 231 232 233 234 235
__global__ void FusedSeqpoolCVMGradKernelWithCVM(const size_t N,
                                                 T **out_grads_values,
                                                 T **in_grads_values,
                                                 T **cvm_values,
                                                 size_t **lods_values,
                                                 const int batch_size,
                                                 const int embedding_size,
                                                 const int cvm_offset) {
D
danleifeng 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
  CUDA_KERNEL_LOOP(i, N) {
    int key = i / embedding_size;
    int offset = i % embedding_size;  // embedx offset
    int x = key / batch_size;         // slot id
    int y = key % batch_size;         // ins id

    T &val = (offset < cvm_offset)
                 ? *(cvm_values[x] + y * cvm_offset + offset)
                 : *(out_grads_values[x] + y * embedding_size + offset);

    auto &start = *(lods_values[x] + y);
    auto &end = *(lods_values[x] + y + 1);
    for (auto k = start; k < end; ++k) {
      *(in_grads_values[x] + k * embedding_size + offset) = val;
    }
  }
}

// join only show not has click
template <typename T>
256 257 258 259 260 261 262 263
__global__ void FusedSeqpoolCVMGradKernelWithShow(const size_t N,
                                                  T **out_grads_values,
                                                  T **in_grads_values,
                                                  T **cvm_values,
                                                  size_t **lods_values,
                                                  const int batch_size,
                                                  const int embedding_size,
                                                  const int cvm_offset) {
D
danleifeng 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
  CUDA_KERNEL_LOOP(i, N) {
    int key = i / embedding_size;
    int offset = i % embedding_size;  // embedx offset
    int x = key / batch_size;         // slot id
    int y = key % batch_size;         // ins id

    T &val =
        (offset < cvm_offset)
            ? *(cvm_values[x] + y * cvm_offset + offset)
            : *(out_grads_values[x] + y * (embedding_size - 1) + offset - 1);

    auto &start = *(lods_values[x] + y);
    auto &end = *(lods_values[x] + y + 1);
    for (auto k = start; k < end; ++k) {
      *(in_grads_values[x] + k * embedding_size + offset) = val;
    }
  }
}

// update grad
template <typename T>
285 286 287 288 289 290 291 292
__global__ void FusedSeqpoolCVMGradKernelNoCVM(const size_t N,
                                               T **out_grads_values,
                                               T **in_grads_values,
                                               T **cvm_values,
                                               size_t **lods_values,
                                               const int batch_size,
                                               const int embedding_size,
                                               const int cvm_offset) {
D
danleifeng 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
  CUDA_KERNEL_LOOP(i, N) {
    int key = i / embedding_size;
    int offset = i % embedding_size;  // embedx offset
    int x = key / batch_size;         // slot id
    int y = key % batch_size;         // ins id

    T &val = (offset < cvm_offset)
                 ? *(cvm_values[x] + y * cvm_offset + offset)
                 : *(out_grads_values[x] + y * (embedding_size - cvm_offset) +
                     offset - cvm_offset);

    auto &start = *(lods_values[x] + y);
    auto &end = *(lods_values[x] + y + 1);
    for (auto k = start; k < end; ++k) {
      *(in_grads_values[x] + k * embedding_size + offset) = val;
    }
  }
}

template <typename T>
void FusedSeqpoolCVMGrad(const framework::ExecutionContext &ctx,
                         const std::vector<const T *> &out_grads_data,
                         const std::vector<T *> &in_grads_data,
                         const std::vector<const T *> &cvm_data,
                         const std::vector<const size_t *> &lods,
318 319 320 321
                         const int batch_size,
                         const int slot_num,
                         const int embedding_size,
                         const bool use_cvm,
D
danleifeng 已提交
322 323 324 325 326 327 328 329 330 331
                         const int cvm_offset) {
  auto stream =
      ctx.template device_context<platform::CUDADeviceContext>().stream();
  auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
  size_t total_ptr_len = out_grads_data.size() + in_grads_data.size() +
                         cvm_data.size() + lods.size();
  auto temp_ptr =
      memory::AllocShared(ctx.GetPlace(), total_ptr_len * sizeof(void *));
#ifdef PADDLE_WITH_HIP
  T **gpu_out_grads_values = reinterpret_cast<T **>(temp_ptr->ptr());
332 333
  platform::GpuMemcpyAsync(gpu_out_grads_values,
                           out_grads_data.data(),
D
danleifeng 已提交
334
                           out_grads_data.size() * sizeof(T *),
335 336
                           hipMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
337 338 339

  T **gpu_in_grads_values =
      reinterpret_cast<T **>(&gpu_out_grads_values[out_grads_data.size()]);
340 341
  platform::GpuMemcpyAsync(gpu_in_grads_values,
                           in_grads_data.data(),
D
danleifeng 已提交
342
                           in_grads_data.size() * sizeof(T *),
343 344
                           hipMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
345 346 347

  T **gpu_cvm_values =
      reinterpret_cast<T **>(&gpu_in_grads_values[in_grads_data.size()]);
348 349 350 351
  platform::GpuMemcpyAsync(gpu_cvm_values,
                           cvm_data.data(),
                           cvm_data.size() * sizeof(T *),
                           hipMemcpyHostToDevice,
D
danleifeng 已提交
352 353 354 355
                           stream);

  size_t **lods_values =
      reinterpret_cast<size_t **>(&gpu_cvm_values[cvm_data.size()]);
356 357
  platform::GpuMemcpyAsync(lods_values,
                           lods.data(),
D
danleifeng 已提交
358
                           lods.size() * sizeof(size_t *),
359 360
                           hipMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
361 362
#else
  T **gpu_out_grads_values = reinterpret_cast<T **>(temp_ptr->ptr());
363 364
  platform::GpuMemcpyAsync(gpu_out_grads_values,
                           out_grads_data.data(),
D
danleifeng 已提交
365
                           out_grads_data.size() * sizeof(T *),
366 367
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
368 369 370

  T **gpu_in_grads_values =
      reinterpret_cast<T **>(&gpu_out_grads_values[out_grads_data.size()]);
371 372
  platform::GpuMemcpyAsync(gpu_in_grads_values,
                           in_grads_data.data(),
D
danleifeng 已提交
373
                           in_grads_data.size() * sizeof(T *),
374 375
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
376 377 378

  T **gpu_cvm_values =
      reinterpret_cast<T **>(&gpu_in_grads_values[in_grads_data.size()]);
379 380
  platform::GpuMemcpyAsync(gpu_cvm_values,
                           cvm_data.data(),
D
danleifeng 已提交
381
                           cvm_data.size() * sizeof(T *),
382 383
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
384 385 386

  size_t **lods_values =
      reinterpret_cast<size_t **>(&gpu_cvm_values[cvm_data.size()]);
387 388
  platform::GpuMemcpyAsync(lods_values,
                           lods.data(),
D
danleifeng 已提交
389
                           lods.size() * sizeof(size_t *),
390 391
                           cudaMemcpyHostToDevice,
                           stream);
D
danleifeng 已提交
392 393 394
#endif

  size_t N = static_cast<size_t>(batch_size * slot_num * embedding_size);
395
  auto config = platform::GetGpuLaunchConfig1D(dev_ctx, N);
D
danleifeng 已提交
396 397 398
  if (use_cvm) {
    // join grad
    FusedSeqpoolCVMGradKernelWithCVM<<<config.block_per_grid.x,
399 400 401 402 403 404 405 406 407 408
                                       config.thread_per_block.x,
                                       0,
                                       stream>>>(N,
                                                 gpu_out_grads_values,
                                                 gpu_in_grads_values,
                                                 gpu_cvm_values,
                                                 lods_values,
                                                 batch_size,
                                                 embedding_size,
                                                 cvm_offset);
D
danleifeng 已提交
409 410 411
  } else {
    // update grad
    FusedSeqpoolCVMGradKernelNoCVM<<<config.block_per_grid.x,
412 413 414 415 416 417 418 419 420 421
                                     config.thread_per_block.x,
                                     0,
                                     stream>>>(N,
                                               gpu_out_grads_values,
                                               gpu_in_grads_values,
                                               gpu_cvm_values,
                                               lods_values,
                                               batch_size,
                                               embedding_size,
                                               cvm_offset);
D
danleifeng 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
  }
}

template <typename T>
class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto inputs = ctx.MultiInput<LoDTensor>("X");
    auto outputs = ctx.MultiOutput<framework::Tensor>("Out");

    const auto slot_size = inputs.size();
    std::vector<const float *> input_data(slot_size);
    std::vector<const size_t *> lods_data(slot_size);
    std::vector<T *> output_data(slot_size);

    std::vector<LoDTensor> seqpool_outputs(slot_size);
    std::vector<T *> seqpool_output_data(slot_size);

    auto padding_value = ctx.Attr<float>("pad_value");
    auto use_cvm = ctx.Attr<bool>("use_cvm");
    const int cvm_offset = ctx.Attr<int>("cvm_offset");

    int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0];
    int batch_size = -1;
    std::vector<paddle::framework::MixVector<size_t> *> mix_lods_v(slot_size);

    for (size_t i = 0; i < slot_size; ++i) {
      const auto *input = inputs[i];

      Vector<size_t> lods;
      if (input->lod().size() != 0) {
        auto lod = input->lod();
        lods = lod[0];
      } else {
        lods.push_back(0);
        for (int i = 0; i < input->dims()[0]; i++) {
          lods.push_back(i + 1);
        }
      }
      int cur_batch_size =
          input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
      if (batch_size == -1) {
        batch_size = cur_batch_size;
      } else {
466 467
        PADDLE_ENFORCE_EQ(batch_size,
                          cur_batch_size,
D
danleifeng 已提交
468 469 470 471
                          platform::errors::PreconditionNotMet(
                              "The batch size of all input should be same, "
                              "please cheack, last batchsize is %d, current "
                              "batchsize is %d",
472 473
                              batch_size,
                              cur_batch_size));
D
danleifeng 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
      }
      input_data[i] = reinterpret_cast<const T *>(input->data<T>());

      auto *output = outputs[i];
      if (use_cvm) {
        output->Resize({batch_size, embedding_size});
      } else {
        output->Resize({batch_size, embedding_size - cvm_offset});
      }
      output_data[i] =
          reinterpret_cast<T *>(output->mutable_data<T>(ctx.GetPlace()));
      mix_lods_v[i] = new paddle::framework::MixVector<size_t>(&lods);
      lods_data[i] = mix_lods_v[i]->CUDAData(ctx.GetPlace());
      seqpool_output_data[i] =
          reinterpret_cast<T *>(seqpool_outputs[i].mutable_data<T>(
              {batch_size, embedding_size}, ctx.GetPlace()));
    }

492 493 494 495 496 497 498 499 500 501 502
    FusedSeqpoolCVM(ctx,
                    input_data,
                    output_data,
                    seqpool_output_data,
                    lods_data,
                    batch_size,
                    slot_size,
                    embedding_size,
                    padding_value,
                    use_cvm,
                    cvm_offset);
D
danleifeng 已提交
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550

    for (int i = 0; i < slot_size; i++) {
      delete mix_lods_v[i];
    }
  }
};

template <typename T>
class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto out_grads = ctx.MultiInput<LoDTensor>(framework::GradVarName("Out"));
    auto in_grads = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
    auto *cvm = ctx.Input<LoDTensor>("CVM");

    std::string pooltype = ctx.Attr<std::string>("pooltype");
    auto use_cvm = ctx.Attr<bool>("use_cvm");
    const int cvm_offset = ctx.Attr<int>("cvm_offset");

    const auto slot_size = in_grads.size();
    std::vector<const T *> out_grads_data(slot_size);
    std::vector<T *> in_grads_data(slot_size);
    std::vector<const T *> cvm_data(slot_size);
    std::vector<const size_t *> lods_data(slot_size);

    int embedding_size = in_grads[0]->numel() / in_grads[0]->dims()[0];
    int batch_size = -1;
    std::vector<paddle::framework::MixVector<size_t> *> mix_lods_v(slot_size);

    for (size_t i = 0; i < slot_size; ++i) {
      auto *in_grad = in_grads[i];

      Vector<size_t> lods;
      if (in_grad->lod().size() != 0) {
        auto lod = in_grad->lod();
        lods = lod[0];
      } else {
        lods.push_back(0);
        for (int i = 0; i < in_grad->dims()[0]; i++) {
          lods.push_back(i + 1);
        }
      }

      int cur_batch_size = in_grad->lod().size() ? in_grad->lod()[0].size() - 1
                                                 : in_grad->dims()[0];
      if (batch_size == -1) {
        batch_size = cur_batch_size;
      } else {
551 552
        PADDLE_ENFORCE_EQ(batch_size,
                          cur_batch_size,
D
danleifeng 已提交
553 554 555 556
                          platform::errors::PreconditionNotMet(
                              "The batch size of all input should be same, "
                              "please cheack, last batchsize is %d, current "
                              "batchsize is %d",
557 558
                              batch_size,
                              cur_batch_size));
D
danleifeng 已提交
559 560 561 562 563 564 565 566 567 568 569
      }

      auto *out_grad = out_grads[i];
      out_grads_data[i] = reinterpret_cast<const T *>(out_grad->data<T>());

      in_grads_data[i] =
          reinterpret_cast<T *>(in_grad->mutable_data<T>(ctx.GetPlace()));
      mix_lods_v[i] = new paddle::framework::MixVector<size_t>(&lods);
      lods_data[i] = mix_lods_v[i]->CUDAData(ctx.GetPlace());
      cvm_data[i] = reinterpret_cast<const T *>(cvm->data<T>());
    }
570 571 572 573 574 575 576 577 578
    FusedSeqpoolCVMGrad(ctx,
                        out_grads_data,
                        in_grads_data,
                        cvm_data,
                        lods_data,
                        batch_size,
                        slot_size,
                        embedding_size,
                        use_cvm,
D
danleifeng 已提交
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
                        cvm_offset);

    for (int i = 0; i < slot_size; i++) {
      delete mix_lods_v[i];
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm,
                        ops::FusedSeqpoolCVMCUDAKernel<float>);

REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_grad,
                        ops::FusedSeqpoolCVMGradCUDAKernel<float>);