margin_cross_entropy_op.cu 21.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_HIP
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#else
#include <cub/cub.cuh>
#endif

#include <vector>
23

24 25 26
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/margin_cross_entropy_op.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
27
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
28 29
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/string/string_helper.h"
30
#include "paddle/phi/api/include/tensor.h"
31
#include "paddle/phi/kernels/funcs/axis_utils.h"
32
#include "paddle/phi/kernels/funcs/math_function.h"
33 34

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
35
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
36
#include "paddle/fluid/platform/collective_helper.h"
37
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
#endif

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
}

53 54 55 56 57 58 59
void GetClassInterval(const gpuStream_t& stream,
                      const platform::Place& place,
                      const platform::DeviceContext& ctx,
                      const int rid,
                      const int rank,
                      const int nranks,
                      const int D,
60 61 62 63 64 65 66 67 68 69 70 71 72
                      Tensor* class_interval) {
  std::vector<int> shard_dim_vec(nranks + 1, 0);
  shard_dim_vec[rank + 1] = D;
  if (nranks <= 1) {
    framework::TensorFromVector(shard_dim_vec, ctx, class_interval);
    return;
  }

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
  Tensor num_classes_per_device;
  framework::TensorFromVector(shard_dim_vec, ctx, &num_classes_per_device);
  int* num_classes_per_device_ptr = num_classes_per_device.data<int>();

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  auto map = distributed::ProcessGroupMapFromGid::getInstance();
  if (map->has(rid)) {
    // Use ProcessGroup
    distributed::ProcessGroup* pg = map->get(rid);
    std::vector<phi::DenseTensor> in_tensor;
    std::vector<phi::DenseTensor> out_tensor;
    in_tensor.push_back(num_classes_per_device);
    out_tensor.push_back(num_classes_per_device);

    distributed::AllreduceOptions opts;
    opts.reduce_op = distributed::ReduceOp::SUM;
    auto task = pg->AllReduce(in_tensor, out_tensor, opts);
    task->Wait();
  } else {
    const auto& comm = platform::NCCLCommContext::Instance().Get(rid, place);
    // use global calculate stream
    const auto calcu_stream =
L
Leo Chen 已提交
90
        static_cast<phi::GPUContext*>(
91 92 93 94
            platform::DeviceContextPool::Instance().Get(place))
            ->stream();

    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
95 96
        num_classes_per_device_ptr,
        num_classes_per_device_ptr,
97 98 99
        num_classes_per_device.numel(),
        platform::ToNCCLDataType(
            framework::TransToProtoVarType(num_classes_per_device.dtype())),
100 101 102
        ncclSum,
        comm->comm(),
        calcu_stream));
103
  }
104 105 106 107 108 109 110

  auto class_interval_ptr =
      class_interval->mutable_data<int>({nranks + 1}, place);
  size_t cub_temp_storage_bytes = 0;
  cub::DeviceScan::InclusiveSum<int*, int*>(
      nullptr, cub_temp_storage_bytes, nullptr, nullptr, nranks + 1, stream);
  auto cub_temp_storage = memory::Alloc(place, cub_temp_storage_bytes);
111 112 113 114 115 116
  cub::DeviceScan::InclusiveSum<int*, int*>(cub_temp_storage->ptr(),
                                            cub_temp_storage_bytes,
                                            num_classes_per_device_ptr,
                                            class_interval_ptr,
                                            nranks + 1,
                                            stream);
117 118 119 120 121
  return;
#endif
}

template <typename T, typename IndexT>
122 123 124 125 126 127 128 129 130 131
__global__ void AddMarginToPositiveLogitsKernel(T* logit,
                                                const IndexT* label,
                                                const float margin1,
                                                const float margin2,
                                                const float margin3,
                                                const int rank,
                                                const int nranks,
                                                const int64_t N,
                                                const int64_t D,
                                                const int* class_interval_ptr) {
132 133 134 135 136 137 138 139 140 141 142
  using MPType = typename details::MPTypeTrait<T>::Type;
  int start_index = class_interval_ptr[rank];
  int end_index = class_interval_ptr[rank + 1];
  int num_classes = class_interval_ptr[nranks];
  CUDA_KERNEL_LOOP(i, N) {
    auto real_label = label[i];
    PADDLE_ENFORCE((real_label < num_classes) && (real_label >= 0),
                   "The index is out of bounds, "
                   "please check whether the value of label and "
                   "input meet the number of class. It should "
                   "be less than [%d], but received [%d]",
143 144
                   num_classes,
                   real_label);
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168

    if (real_label >= start_index && real_label < end_index) {
      int64_t offset = i * D + real_label - start_index;
      if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) {
        MPType x = static_cast<MPType>(logit[offset]);
        MPType theta = acos(x);
        if (fabs(margin1 - 1.0) > 1e-8) {
          theta *= static_cast<MPType>(margin1);
        }
        if (fabs(margin2) > 1e-8) {
          theta += static_cast<MPType>(margin2);
        }
        logit[offset] = static_cast<T>(cos(theta));
      }
      if (fabs(margin3) > 1e-8) {
        MPType y = static_cast<MPType>(logit[offset]);
        y -= static_cast<MPType>(margin3);
        logit[offset] = static_cast<T>(y);
      }
    }
  }
}

