未验证 提交 5c76b38b 编写于 作者: G Ghost Screaming 提交者: GitHub

Support ignore_index for c_softmax_with_cross_entropy_op. (#52157)

* Support ignore_index for c_softmax_with_cross_entropy_op.

* Polish code. Remove useless comments and add Testcase.

* Polish code for TestCase.

* Polish code.

* Polish code style.

* Polish code.

* Change loss calculation formula and ignore_index dtype.

* Polish TestCase.
上级 7067763e
......@@ -106,6 +106,10 @@ class CSoftmaxWithCrossEntropyOpMaker
"Input(Logits) "
"except the shape in dimension :attr:`axis` as 1. The cross "
"entropy loss.");
AddAttr<int64_t>("ignore_index",
"(int default -100) Specifies a target value "
"that is ignored and does not contribute to the loss.")
.SetDefault(-100);
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("rank",
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/phi/kernels/funcs/math.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/softmax_impl.h"
namespace paddle {
......@@ -59,6 +61,24 @@ __global__ void MaskLabelByIndex(T* predicted_logits,
}
}
template <typename T, typename IndexT>
__global__ void CaculateLoss(T* loss,
const T* predict_logits,
const T* sum_exp_logits,
const IndexT* label,
const int64_t ignore_index,
const int N) {
CUDA_KERNEL_LOOP(i, N) {
auto real_label = static_cast<int64_t>(label[i]);
loss[i] = ignore_index == real_label
? static_cast<T>(0)
: phi::funcs::TolerableValue<T>()(
phi::funcs::TolerableValue<T>()(
phi::funcs::real_log(sum_exp_logits[i])) -
predict_logits[i]);
}
}
template <typename T, typename IndexT>
__global__ void MaskLabelByIndexGrad(T* logits_grad,
const T* loss_grad,
......@@ -66,11 +86,15 @@ __global__ void MaskLabelByIndexGrad(T* logits_grad,
const int start_index,
const int end_index,
const int64_t N,
const int64_t D) {
const int64_t D,
const int64_t ignore_index) {
CUDA_KERNEL_LOOP(i, N * D) {
auto row = i / D;
auto col = i % D;
if ((col + start_index) == labels[row]) {
auto lbl = static_cast<int64_t>(labels[row]);
if (lbl == ignore_index) {
logits_grad[i] = static_cast<T>(0.0);
} else if ((col + start_index) == labels[row]) {
logits_grad[i] = (logits_grad[i] - static_cast<T>(1.0)) * loss_grad[row];
} else {
logits_grad[i] *= loss_grad[row];
......@@ -102,6 +126,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
phi::DenseTensor* softmax = ctx.Output<phi::DenseTensor>("Softmax");
phi::DenseTensor* loss = ctx.Output<phi::DenseTensor>("Loss");
const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");
const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");
......@@ -234,14 +259,23 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
comm->comm(),
stream));
auto eigen_loss = phi::funcs::EigenMatrix<T>::From(loss_2d);
auto eigen_predicted_logits =
phi::funcs::EigenMatrix<T>::From(predicted_logits);
eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue<T>()) -
eigen_predicted_logits)
.unaryExpr(phi::funcs::TolerableValue<T>());
if (label_type == framework::proto::VarType::INT32) {
CaculateLoss<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_2d.data<T>(),
predicted_logits.data<T>(),
sum_exp_logits.data<T>(),
labels->data<int32_t>(),
ignore_index,
N);
} else {
CaculateLoss<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_2d.data<T>(),
predicted_logits.data<T>(),
sum_exp_logits.data<T>(),
labels->data<int64_t>(),
ignore_index,
N);
}
eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax *
......@@ -257,6 +291,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
phi::DenseTensor* softmax = ctx.Output<phi::DenseTensor>("Softmax");
phi::DenseTensor* loss = ctx.Output<phi::DenseTensor>("Loss");
const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");
const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");
......@@ -371,14 +406,23 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
opts.reduce_op = distributed::ReduceOp::SUM;
pg->AllReduce(in_out, in_out, opts)->Synchronize();
auto eigen_loss = phi::funcs::EigenMatrix<T>::From(loss_2d);
auto eigen_predicted_logits =
phi::funcs::EigenMatrix<T>::From(predicted_logits);
eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(phi::funcs::TolerableValue<T>()) -
eigen_predicted_logits)
.unaryExpr(phi::funcs::TolerableValue<T>());
if (label_type == framework::proto::VarType::INT32) {
CaculateLoss<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_2d.data<T>(),
predicted_logits.data<T>(),
sum_exp_logits.data<T>(),
labels->data<int32_t>(),
ignore_index,
N);
} else {
CaculateLoss<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_2d.data<T>(),
predicted_logits.data<T>(),
sum_exp_logits.data<T>(),
labels->data<int64_t>(),
ignore_index,
N);
}
eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax *
......@@ -397,6 +441,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
context.Output<phi::DenseTensor>(framework::GradVarName("Logits"));
const phi::DenseTensor* softmax =
context.Input<phi::DenseTensor>("Softmax");
const int64_t ignore_index = context.Attr<int64_t>("ignore_index");
const int rank = context.Attr<int>("rank");
auto& dev_ctx = context.template device_context<phi::GPUContext>();
......@@ -426,7 +472,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
start_index,
end_index,
N,
D);
D,
ignore_index);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndexGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(logit_grad_2d.data<T>(),
......@@ -435,7 +482,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
start_index,
end_index,
N,
D);
D,
ignore_index);
}
}
};
......
......@@ -33,6 +33,13 @@ template <typename DeviceContext, typename T>
class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int ignore_index = ctx.Attr<int>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %d",
ignore_index));
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
......@@ -453,6 +460,13 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
context.Output<phi::DenseTensor>(framework::GradVarName("Logits"));
const phi::DenseTensor* softmax =
context.Input<phi::DenseTensor>("Softmax");
const int ignore_index = context.Attr<int>("ignore_index");
PADDLE_ENFORCE_LT(ignore_index,
0,
platform::errors::InvalidArgument(
"When SoftmaxWithCrossEntropy run on XPU, "
"ignore_index should be <=0, however it's %d",
ignore_index));
const int rank = context.Attr<int>("rank");
auto& dev_ctx = context.template device_context<DeviceContext>();
......
......@@ -529,6 +529,9 @@ class ParallelCrossEntropy(paddle.nn.Layer):
mp_group(Group): The tensor parallel group.
name(str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
ignore_index (int, optional): Specifies a target value that is ignored and
does not contribute to the loss. A negative value means that no label value
needs to be ignored. Default is -100 .
Examples:
.. code-block:: python
......@@ -536,7 +539,7 @@ class ParallelCrossEntropy(paddle.nn.Layer):
loss = loss_func(img, lable)
"""
def __init__(self, mp_group=None, name=None):
def __init__(self, mp_group=None, name=None, ignore_index=-100):
super().__init__()
self.name = name
self.model_parallel_group = (
......@@ -554,9 +557,13 @@ class ParallelCrossEntropy(paddle.nn.Layer):
if mp_group is None
else mp_group.rank
)
self.ignore_index = ignore_index
def forward(self, input, label):
loss = mp_ops._c_softmax_with_cross_entropy(
input, label, group=self.model_parallel_group
input,
label,
group=self.model_parallel_group,
ignore_index=self.ignore_index,
)
return loss
......@@ -357,7 +357,11 @@ class _Linear(Layer):
def _c_softmax_with_cross_entropy(
logits, label, group=None, return_softmax=False
logits,
label,
group=None,
return_softmax=False,
ignore_index=-100,
):
if group is not None and not group.is_member():
return
......@@ -384,7 +388,16 @@ def _c_softmax_with_cross_entropy(
if in_dygraph_mode():
softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy(
logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks
logits,
label,
'ring_id',
ring_id,
'rank',
rank,
'nranks',
nranks,
'ignore_index',
ignore_index,
)
if not return_softmax:
return loss
......@@ -395,6 +408,7 @@ def _c_softmax_with_cross_entropy(
'ring_id': ring_id,
'rank': rank,
'nranks': nranks,
'ignore_index': ignore_index,
}
helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.collective import _c_softmax_with_cross_entropy
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
# clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF
shiftx = (x - np.max(x)).clip(-64.0)
exps = np.exp(shiftx)
return exps / np.sum(exps)
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
if soft_label:
return (-label * np.log(softmax)).sum(axis=axis, keepdims=True)
shape = softmax.shape
axis %= len(shape)
n = int(np.prod(shape[:axis]))
axis_dim = shape[axis]
remain = int(np.prod(shape[axis + 1 :]))
softmax_reshape = softmax.reshape((n, axis_dim, remain))
label_reshape = label.reshape((n, 1, remain))
result = np.zeros_like(label_reshape, dtype=softmax.dtype)
for i in range(n):
for j in range(remain):
lbl = label_reshape[i, 0, j]
if lbl != ignore_index:
result[i, 0, j] -= np.log(softmax_reshape[i, lbl, j])
return result.reshape(label.shape)
def softmax_with_cross_entropy_grad(softmax, label, loss_grad, axis):
logit_grad = softmax.copy()
shape = softmax.shape
axis %= len(shape)
n = int(np.prod(shape[:axis]))
d = int(np.prod(shape[axis:]))
for i in range(n * d):
row = int(i / d)
col = i % d
if col == label[row]:
logit_grad[row][col] = (logit_grad[row][col] - 1.0) * loss_grad[row]
else:
logit_grad[row][col] = logit_grad[row][col] * loss_grad[row]
return logit_grad
class TestCSoftmaxWithCrossEntropy(unittest.TestCase):
def test_model(self, data_type="float32"):
self.num_class = 1000
self.batch_size = 1024
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2}
rank = fleet.worker_index()
# get data that is shared by both ranks
np.random.seed(os.getuid())
label = np.random.randint(
0, self.num_class, size=(self.batch_size, 1), dtype='int32'
)
ignore_index = label[0][0]
local_elements = int(self.num_class / 2)
# get input data for rank 0
np.random.seed(0)
input0 = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, local_elements)
).astype(data_type)
# get input data for rank 1
np.random.seed(1)
input1 = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, local_elements)
).astype(data_type)
# get combined input data
inputs = np.concatenate((input0, input1), axis=1)
if rank == 0:
loss, softmax = _c_softmax_with_cross_entropy(
paddle.to_tensor(input0),
paddle.to_tensor(label),
ignore_index=ignore_index,
return_softmax=True,
)
else:
loss, softmax = _c_softmax_with_cross_entropy(
paddle.to_tensor(input1),
paddle.to_tensor(label),
ignore_index=ignore_index,
return_softmax=True,
)
paddle.device.cuda.synchronize()
softmax_list = []
paddle.distributed.all_gather(softmax_list, softmax)
# calculate analytic result
need_softmax = np.apply_along_axis(stable_softmax, 1, inputs)
need_loss = cross_entropy(
need_softmax, label, False, 1, ignore_index=ignore_index
)
softmax = np.concatenate(
(softmax_list[0].numpy(), softmax_list[1].numpy()), axis=1
)
# compare results
rtol = 1e-6
np.testing.assert_allclose(loss.numpy(), need_loss, rtol=rtol)
np.testing.assert_allclose(softmax, need_softmax, rtol=rtol)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
import unittest
sys.path.append(".")
class TestCSoftmaxWithCrossEntropy(unittest.TestCase):
def pdrun(self):
cmd = [
sys.executable,
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"c_softmax_with_cross_entropy_op.py",
]
proc = subprocess.Popen(cmd)
return proc
def test_c_softmax_with_cross_entropy_op(self):
p = self.pdrun()
p.wait()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册