diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cc b/paddle/fluid/operators/margin_cross_entropy_op.cc index 6ae692260a554d6f8ab4d373c77c01b73e7f10a2..5813a86a6930195870ac963da494d360bf2855fb 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cc +++ b/paddle/fluid/operators/margin_cross_entropy_op.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/margin_cross_entropy_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -204,8 +205,3 @@ REGISTER_OPERATOR( ops::MarginCrossEntropyOpGradMaker); REGISTER_OPERATOR(margin_cross_entropy_grad, ops::MarginCrossEntropyOpGrad); - -REGISTER_OP_CPU_KERNEL(margin_cross_entropy, - ops::MarginCrossEntropyOpCPUKernel, - ops::MarginCrossEntropyOpCPUKernel, - ops::MarginCrossEntropyOpCPUKernel); diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu deleted file mode 100644 index 6d1ff9f296eb85601c9bb9eb2d956986f48d5d8c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/margin_cross_entropy_op.cu +++ /dev/null @@ -1,618 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_HIP -#include -namespace cub = hipcub; -#else -#include -#endif - -#include - -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/margin_cross_entropy_op.h" -#include "paddle/fluid/operators/math/softmax_impl.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/string/string_helper.h" -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/kernels/funcs/axis_utils.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/ProcessGroup.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -void GetClassInterval(const gpuStream_t& stream, - const platform::Place& place, - const platform::DeviceContext& ctx, - const int rid, - const int rank, - const int nranks, - const int D, - Tensor* class_interval) { - std::vector shard_dim_vec(nranks + 1, 0); - shard_dim_vec[rank + 1] = D; - if (nranks <= 1) { - framework::TensorFromVector(shard_dim_vec, ctx, class_interval); - return; - } - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - Tensor num_classes_per_device; - framework::TensorFromVector(shard_dim_vec, ctx, &num_classes_per_device); - int* num_classes_per_device_ptr = num_classes_per_device.data(); - - auto map = distributed::ProcessGroupMapFromGid::getInstance(); - if (map->has(rid)) { - // Use ProcessGroup - distributed::ProcessGroup* pg = map->get(rid); - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(num_classes_per_device); - out_tensor.push_back(num_classes_per_device); - - distributed::AllreduceOptions opts; - opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); - task->Wait(); - } else { - const auto& comm = platform::NCCLCommContext::Instance().Get(rid, place); - // use global calculate stream - const auto calcu_stream = - static_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - num_classes_per_device_ptr, - num_classes_per_device_ptr, - num_classes_per_device.numel(), - platform::ToNCCLDataType( - framework::TransToProtoVarType(num_classes_per_device.dtype())), - ncclSum, - comm->comm(), - calcu_stream)); - } - - auto class_interval_ptr = - class_interval->mutable_data({nranks + 1}, place); - size_t cub_temp_storage_bytes = 0; - cub::DeviceScan::InclusiveSum( - nullptr, cub_temp_storage_bytes, nullptr, nullptr, nranks + 1, stream); - auto cub_temp_storage = memory::Alloc(place, cub_temp_storage_bytes); - cub::DeviceScan::InclusiveSum(cub_temp_storage->ptr(), - cub_temp_storage_bytes, - num_classes_per_device_ptr, - class_interval_ptr, - nranks + 1, - stream); - return; -#endif -} - -template -__global__ void AddMarginToPositiveLogitsKernel(T* logit, - const IndexT* label, - const float margin1, - const float margin2, - const float margin3, - const int rank, - const int nranks, - const int64_t N, - const int64_t D, - const int* class_interval_ptr) { - using MPType = typename details::MPTypeTrait::Type; - int start_index = class_interval_ptr[rank]; - int end_index = class_interval_ptr[rank + 1]; - int num_classes = class_interval_ptr[nranks]; - CUDA_KERNEL_LOOP(i, N) { - auto real_label = label[i]; - PADDLE_ENFORCE((real_label < num_classes) && (real_label >= 0), - "The index is out of bounds, " - "please check whether the value of label and " - "input meet the number of class. It should " - "be less than [%d], but received [%d]", - num_classes, - real_label); - - if (real_label >= start_index && real_label < end_index) { - int64_t offset = i * D + real_label - start_index; - if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) { - MPType x = static_cast(logit[offset]); - MPType theta = acos(x); - if (fabs(margin1 - 1.0) > 1e-8) { - theta *= static_cast(margin1); - } - if (fabs(margin2) > 1e-8) { - theta += static_cast(margin2); - } - logit[offset] = static_cast(cos(theta)); - } - if (fabs(margin3) > 1e-8) { - MPType y = static_cast(logit[offset]); - y -= static_cast(margin3); - logit[offset] = static_cast(y); - } - } - } -} - -template -__global__ void ScaleLogitKernel(T* logits, - const float scale, - const int64_t N, - const int64_t D) { - CUDA_KERNEL_LOOP(i, N * D) { logits[i] *= static_cast(scale); } -} - -template -__global__ void LogitsMinusMaxKernel(T* logits, - const T* logits_max_per_row, - const int64_t N, - const int64_t D) { - CUDA_KERNEL_LOOP(i, N * D) { - auto row = i / D; - logits[i] -= logits_max_per_row[row]; - } -} - -template -__global__ void LogitsMinusLogSumKernel(T* logits, - const T* logits_sum_per_row, - const int64_t N, - const int64_t D) { - CUDA_KERNEL_LOOP(i, N * D) { - auto row = i / D; - logits[i] -= kps::details::Log(logits_sum_per_row[row]); - } -} - -template -__global__ void HardLabelSoftmaxWithCrossEntropyKernel( - T* loss, - T* log_softmax, - const IndexT* labels, - const int rank, - const int64_t N, - const int64_t D, - const int* class_interval_ptr) { - int start_index = class_interval_ptr[rank]; - CUDA_KERNEL_LOOP(i, N * D) { - auto row = i / D; - auto col = i % D; - if ((col + start_index) == labels[row]) { - auto softmax = log_softmax[i]; - loss[row] = -softmax; - log_softmax[i] = kps::details::Exp(softmax); - } else { - log_softmax[i] = kps::details::Exp(log_softmax[i]); - } - } -} - -template -__global__ void CalculateGrad(T* logits_grad, - const T* loss_grad, - const T* logits, - const IndexT* labels, - const float margin1, - const float margin2, - const float scale, - const int rank, - const int64_t N, - const int64_t D, - const int* class_interval_ptr) { - using MPType = typename details::MPTypeTrait::Type; - int start_index = class_interval_ptr[rank]; - CUDA_KERNEL_LOOP(i, N * D) { - auto row = i / D; - auto col = i % D; - if ((col + start_index) == labels[row]) { - logits_grad[i] = (logits_grad[i] - static_cast(1.0)) * loss_grad[row]; - if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) { - MPType dout = static_cast(logits_grad[i]); - MPType one = static_cast(1.0f); - MPType x = static_cast(logits[i]); - MPType m1 = static_cast(margin1); - MPType m2 = static_cast(margin2); - - MPType d = m1 * sin(m1 * acos(x) + m2) / sqrt(one - x * x); - logits_grad[i] = static_cast(dout * d); - } - } else { - logits_grad[i] *= loss_grad[row]; - } - if (fabs(scale - 1.0) > 1e-8) { - logits_grad[i] *= static_cast(scale); - } - } -} - -template -class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* logits = ctx.Input("Logits"); - const Tensor* labels = ctx.Input("Label"); - Tensor* softmax = ctx.Output("Softmax"); - Tensor* loss = ctx.Output("Loss"); - - const int rid = ctx.Attr("ring_id"); - const int nranks = ctx.Attr("nranks"); - const int rank = ctx.Attr("rank"); - - const float margin1 = ctx.Attr("margin1"); - const float margin2 = ctx.Attr("margin2"); - const float margin3 = ctx.Attr("margin3"); - const float scale = ctx.Attr("scale"); - - const auto& place = ctx.GetPlace(); - auto& dev_ctx = ctx.template device_context(); - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - platform::NCCLComm* comm; - distributed::ProcessGroup* pg = nullptr; - gpuStream_t stream; - if (nranks > 1) { - auto map = distributed::ProcessGroupMapFromGid::getInstance(); - if (map->has(rid)) { - // Use ProcessGroup - pg = map->get(rid); - } else { - comm = platform::NCCLCommContext::Instance().Get(rid, place); - - // use global calculate stream - stream = static_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - } - } -#endif - - // allocate memory on device. - T* softmax_ptr = softmax->mutable_data(place); - T* loss_ptr = loss->mutable_data(place); - - const auto& logits_dims = logits->dims(); - const auto& labels_dims = labels->dims(); - - const int axis = logits_dims.size() - 1; - const int N = phi::funcs::SizeToAxis(axis, logits_dims); - const int D = phi::funcs::SizeFromAxis(axis, logits_dims); - - int blocks = NumBlocks(N); - int threads = kNumCUDAThreads; - const auto& label_type = framework::TransToProtoVarType(labels->dtype()); - - // copy logits to softmax variable since we can't modify logits, - // and it also be used when calculate grad - framework::TensorCopy( - *logits, ctx.GetPlace(), ctx.device_context(), softmax); - - Tensor softmax_2d; - softmax_2d.ShareDataWith(*softmax).Resize({N, D}); - T* logits_ptr = softmax_2d.data(); - - Tensor class_interval; - GetClassInterval(dev_ctx.stream(), - place, - ctx.cuda_device_context(), - rid, - rank, - nranks, - D, - &class_interval); - - // step 1, preprocess logits - // add margin for positive elements - // theta = acos(x_i) - // (cos(m1 * theta + m2) - m3) - // save match_logits, used for gradient computation. - if (label_type == framework::proto::VarType::INT32) { - typedef int32_t LabelT; - AddMarginToPositiveLogitsKernel - <<>>( - logits_ptr, - labels->data(), - margin1, - margin2, - margin3, - rank, - nranks, - N, - D, - class_interval.data()); - } else if (label_type == framework::proto::VarType::INT64) { - typedef int64_t LabelT; - AddMarginToPositiveLogitsKernel - <<>>( - logits_ptr, - labels->data(), - margin1, - margin2, - margin3, - rank, - nranks, - N, - D, - class_interval.data()); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "margin_cross_entropy label type noly support int32 and int64, " - "but got %s", - label_type)); - } - - // scale by s - ScaleLogitKernel<<>>( - logits_ptr, scale, N, D); - - // step 2, obtain logit_max - Tensor logits_max; - logits_max = ctx.AllocateTmpTensor({N, 1}, dev_ctx); - T* logits_max_buff = logits_max.mutable_data(place); - TensorReduceImpl>( - dev_ctx, - softmax_2d, - &logits_max, - kps::IdentityFunctor(), - {1}, - dev_ctx.stream()); - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - if (nranks > 1) { - if (pg) { - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(logits_max); - out_tensor.push_back(logits_max); - - distributed::AllreduceOptions opts; - opts.reduce_op = distributed::ReduceOp::MAX; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); - task->Wait(); - } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - logits_max_buff, - logits_max_buff, - logits_max.numel(), - platform::ToNCCLDataType( - framework::TransToProtoVarType(logits_max.dtype())), - ncclMax, - comm->comm(), - stream)); - } - } -#endif - - // step 3, logit - logit_max - LogitsMinusMaxKernel<<>>( - logits_ptr, logits_max_buff, N, D); - - // step 4, sum(exp(logit - logit_max)) - Tensor sum_exp_logits; - sum_exp_logits = ctx.AllocateTmpTensor({N, 1}, dev_ctx); - T* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); - TensorReduceImpl>( - dev_ctx, - softmax_2d, - &sum_exp_logits, - kps::ExpFunctor(), - {1}, - dev_ctx.stream()); - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - if (nranks > 1) { - if (pg) { - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(sum_exp_logits); - out_tensor.push_back(sum_exp_logits); - - distributed::AllreduceOptions opts; - opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); - task->Wait(); - } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - sum_exp_logits_buff, - sum_exp_logits_buff, - sum_exp_logits.numel(), - platform::ToNCCLDataType( - framework::TransToProtoVarType(sum_exp_logits.dtype())), - ncclSum, - comm->comm(), - stream)); - } - } -#endif - - // step 5, (logit - logit_max) - log(sum(exp(logit - logit_max))) - LogitsMinusLogSumKernel - <<>>( - logits_ptr, sum_exp_logits_buff, N, D); - - // step 6, prob = exp((logit - logit_max) - log(sum(exp(logit - - // logit_max)))) - // loss = -((logit_i - logit_max) - log(sum(exp(logit - logit_max)))) - phi::funcs::SetConstant()( - dev_ctx, loss, static_cast(0.0)); - if (label_type == framework::proto::VarType::INT32) { - typedef int32_t LabelT; - HardLabelSoftmaxWithCrossEntropyKernel - <<>>( - loss_ptr, - logits_ptr, - labels->data(), - rank, - N, - D, - class_interval.data()); - } else if (label_type == framework::proto::VarType::INT64) { - typedef int64_t LabelT; - HardLabelSoftmaxWithCrossEntropyKernel - <<>>( - loss_ptr, - logits_ptr, - labels->data(), - rank, - N, - D, - class_interval.data()); - } - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - if (nranks > 1) { - if (pg) { - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(*loss); - out_tensor.push_back(*loss); - - distributed::AllreduceOptions opts; - opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); - task->Wait(); - } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - loss_ptr, - loss_ptr, - loss->numel(), - platform::ToNCCLDataType( - framework::TransToProtoVarType(loss->dtype())), - ncclSum, - comm->comm(), - stream)); - } - } -#endif - } -}; - -template -class MarginCrossEntropyGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* labels = context.Input("Label"); - const Tensor* logits = context.Input("Logits"); - const Tensor* softmax = context.Input("Softmax"); - - const Tensor* loss_grad = - context.Input(framework::GradVarName("Loss")); - Tensor* logit_grad = - context.Output(framework::GradVarName("Logits")); - - const bool return_softmax = context.Attr("return_softmax"); - - const int rid = context.Attr("ring_id"); - const int nranks = context.Attr("nranks"); - const int rank = context.Attr("rank"); - - const float margin1 = context.Attr("margin1"); - const float margin2 = context.Attr("margin2"); - const float margin3 = context.Attr("margin3"); - const float scale = context.Attr("scale"); - - auto& dev_ctx = context.template device_context(); - - const auto sofrmax_dims = softmax->dims(); - const int axis = sofrmax_dims.size() - 1; - const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims); - const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims); - - if (return_softmax) { - framework::TensorCopy( - *softmax, context.GetPlace(), context.device_context(), logit_grad); - } else { - logit_grad->ShareDataWith(*softmax); - } - - int blocks = NumBlocks(N * D); - int threads = kNumCUDAThreads; - const auto& label_type = framework::TransToProtoVarType(labels->dtype()); - - Tensor class_interval; - GetClassInterval(dev_ctx.stream(), - context.GetPlace(), - context.cuda_device_context(), - rid, - rank, - nranks, - D, - &class_interval); - - if (label_type == framework::proto::VarType::INT32) { - typedef int32_t LabelT; - CalculateGrad<<>>( - logit_grad->data(), - loss_grad->data(), - logits->data(), - labels->data(), - margin1, - margin2, - scale, - rank, - N, - D, - class_interval.data()); - } else if (label_type == framework::proto::VarType::INT64) { - typedef int64_t LabelT; - CalculateGrad<<>>( - logit_grad->data(), - loss_grad->data(), - logits->data(), - labels->data(), - margin1, - margin2, - scale, - rank, - N, - D, - class_interval.data()); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(margin_cross_entropy, - ops::MarginCrossEntropyOpCUDAKernel, - ops::MarginCrossEntropyOpCUDAKernel, - ops::MarginCrossEntropyOpCUDAKernel); - -REGISTER_OP_CUDA_KERNEL(margin_cross_entropy_grad, - ops::MarginCrossEntropyGradCUDAKernel, - ops::MarginCrossEntropyGradCUDAKernel, - ops::MarginCrossEntropyGradCUDAKernel); diff --git a/paddle/fluid/operators/margin_cross_entropy_op.h b/paddle/fluid/operators/margin_cross_entropy_op.h deleted file mode 100644 index 9261c84c8552c3eb6b441a28324859970eb0a0b3..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/margin_cross_entropy_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/softmax.h" - -namespace paddle { -namespace operators { - -template -class MarginCrossEntropyOpCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW(platform::errors::Unavailable( - "Do not support margin_cross_entropy for cpu kernel " - "now.")); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..54813422ef5e8cb092afbed8a62569ae6d2c170d --- /dev/null +++ b/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu @@ -0,0 +1,243 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// old op include, fluid should be removed +#ifdef PADDLE_WITH_HIP +#include +namespace cub = hipcub; +#else +#include +#endif + +#include +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/softmax_kernel_impl.h" +#include "paddle/phi/kernels/margin_cross_entropy_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif +#include "paddle/phi/backends/gpu/gpu_context.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +void GetClassInterval(const gpuStream_t& stream, + const phi::Place& place, + const Context& dev_ctx, + const int rid, + const int rank, + const int nranks, + const int D, + DenseTensor* class_interval) { + std::vector shard_dim_vec(nranks + 1, 0); + shard_dim_vec[rank + 1] = D; + if (nranks <= 1) { + paddle::framework::TensorFromVector(shard_dim_vec, dev_ctx, class_interval); + return; + } +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + DenseTensor num_classes_per_device; + paddle::framework::TensorFromVector( + shard_dim_vec, dev_ctx, &num_classes_per_device); + int* num_classes_per_device_ptr = num_classes_per_device.data(); + + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(rid)) { + // Use ProcessGroup + paddle::distributed::ProcessGroup* pg = map->get(rid); + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(num_classes_per_device); + out_tensor.push_back(num_classes_per_device); + + paddle::distributed::AllreduceOptions opts; + opts.reduce_op = paddle::distributed::ReduceOp::SUM; + auto task = pg->AllReduce(in_tensor, out_tensor, opts); + task->Wait(); + } else { + const auto& comm = + paddle::platform::NCCLCommContext::Instance().Get(rid, place); + // use global calculate stream + const auto calcu_stream = + static_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + num_classes_per_device_ptr, + num_classes_per_device_ptr, + num_classes_per_device.numel(), + paddle::platform::ToNCCLDataType(paddle::framework::TransToProtoVarType( + num_classes_per_device.dtype())), + ncclSum, + comm->comm(), + calcu_stream)); + } + + class_interval->Resize({nranks + 1}); + auto class_interval_ptr = dev_ctx.template Alloc(class_interval); + + size_t cub_temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum( + nullptr, cub_temp_storage_bytes, nullptr, nullptr, nranks + 1, stream); + auto cub_temp_storage = paddle::memory::Alloc(place, cub_temp_storage_bytes); + cub::DeviceScan::InclusiveSum(cub_temp_storage->ptr(), + cub_temp_storage_bytes, + num_classes_per_device_ptr, + class_interval_ptr, + nranks + 1, + stream); + return; +#endif +} + +template +__global__ void CalculateGrad(T* logits_grad, + const T* loss_grad, + const T* logits, + const IndexT* label, + const float margin1, + const float margin2, + const float scale, + const int rank, + const int64_t N, + const int64_t D, + const int* class_interval_ptr) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + int start_index = class_interval_ptr[rank]; + CUDA_KERNEL_LOOP(i, N * D) { + auto row = i / D; + auto col = i % D; + if ((col + start_index) == label[row]) { + logits_grad[i] = (logits_grad[i] - static_cast(1.0)) * loss_grad[row]; + if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) { + MPType dout = static_cast(logits_grad[i]); + MPType one = static_cast(1.0f); + MPType x = static_cast(logits[i]); + MPType m1 = static_cast(margin1); + MPType m2 = static_cast(margin2); + + MPType d = m1 * sin(m1 * acos(x) + m2) / sqrt(one - x * x); + logits_grad[i] = static_cast(dout * d); + } + } else { + logits_grad[i] *= loss_grad[row]; + } + if (fabs(scale - 1.0) > 1e-8) { + logits_grad[i] *= static_cast(scale); + } + } +} + +template +void MarginCrossEntropyGradKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool return_softmax, + int ring_id, + int rank, + int nranks, + float margin1, + float margin2, + float margin3, + float scale, + DenseTensor* logits_grad) { + const auto softmax_dims = softmax.dims(); + const int axis = softmax_dims.size() - 1; + const int N = phi::funcs::SizeToAxis(axis, softmax_dims); + const int D = phi::funcs::SizeFromAxis(axis, softmax_dims); + + if (return_softmax) { + phi::Copy( + dev_ctx, softmax, dev_ctx.GetPlace(), false, logits_grad); + } else { + logits_grad->ShareDataWith(softmax); + } + + int blocks = NumBlocks(N * D); + int threads = kNumCUDAThreads; + const auto& label_type = + paddle::framework::TransToProtoVarType(label.dtype()); + + DenseTensor class_interval; + GetClassInterval(dev_ctx.stream(), + dev_ctx.GetPlace(), + dev_ctx, + ring_id, + rank, + nranks, + D, + &class_interval); + + if (label_type == paddle::framework::proto::VarType::INT32) { + typedef int32_t LabelT; + CalculateGrad + <<>>(logits_grad->data(), + loss_grad.data(), + logits.data(), + label.data(), + margin1, + margin2, + scale, + rank, + N, + D, + class_interval.data()); + } else if (label_type == paddle::framework::proto::VarType::INT64) { + typedef int64_t LabelT; + CalculateGrad + <<>>(logits_grad->data(), + loss_grad.data(), + logits.data(), + label.data(), + margin1, + margin2, + scale, + rank, + N, + D, + class_interval.data()); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(margin_cross_entropy_grad, + GPU, + ALL_LAYOUT, + phi::MarginCrossEntropyGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7ed0eb1f5049970312051e4a41077b6ae3cdf0cc --- /dev/null +++ b/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu @@ -0,0 +1,484 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// old op include, fluid should be removed +#ifdef PADDLE_WITH_HIP +#include +namespace cub = hipcub; +#else +#include +#endif + +#include +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif +// trace op include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +void GetClassInterval(const gpuStream_t& stream, + const phi::Place& place, + const Context& dev_ctx, + const int rid, + const int rank, + const int nranks, + const int D, + DenseTensor* class_interval) { + std::vector shard_dim_vec(nranks + 1, 0); + shard_dim_vec[rank + 1] = D; + if (nranks <= 1) { + paddle::framework::TensorFromVector(shard_dim_vec, dev_ctx, class_interval); + return; + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + DenseTensor num_classes_per_device; + paddle::framework::TensorFromVector( + shard_dim_vec, dev_ctx, &num_classes_per_device); + int* num_classes_per_device_ptr = num_classes_per_device.data(); + + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(rid)) { + // Use ProcessGroup + paddle::distributed::ProcessGroup* pg = map->get(rid); + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(num_classes_per_device); + out_tensor.push_back(num_classes_per_device); + + paddle::distributed::AllreduceOptions opts; + opts.reduce_op = paddle::distributed::ReduceOp::SUM; + auto task = pg->AllReduce(in_tensor, out_tensor, opts); + task->Wait(); + } else { + const auto& comm = + paddle::platform::NCCLCommContext::Instance().Get(rid, place); + // use global calculate stream + const auto calcu_stream = + static_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + num_classes_per_device_ptr, + num_classes_per_device_ptr, + num_classes_per_device.numel(), + paddle::platform::ToNCCLDataType(paddle::framework::TransToProtoVarType( + num_classes_per_device.dtype())), + ncclSum, + comm->comm(), + calcu_stream)); + } + + class_interval->Resize({nranks + 1}); + auto class_interval_ptr = dev_ctx.template Alloc(class_interval); + size_t cub_temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum( + nullptr, cub_temp_storage_bytes, nullptr, nullptr, nranks + 1, stream); + auto cub_temp_storage = paddle::memory::Alloc(place, cub_temp_storage_bytes); + cub::DeviceScan::InclusiveSum(cub_temp_storage->ptr(), + cub_temp_storage_bytes, + num_classes_per_device_ptr, + class_interval_ptr, + nranks + 1, + stream); + return; +#endif +} + +template +__global__ void AddMarginToPositiveLogitsKernel(T* logit, + const IndexT* label, + const float margin1, + const float margin2, + const float margin3, + const int rank, + const int nranks, + const int64_t N, + const int64_t D, + const int* class_interval_ptr) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + int start_index = class_interval_ptr[rank]; + int end_index = class_interval_ptr[rank + 1]; + int num_classes = class_interval_ptr[nranks]; + CUDA_KERNEL_LOOP(i, N) { + auto real_label = label[i]; + PADDLE_ENFORCE((real_label < num_classes) && (real_label >= 0), + "The index is out of bounds, " + "please check whether the value of label and " + "input meet the number of class. It should " + "be less than [%d], but received [%d]", + num_classes, + real_label); + + if (real_label >= start_index && real_label < end_index) { + int64_t offset = i * D + real_label - start_index; + if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) { + MPType x = static_cast(logit[offset]); + MPType theta = acos(x); + if (fabs(margin1 - 1.0) > 1e-8) { + theta *= static_cast(margin1); + } + if (fabs(margin2) > 1e-8) { + theta += static_cast(margin2); + } + logit[offset] = static_cast(cos(theta)); + } + if (fabs(margin3) > 1e-8) { + MPType y = static_cast(logit[offset]); + y -= static_cast(margin3); + logit[offset] = static_cast(y); + } + } + } +} + +template +__global__ void ScaleLogitKernel(T* logits, + const float scale, + const int64_t N, + const int64_t D) { + CUDA_KERNEL_LOOP(i, N * D) { logits[i] *= static_cast(scale); } +} + +template +__global__ void LogitsMinusMaxKernel(T* logits, + const T* logits_max_per_row, + const int64_t N, + const int64_t D) { + CUDA_KERNEL_LOOP(i, N * D) { + auto row = i / D; + logits[i] -= logits_max_per_row[row]; + } +} + +template +__global__ void LogitsMinusLogSumKernel(T* logits, + const T* logits_sum_per_row, + const int64_t N, + const int64_t D) { + CUDA_KERNEL_LOOP(i, N * D) { + auto row = i / D; + logits[i] -= phi::kps::details::Log(logits_sum_per_row[row]); + } +} + +template +__global__ void HardLabelSoftmaxWithCrossEntropyKernel( + T* loss, + T* log_softmax, + const IndexT* labels, + const int rank, + const int64_t N, + const int64_t D, + const int* class_interval_ptr) { + int start_index = class_interval_ptr[rank]; + CUDA_KERNEL_LOOP(i, N * D) { + auto row = i / D; + auto col = i % D; + if ((col + start_index) == labels[row]) { + auto softmax = log_softmax[i]; + loss[row] = -softmax; + log_softmax[i] = phi::kps::details::Exp(softmax); + } else { + log_softmax[i] = phi::kps::details::Exp(log_softmax[i]); + } + } +} + +template +void MarginCrossEntropyKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& labels, + bool return_softmax, + int ring_id, + int rank, + int nranks, + float margin1, + float margin2, + float margin3, + float scale, + DenseTensor* softmax, + DenseTensor* loss) { + const auto& place = dev_ctx.GetPlace(); // old code + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + paddle::platform::NCCLComm* comm; + paddle::distributed::ProcessGroup* pg = nullptr; + gpuStream_t stream; + if (nranks > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(ring_id)) { + // Use ProcessGroup + pg = map->get(ring_id); + } else { + comm = paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); + + // use global calculate stream + stream = static_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + } + } +#endif + + // allocate memory on device. + T* softmax_ptr = dev_ctx.template Alloc(softmax); + T* loss_ptr = dev_ctx.template Alloc(loss); + + const auto& logits_dims = logits.dims(); + const auto& labels_dims = labels.dims(); + + const int axis = logits_dims.size() - 1; + const int N = phi::funcs::SizeToAxis(axis, logits_dims); + const int D = phi::funcs::SizeFromAxis(axis, logits_dims); + + int blocks = NumBlocks(N); + int threads = kNumCUDAThreads; + const auto& label_type = + paddle::framework::TransToProtoVarType(labels.dtype()); + + // copy logits to softmax variable since we can't modify logits, + // and it also be used when calculate grad + phi::Copy(dev_ctx, logits, dev_ctx.GetPlace(), true, softmax); + + DenseTensor softmax_2d; + softmax_2d.ShareDataWith(*softmax).Resize({N, D}); + T* logits_ptr = softmax_2d.data(); + + DenseTensor class_interval; + GetClassInterval(dev_ctx.stream(), + dev_ctx.GetPlace(), + dev_ctx, + ring_id, + rank, + nranks, + D, + &class_interval); + + // step 1, preprocess logits + // add margin for positive elements + // theta = acos(x_i) + // (cos(m1 * theta + m2) - m3) + // save match_logits, used for gradient computation. + if (label_type == paddle::framework::proto::VarType::INT32) { + typedef int32_t LabelT; + AddMarginToPositiveLogitsKernel + <<>>( + logits_ptr, + labels.data(), + margin1, + margin2, + margin3, + rank, + nranks, + N, + D, + class_interval.data()); + } else if (label_type == paddle::framework::proto::VarType::INT64) { + typedef int64_t LabelT; + AddMarginToPositiveLogitsKernel + <<>>( + logits_ptr, + labels.data(), + margin1, + margin2, + margin3, + rank, + nranks, + N, + D, + class_interval.data()); + } else { + PADDLE_THROW(errors::Unimplemented( + "margin_cross_entropy label type noly support int32 and int64, " + "but got %s", + label_type)); + } + + // scale by s + ScaleLogitKernel<<>>( + logits_ptr, scale, N, D); + + // step 2, obtain logit_max + DenseTensor logits_max; + logits_max.Resize({N, 1}); + dev_ctx.template Alloc(&logits_max); + T* logits_max_buff = dev_ctx.template Alloc(&logits_max); + + phi::funcs:: + ReduceKernel>( + static_cast(dev_ctx), + softmax_2d, + &logits_max, + phi::kps::IdentityFunctor(), + {1}); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + if (nranks > 1) { + if (pg) { + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(logits_max); + out_tensor.push_back(logits_max); + + paddle::distributed::AllreduceOptions opts; + opts.reduce_op = paddle::distributed::ReduceOp::MAX; + auto task = pg->AllReduce(in_tensor, out_tensor, opts); + task->Wait(); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + logits_max_buff, + logits_max_buff, + logits_max.numel(), + paddle::platform::ToNCCLDataType( + paddle::framework::TransToProtoVarType(logits_max.dtype())), + ncclMax, + comm->comm(), + stream)); + } + } +#endif + + // step 3, logit - logit_max + LogitsMinusMaxKernel<<>>( + logits_ptr, logits_max_buff, N, D); + + // step 4, sum(exp(logit - logit_max)) + DenseTensor sum_exp_logits; + sum_exp_logits.Resize({N, 1}); + dev_ctx.template Alloc(&sum_exp_logits); + // T* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); + T* sum_exp_logits_buff = dev_ctx.template Alloc(&sum_exp_logits); + phi::funcs::ReduceKernel>( + static_cast(dev_ctx), + softmax_2d, + &sum_exp_logits, + phi::kps::ExpFunctor(), + {1}); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + if (nranks > 1) { + if (pg) { + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(sum_exp_logits); + out_tensor.push_back(sum_exp_logits); + + paddle::distributed::AllreduceOptions opts; + opts.reduce_op = paddle::distributed::ReduceOp::SUM; + auto task = pg->AllReduce(in_tensor, out_tensor, opts); + task->Wait(); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + sum_exp_logits_buff, + sum_exp_logits_buff, + sum_exp_logits.numel(), + paddle::platform::ToNCCLDataType( + paddle::framework::TransToProtoVarType(sum_exp_logits.dtype())), + ncclSum, + comm->comm(), + stream)); + } + } +#endif + + // step 5, (logit - logit_max) - log(sum(exp(logit - logit_max))) + LogitsMinusLogSumKernel + <<>>( + logits_ptr, sum_exp_logits_buff, N, D); + + // step 6, prob = exp((logit - logit_max) - log(sum(exp(logit - + // logit_max)))) + // loss = -((logit_i - logit_max) - log(sum(exp(logit - logit_max)))) + + phi::funcs::SetConstant functor; + functor(dev_ctx, loss, static_cast(0.0)); + if (label_type == paddle::framework::proto::VarType::INT32) { + typedef int32_t LabelT; + HardLabelSoftmaxWithCrossEntropyKernel + <<>>(loss_ptr, + logits_ptr, + labels.data(), + rank, + N, + D, + class_interval.data()); + } else if (label_type == paddle::framework::proto::VarType::INT64) { + typedef int64_t LabelT; + HardLabelSoftmaxWithCrossEntropyKernel + <<>>(loss_ptr, + logits_ptr, + labels.data(), + rank, + N, + D, + class_interval.data()); + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + if (nranks > 1) { + if (pg) { + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(*loss); + out_tensor.push_back(*loss); + + paddle::distributed::AllreduceOptions opts; + opts.reduce_op = paddle::distributed::ReduceOp::SUM; + auto task = pg->AllReduce(in_tensor, out_tensor, opts); + task->Wait(); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + loss_ptr, + loss_ptr, + loss->numel(), + paddle::platform::ToNCCLDataType( + paddle::framework::TransToProtoVarType(loss->dtype())), + ncclSum, + comm->comm(), + stream)); + } + } +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(margin_cross_entropy, + GPU, + ALL_LAYOUT, + phi::MarginCrossEntropyKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/margin_cross_entropy_grad_kernel.h b/paddle/phi/kernels/margin_cross_entropy_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2d0715149751edd586ef1910309295e651446e53 --- /dev/null +++ b/paddle/phi/kernels/margin_cross_entropy_grad_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +namespace phi { +template +void MarginCrossEntropyGradKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool return_softmax, + int ring_id, + int rank, + int nranks, + float margin1, + float margin2, + float margin3, + float scale, + DenseTensor* logits_grad); +} // namespace phi diff --git a/paddle/phi/kernels/margin_cross_entropy_kernel.h b/paddle/phi/kernels/margin_cross_entropy_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..df58256597695b09fbca54436ee5e93d5b82b473 --- /dev/null +++ b/paddle/phi/kernels/margin_cross_entropy_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MarginCrossEntropyKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + bool return_softmax, + int ring_id, + int rank, + int nranks, + float margin1, + float margin2, + float margin3, + float scale, + DenseTensor* softmax, + DenseTensor* loss); +} // namespace phi diff --git a/paddle/phi/ops/compat/margin_cross_entropy_sig.cc b/paddle/phi/ops/compat/margin_cross_entropy_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..adc0e426d19528b32fde80d4a0a8519d14ad1114 --- /dev/null +++ b/paddle/phi/ops/compat/margin_cross_entropy_sig.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MarginCrossEntropyOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("margin_cross_entropy", + {"Logits", "Label"}, + {"return_softmax", + "ring_id", + "rank", + "nranks", + "margin1", + "margin2", + "margin3", + "scale"}, + {"Softmax", "Loss"}); +} + +KernelSignature MarginCrossEntropyGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("margin_cross_entropy_grad", + {"Logits", "Label", "Softmax", "Loss@GRAD"}, + {"return_softmax", + "ring_id", + "rank", + "nranks", + "margin1", + "margin2", + "margin3", + "scale"}, + {"Logits@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(margin_cross_entropy, + phi::MarginCrossEntropyOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(margin_cross_entropy_grad, + phi::MarginCrossEntropyGradOpArgumentMapping);