template <typename T>
169 170 171
__global__ void ScaleLogitKernel(T* logits,
                                 const float scale,
                                 const int64_t N,
172 173 174 175 176
                                 const int64_t D) {
  CUDA_KERNEL_LOOP(i, N * D) { logits[i] *= static_cast<T>(scale); }
}

template <typename T>
177 178 179 180
__global__ void LogitsMinusMaxKernel(T* logits,
                                     const T* logits_max_per_row,
                                     const int64_t N,
                                     const int64_t D) {
181 182 183 184 185 186 187
  CUDA_KERNEL_LOOP(i, N * D) {
    auto row = i / D;
    logits[i] -= logits_max_per_row[row];
  }
}

template <typename T>
188 189 190 191
__global__ void LogitsMinusLogSumKernel(T* logits,
                                        const T* logits_sum_per_row,
                                        const int64_t N,
                                        const int64_t D) {
192 193
  CUDA_KERNEL_LOOP(i, N * D) {
    auto row = i / D;
194
    logits[i] -= kps::details::Log(logits_sum_per_row[row]);
195 196 197 198 199
  }
}

template <typename T, typename IndexT>
__global__ void HardLabelSoftmaxWithCrossEntropyKernel(
200 201 202 203 204 205 206
    T* loss,
    T* log_softmax,
    const IndexT* labels,
    const int rank,
    const int64_t N,
    const int64_t D,
    const int* class_interval_ptr) {
207 208 209 210 211 212 213
  int start_index = class_interval_ptr[rank];
  CUDA_KERNEL_LOOP(i, N * D) {
    auto row = i / D;
    auto col = i % D;
    if ((col + start_index) == labels[row]) {
      auto softmax = log_softmax[i];
      loss[row] = -softmax;
214
      log_softmax[i] = kps::details::Exp(softmax);
215
    } else {
216
      log_softmax[i] = kps::details::Exp(log_softmax[i]);
217 218 219 220 221
    }
  }
}

template <typename T, typename IndexT>
222 223 224 225 226 227 228 229 230 231
__global__ void CalculateGrad(T* logits_grad,
                              const T* loss_grad,
                              const T* logits,
                              const IndexT* labels,
                              const float margin1,
                              const float margin2,
                              const float scale,
                              const int rank,
                              const int64_t N,
                              const int64_t D,
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
                              const int* class_interval_ptr) {
  using MPType = typename details::MPTypeTrait<T>::Type;
  int start_index = class_interval_ptr[rank];
  CUDA_KERNEL_LOOP(i, N * D) {
    auto row = i / D;
    auto col = i % D;
    if ((col + start_index) == labels[row]) {
      logits_grad[i] = (logits_grad[i] - static_cast<T>(1.0)) * loss_grad[row];
      if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) {
        MPType dout = static_cast<MPType>(logits_grad[i]);
        MPType one = static_cast<MPType>(1.0f);
        MPType x = static_cast<MPType>(logits[i]);
        MPType m1 = static_cast<MPType>(margin1);
        MPType m2 = static_cast<MPType>(margin2);

        MPType d = m1 * sin(m1 * acos(x) + m2) / sqrt(one - x * x);
        logits_grad[i] = static_cast<T>(dout * d);
      }
    } else {
      logits_grad[i] *= loss_grad[row];
    }
    if (fabs(scale - 1.0) > 1e-8) {
      logits_grad[i] *= static_cast<T>(scale);
    }
  }
}

