diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc index 97d72457a24965c525e72109df55c4e92c96fa80..f3ab5ed65df7ce8ec1c04c3115737166b5338452 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc @@ -106,6 +106,10 @@ class CSoftmaxWithCrossEntropyOpMaker "Input(Logits) " "except the shape in dimension :attr:`axis` as 1. The cross " "entropy loss."); + AddAttr("ignore_index", + "(int default -100) Specifies a target value " + "that is ignored and does not contribute to the loss.") + .SetDefault(-100); AddAttr("ring_id", "(int default 0) nccl communication ring id.") .SetDefault(0); AddAttr("rank", diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index f17a9dfe3c232db161d39fa3fc78e2c594a20651..b3c5e75971c5d4e286f442de25f1ca47871a68ce 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/cross_entropy.h" +#include "paddle/phi/kernels/funcs/math.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/softmax_impl.h" namespace paddle { @@ -59,6 +61,24 @@ __global__ void MaskLabelByIndex(T* predicted_logits, } } +template +__global__ void CaculateLoss(T* loss, + const T* predict_logits, + const T* sum_exp_logits, + const IndexT* label, + const int64_t ignore_index, + const int N) { + CUDA_KERNEL_LOOP(i, N) { + auto real_label = static_cast(label[i]); + loss[i] = ignore_index == real_label + ? static_cast(0) + : phi::funcs::TolerableValue()( + phi::funcs::TolerableValue()( + phi::funcs::real_log(sum_exp_logits[i])) - + predict_logits[i]); + } +} + template __global__ void MaskLabelByIndexGrad(T* logits_grad, const T* loss_grad, @@ -66,11 +86,15 @@ __global__ void MaskLabelByIndexGrad(T* logits_grad, const int start_index, const int end_index, const int64_t N, - const int64_t D) { + const int64_t D, + const int64_t ignore_index) { CUDA_KERNEL_LOOP(i, N * D) { auto row = i / D; auto col = i % D; - if ((col + start_index) == labels[row]) { + auto lbl = static_cast(labels[row]); + if (lbl == ignore_index) { + logits_grad[i] = static_cast(0.0); + } else if ((col + start_index) == labels[row]) { logits_grad[i] = (logits_grad[i] - static_cast(1.0)) * loss_grad[row]; } else { logits_grad[i] *= loss_grad[row]; @@ -102,6 +126,7 @@ struct CSoftmaxWithCrossEntropyFunctor { phi::DenseTensor* softmax = ctx.Output("Softmax"); phi::DenseTensor* loss = ctx.Output("Loss"); + const int64_t ignore_index = ctx.Attr("ignore_index"); const int rid = ctx.Attr("ring_id"); const int nranks = ctx.Attr("nranks"); const int rank = ctx.Attr("rank"); @@ -234,14 +259,23 @@ struct CSoftmaxWithCrossEntropyFunctor { comm->comm(), stream)); - auto eigen_loss = phi::funcs::EigenMatrix::From(loss_2d); - auto eigen_predicted_logits = - phi::funcs::EigenMatrix::From(predicted_logits); - - eigen_loss.device(*dev_ctx.eigen_device()) = - (eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue()) - - eigen_predicted_logits) - .unaryExpr(phi::funcs::TolerableValue()); + if (label_type == framework::proto::VarType::INT32) { + CaculateLoss + <<>>(loss_2d.data(), + predicted_logits.data(), + sum_exp_logits.data(), + labels->data(), + ignore_index, + N); + } else { + CaculateLoss + <<>>(loss_2d.data(), + predicted_logits.data(), + sum_exp_logits.data(), + labels->data(), + ignore_index, + N); + } eigen_softmax.device(*dev_ctx.eigen_device()) = (eigen_softmax * @@ -257,6 +291,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { phi::DenseTensor* softmax = ctx.Output("Softmax"); phi::DenseTensor* loss = ctx.Output("Loss"); + const int64_t ignore_index = ctx.Attr("ignore_index"); const int rid = ctx.Attr("ring_id"); const int nranks = ctx.Attr("nranks"); const int rank = ctx.Attr("rank"); @@ -371,14 +406,23 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { opts.reduce_op = distributed::ReduceOp::SUM; pg->AllReduce(in_out, in_out, opts)->Synchronize(); - auto eigen_loss = phi::funcs::EigenMatrix::From(loss_2d); - auto eigen_predicted_logits = - phi::funcs::EigenMatrix::From(predicted_logits); - - eigen_loss.device(*dev_ctx.eigen_device()) = - (eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue()) - - eigen_predicted_logits) - .unaryExpr(phi::funcs::TolerableValue()); + if (label_type == framework::proto::VarType::INT32) { + CaculateLoss + <<>>(loss_2d.data(), + predicted_logits.data(), + sum_exp_logits.data(), + labels->data(), + ignore_index, + N); + } else { + CaculateLoss + <<>>(loss_2d.data(), + predicted_logits.data(), + sum_exp_logits.data(), + labels->data(), + ignore_index, + N); + } eigen_softmax.device(*dev_ctx.eigen_device()) = (eigen_softmax * @@ -397,6 +441,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { context.Output(framework::GradVarName("Logits")); const phi::DenseTensor* softmax = context.Input("Softmax"); + + const int64_t ignore_index = context.Attr("ignore_index"); const int rank = context.Attr("rank"); auto& dev_ctx = context.template device_context(); @@ -426,7 +472,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { start_index, end_index, N, - D); + D, + ignore_index); } else if (label_type == framework::proto::VarType::INT64) { MaskLabelByIndexGrad <<>>(logit_grad_2d.data(), @@ -435,7 +482,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { start_index, end_index, N, - D); + D, + ignore_index); } } }; diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc index fd6d7f37509598eb298957fae9d0a6657c60e2dc..908018d7550802590a3a02e8ddab361310a2fd2f 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc @@ -33,6 +33,13 @@ template class CSoftmaxWithCrossEntropyOp : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const int ignore_index = ctx.Attr("ignore_index"); + PADDLE_ENFORCE_LT(ignore_index, + 0, + platform::errors::InvalidArgument( + "When SoftmaxWithCrossEntropy run on XPU, " + "ignore_index should be <=0, however it's %d", + ignore_index)); const int rid = ctx.Attr("ring_id"); auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { @@ -453,6 +460,13 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel { context.Output(framework::GradVarName("Logits")); const phi::DenseTensor* softmax = context.Input("Softmax"); + const int ignore_index = context.Attr("ignore_index"); + PADDLE_ENFORCE_LT(ignore_index, + 0, + platform::errors::InvalidArgument( + "When SoftmaxWithCrossEntropy run on XPU, " + "ignore_index should be <=0, however it's %d", + ignore_index)); const int rank = context.Attr("rank"); auto& dev_ctx = context.template device_context(); diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 9c531f6af7cb6189f645c98155c23c2b6bb17d7c..f820acfa8f112054d7d6501f128d7ce5adda245b 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -529,6 +529,9 @@ class ParallelCrossEntropy(paddle.nn.Layer): mp_group(Group): The tensor parallel group. name(str, optional): Normally there is no need for user to set this parameter. For detailed information, please refer to :ref:`api_guide_Name` . + ignore_index (int, optional): Specifies a target value that is ignored and + does not contribute to the loss. A negative value means that no label value + needs to be ignored. Default is -100 . Examples: .. code-block:: python @@ -536,7 +539,7 @@ class ParallelCrossEntropy(paddle.nn.Layer): loss = loss_func(img, lable) """ - def __init__(self, mp_group=None, name=None): + def __init__(self, mp_group=None, name=None, ignore_index=-100): super().__init__() self.name = name self.model_parallel_group = ( @@ -554,9 +557,13 @@ class ParallelCrossEntropy(paddle.nn.Layer): if mp_group is None else mp_group.rank ) + self.ignore_index = ignore_index def forward(self, input, label): loss = mp_ops._c_softmax_with_cross_entropy( - input, label, group=self.model_parallel_group + input, + label, + group=self.model_parallel_group, + ignore_index=self.ignore_index, ) return loss diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 8aae54c967d4a24ef482702c29a35728a4c9320f..fade4aa61ce84cedc8cfad4d1ae214f76ae0c3a7 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -357,7 +357,11 @@ class _Linear(Layer): def _c_softmax_with_cross_entropy( - logits, label, group=None, return_softmax=False + logits, + label, + group=None, + return_softmax=False, + ignore_index=-100, ): if group is not None and not group.is_member(): return @@ -384,7 +388,16 @@ def _c_softmax_with_cross_entropy( if in_dygraph_mode(): softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy( - logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks + logits, + label, + 'ring_id', + ring_id, + 'rank', + rank, + 'nranks', + nranks, + 'ignore_index', + ignore_index, ) if not return_softmax: return loss @@ -395,6 +408,7 @@ def _c_softmax_with_cross_entropy( 'ring_id': ring_id, 'rank': rank, 'nranks': nranks, + 'ignore_index': ignore_index, } helper = LayerHelper('c_softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) diff --git a/python/paddle/fluid/tests/unittests/c_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/c_softmax_with_cross_entropy_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0217f7ee710f53fcbc6efb1809d0ebeb5fe3f829 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/c_softmax_with_cross_entropy_op.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.collective import _c_softmax_with_cross_entropy + + +def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.0) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): + if soft_label: + return (-label * np.log(softmax)).sum(axis=axis, keepdims=True) + shape = softmax.shape + axis %= len(shape) + n = int(np.prod(shape[:axis])) + axis_dim = shape[axis] + remain = int(np.prod(shape[axis + 1 :])) + softmax_reshape = softmax.reshape((n, axis_dim, remain)) + label_reshape = label.reshape((n, 1, remain)) + result = np.zeros_like(label_reshape, dtype=softmax.dtype) + for i in range(n): + for j in range(remain): + lbl = label_reshape[i, 0, j] + if lbl != ignore_index: + result[i, 0, j] -= np.log(softmax_reshape[i, lbl, j]) + return result.reshape(label.shape) + + +def softmax_with_cross_entropy_grad(softmax, label, loss_grad, axis): + logit_grad = softmax.copy() + shape = softmax.shape + axis %= len(shape) + n = int(np.prod(shape[:axis])) + d = int(np.prod(shape[axis:])) + for i in range(n * d): + row = int(i / d) + col = i % d + if col == label[row]: + logit_grad[row][col] = (logit_grad[row][col] - 1.0) * loss_grad[row] + else: + logit_grad[row][col] = logit_grad[row][col] * loss_grad[row] + return logit_grad + + +class TestCSoftmaxWithCrossEntropy(unittest.TestCase): + def test_model(self, data_type="float32"): + self.num_class = 1000 + self.batch_size = 1024 + fleet.init(is_collective=True) + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2} + + rank = fleet.worker_index() + + # get data that is shared by both ranks + np.random.seed(os.getuid()) + label = np.random.randint( + 0, self.num_class, size=(self.batch_size, 1), dtype='int32' + ) + ignore_index = label[0][0] + + local_elements = int(self.num_class / 2) + # get input data for rank 0 + np.random.seed(0) + input0 = np.random.uniform( + low=-10.0, high=10.0, size=(self.batch_size, local_elements) + ).astype(data_type) + + # get input data for rank 1 + np.random.seed(1) + input1 = np.random.uniform( + low=-10.0, high=10.0, size=(self.batch_size, local_elements) + ).astype(data_type) + + # get combined input data + inputs = np.concatenate((input0, input1), axis=1) + + if rank == 0: + loss, softmax = _c_softmax_with_cross_entropy( + paddle.to_tensor(input0), + paddle.to_tensor(label), + ignore_index=ignore_index, + return_softmax=True, + ) + else: + loss, softmax = _c_softmax_with_cross_entropy( + paddle.to_tensor(input1), + paddle.to_tensor(label), + ignore_index=ignore_index, + return_softmax=True, + ) + paddle.device.cuda.synchronize() + softmax_list = [] + paddle.distributed.all_gather(softmax_list, softmax) + + # calculate analytic result + need_softmax = np.apply_along_axis(stable_softmax, 1, inputs) + need_loss = cross_entropy( + need_softmax, label, False, 1, ignore_index=ignore_index + ) + + softmax = np.concatenate( + (softmax_list[0].numpy(), softmax_list[1].numpy()), axis=1 + ) + + # compare results + rtol = 1e-6 + np.testing.assert_allclose(loss.numpy(), need_loss, rtol=rtol) + np.testing.assert_allclose(softmax, need_softmax, rtol=rtol) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_c_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_c_softmax_with_cross_entropy_op.py new file mode 100644 index 0000000000000000000000000000000000000000..44ce71285af33c06aec84d4bb2b78bcedba1cd0d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_c_softmax_with_cross_entropy_op.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import sys +import unittest + +sys.path.append(".") + + +class TestCSoftmaxWithCrossEntropy(unittest.TestCase): + def pdrun(self): + cmd = [ + sys.executable, + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "c_softmax_with_cross_entropy_op.py", + ] + proc = subprocess.Popen(cmd) + return proc + + def test_c_softmax_with_cross_entropy_op(self): + p = self.pdrun() + p.wait() + + +if __name__ == '__main__': + unittest.main()