template <typename T>
class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const Tensor* logits = ctx.Input<Tensor>("Logits");
    const Tensor* labels = ctx.Input<Tensor>("Label");
    Tensor* softmax = ctx.Output<Tensor>("Softmax");
    Tensor* loss = ctx.Output<Tensor>("Loss");

    const int rid = ctx.Attr<int>("ring_id");
    const int nranks = ctx.Attr<int>("nranks");
    const int rank = ctx.Attr<int>("rank");

    const float margin1 = ctx.Attr<float>("margin1");
    const float margin2 = ctx.Attr<float>("margin2");
    const float margin3 = ctx.Attr<float>("margin3");
    const float scale = ctx.Attr<float>("scale");

    const auto& place = ctx.GetPlace();
L
Leo Chen 已提交
278
    auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
279 280 281

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
    platform::NCCLComm* comm;
282
    distributed::ProcessGroup* pg = nullptr;
283 284
    gpuStream_t stream;
    if (nranks > 1) {
285 286 287 288 289 290 291 292
      auto map = distributed::ProcessGroupMapFromGid::getInstance();
      if (map->has(rid)) {
        // Use ProcessGroup
        pg = map->get(rid);
      } else {
        comm = platform::NCCLCommContext::Instance().Get(rid, place);

        // use global calculate stream
L
Leo Chen 已提交
293
        stream = static_cast<phi::GPUContext*>(
294 295 296
                     platform::DeviceContextPool::Instance().Get(place))
                     ->stream();
      }
297 298 299 300 301 302 303 304 305 306 307
    }
#endif

    // allocate memory on device.
    T* softmax_ptr = softmax->mutable_data<T>(place);
    T* loss_ptr = loss->mutable_data<T>(place);

    const auto& logits_dims = logits->dims();
    const auto& labels_dims = labels->dims();

    const int axis = logits_dims.size() - 1;
308 309
    const int N = phi::funcs::SizeToAxis(axis, logits_dims);
    const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
310 311 312

    int blocks = NumBlocks(N);
    int threads = kNumCUDAThreads;
313
    const auto& label_type = framework::TransToProtoVarType(labels->dtype());
314 315 316

    // copy logits to softmax variable since we can't modify logits,
    // and it also be used when calculate grad
317 318
    framework::TensorCopy(
        *logits, ctx.GetPlace(), ctx.device_context(), softmax);
319 320 321 322 323 324

    Tensor softmax_2d;
    softmax_2d.ShareDataWith(*softmax).Resize({N, D});
    T* logits_ptr = softmax_2d.data<T>();

    Tensor class_interval;
325 326 327 328 329 330 331 332
    GetClassInterval(dev_ctx.stream(),
                     place,
                     ctx.cuda_device_context(),
                     rid,
                     rank,
                     nranks,
                     D,
                     &class_interval);
333 334 335 336 337 338 339 340

    // step 1, preprocess logits
    // add margin for positive elements
    // theta = acos(x_i)
    // (cos(m1 * theta + m2) - m3)
    // save match_logits, used for gradient computation.
    if (label_type == framework::proto::VarType::INT32) {
      typedef int32_t LabelT;
341 342
      AddMarginToPositiveLogitsKernel<T>
          <<<NumBlocks(N), threads, 0, dev_ctx.stream()>>>(
343 344 345 346 347 348 349 350 351 352
              logits_ptr,
              labels->data<LabelT>(),
              margin1,
              margin2,
              margin3,
              rank,
              nranks,
              N,
              D,
              class_interval.data<int>());
353 354
    } else if (label_type == framework::proto::VarType::INT64) {
      typedef int64_t LabelT;
355 356
      AddMarginToPositiveLogitsKernel<T>
          <<<NumBlocks(N), threads, 0, dev_ctx.stream()>>>(
357 358 359 360 361 362 363 364 365 366
              logits_ptr,
              labels->data<LabelT>(),
              margin1,
              margin2,
              margin3,
              rank,
              nranks,
              N,
              D,
              class_interval.data<int>());
G
Guoxia Wang 已提交
367 368 369 370 371
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "margin_cross_entropy label type noly support int32 and int64, "
          "but got %s",
          label_type));
372 373 374 375 376 377 378 379
    }

    // scale by s
    ScaleLogitKernel<T><<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
        logits_ptr, scale, N, D);

    // step 2, obtain logit_max
    Tensor logits_max;
L
Leo Chen 已提交
380
    logits_max = ctx.AllocateTmpTensor<T, phi::GPUContext>({N, 1}, dev_ctx);
381
    T* logits_max_buff = logits_max.mutable_data<T>(place);
382
    TensorReduceImpl<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
383 384 385 386 387
        dev_ctx,
        softmax_2d,
        &logits_max,
        kps::IdentityFunctor<T>(),
        {1},
388
        dev_ctx.stream());
389 390 391

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
    if (nranks > 1) {
392 393 394 395 396 397 398 399 400 401 402 403
      if (pg) {
        std::vector<phi::DenseTensor> in_tensor;
        std::vector<phi::DenseTensor> out_tensor;
        in_tensor.push_back(logits_max);
        out_tensor.push_back(logits_max);

        distributed::AllreduceOptions opts;
        opts.reduce_op = distributed::ReduceOp::MAX;
        auto task = pg->AllReduce(in_tensor, out_tensor, opts);
        task->Wait();
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
404 405 406
            logits_max_buff,
            logits_max_buff,
            logits_max.numel(),
407 408
            platform::ToNCCLDataType(
                framework::TransToProtoVarType(logits_max.dtype())),
409 410 411
            ncclMax,
            comm->comm(),
            stream));
412
      }
413 414 415 416 417 418 419 420 421
    }
#endif

    // step 3, logit - logit_max
    LogitsMinusMaxKernel<T><<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
        logits_ptr, logits_max_buff, N, D);

    // step 4, sum(exp(logit - logit_max))
    Tensor sum_exp_logits;
L
Leo Chen 已提交
422
    sum_exp_logits = ctx.AllocateTmpTensor<T, phi::GPUContext>({N, 1}, dev_ctx);
423
    T* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
424
    TensorReduceImpl<T, T, kps::AddFunctor, kps::ExpFunctor<T>>(
425 426 427 428 429
        dev_ctx,
        softmax_2d,
        &sum_exp_logits,
        kps::ExpFunctor<T>(),
        {1},
430
        dev_ctx.stream());
431 432 433

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
    if (nranks > 1) {
434 435 436 437 438 439 440 441 442 443 444 445
      if (pg) {
        std::vector<phi::DenseTensor> in_tensor;
        std::vector<phi::DenseTensor> out_tensor;
        in_tensor.push_back(sum_exp_logits);
        out_tensor.push_back(sum_exp_logits);

        distributed::AllreduceOptions opts;
        opts.reduce_op = distributed::ReduceOp::SUM;
        auto task = pg->AllReduce(in_tensor, out_tensor, opts);
        task->Wait();
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
446 447 448
            sum_exp_logits_buff,
            sum_exp_logits_buff,
            sum_exp_logits.numel(),
449 450
            platform::ToNCCLDataType(
                framework::TransToProtoVarType(sum_exp_logits.dtype())),
451 452 453
            ncclSum,
            comm->comm(),
            stream));
454
      }
455 456 457 458
    }
#endif

    // step 5, (logit - logit_max) - log(sum(exp(logit - logit_max)))
459 460 461
    LogitsMinusLogSumKernel<T>
        <<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
            logits_ptr, sum_exp_logits_buff, N, D);
462 463 464 465

    // step 6, prob = exp((logit - logit_max) - log(sum(exp(logit -
    // logit_max))))
    // loss = -((logit_i - logit_max) - log(sum(exp(logit - logit_max))))
L
Leo Chen 已提交
466
    phi::funcs::SetConstant<phi::GPUContext, T>()(
467
        dev_ctx, loss, static_cast<T>(0.0));
468 469
    if (label_type == framework::proto::VarType::INT32) {
      typedef int32_t LabelT;
470 471
      HardLabelSoftmaxWithCrossEntropyKernel<T, LabelT>
          <<<blocks, threads, 0, dev_ctx.stream()>>>(
472 473 474 475 476 477
              loss_ptr,
              logits_ptr,
              labels->data<LabelT>(),
              rank,
              N,
              D,
478
              class_interval.data<int>());
479 480
    } else if (label_type == framework::proto::VarType::INT64) {
      typedef int64_t LabelT;
481 482
      HardLabelSoftmaxWithCrossEntropyKernel<T, LabelT>
          <<<blocks, threads, 0, dev_ctx.stream()>>>(
483 484 485 486 487 488
              loss_ptr,
              logits_ptr,
              labels->data<LabelT>(),
              rank,
              N,
              D,
489
              class_interval.data<int>());
490 491 492 493
    }

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
    if (nranks > 1) {
494 495 496 497 498 499 500 501 502 503 504 505
      if (pg) {
        std::vector<phi::DenseTensor> in_tensor;
        std::vector<phi::DenseTensor> out_tensor;
        in_tensor.push_back(*loss);
        out_tensor.push_back(*loss);

        distributed::AllreduceOptions opts;
        opts.reduce_op = distributed::ReduceOp::SUM;
        auto task = pg->AllReduce(in_tensor, out_tensor, opts);
        task->Wait();
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
506 507 508
            loss_ptr,
            loss_ptr,
            loss->numel(),
509 510
            platform::ToNCCLDataType(
                framework::TransToProtoVarType(loss->dtype())),
511 512 513
            ncclSum,
            comm->comm(),
            stream));
514
      }
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
    }
#endif
  }
};

template <typename T>
class MarginCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* labels = context.Input<Tensor>("Label");
    const Tensor* logits = context.Input<Tensor>("Logits");
    const Tensor* softmax = context.Input<Tensor>("Softmax");

    const Tensor* loss_grad =
        context.Input<Tensor>(framework::GradVarName("Loss"));
    Tensor* logit_grad =
        context.Output<Tensor>(framework::GradVarName("Logits"));

    const bool return_softmax = context.Attr<bool>("return_softmax");

    const int rid = context.Attr<int>("ring_id");
    const int nranks = context.Attr<int>("nranks");
    const int rank = context.Attr<int>("rank");

    const float margin1 = context.Attr<float>("margin1");
    const float margin2 = context.Attr<float>("margin2");
    const float margin3 = context.Attr<float>("margin3");
    const float scale = context.Attr<float>("scale");

L
Leo Chen 已提交
544
    auto& dev_ctx = context.template device_context<phi::GPUContext>();
545 546 547

    const auto sofrmax_dims = softmax->dims();
    const int axis = sofrmax_dims.size() - 1;
548 549
    const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims);
    const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims);
550 551

    if (return_softmax) {
552 553
      framework::TensorCopy(
          *softmax, context.GetPlace(), context.device_context(), logit_grad);
554 555 556 557 558 559
    } else {
      logit_grad->ShareDataWith(*softmax);
    }

    int blocks = NumBlocks(N * D);
    int threads = kNumCUDAThreads;
560
    const auto& label_type = framework::TransToProtoVarType(labels->dtype());
561 562

    Tensor class_interval;
563 564 565 566 567 568 569
    GetClassInterval(dev_ctx.stream(),
                     context.GetPlace(),
                     context.cuda_device_context(),
                     rid,
                     rank,
                     nranks,
                     D,
570 571 572 573 574
                     &class_interval);

    if (label_type == framework::proto::VarType::INT32) {
      typedef int32_t LabelT;
      CalculateGrad<T, LabelT><<<blocks, threads, 0, dev_ctx.stream()>>>(
575 576 577 578 579 580 581 582 583 584
          logit_grad->data<T>(),
          loss_grad->data<T>(),
          logits->data<T>(),
          labels->data<LabelT>(),
          margin1,
          margin2,
          scale,
          rank,
          N,
          D,
585 586 587 588
          class_interval.data<int>());
    } else if (label_type == framework::proto::VarType::INT64) {
      typedef int64_t LabelT;
      CalculateGrad<T, LabelT><<<blocks, threads, 0, dev_ctx.stream()>>>(
589 590 591 592 593 594 595 596 597 598
          logit_grad->data<T>(),
          loss_grad->data<T>(),
          logits->data<T>(),
          labels->data<LabelT>(),
          margin1,
          margin2,
          scale,
          rank,
          N,
          D,
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
          class_interval.data<int>());
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(margin_cross_entropy,
                        ops::MarginCrossEntropyOpCUDAKernel<float>,
                        ops::MarginCrossEntropyOpCUDAKernel<double>,
                        ops::MarginCrossEntropyOpCUDAKernel<plat::float16>);

REGISTER_OP_CUDA_KERNEL(margin_cross_entropy_grad,
                        ops::MarginCrossEntropyGradCUDAKernel<float>,
                        ops::MarginCrossEntropyGradCUDAKernel<double>,
                        ops::MarginCrossEntropyGradCUDAKernel<plat::float16>);