未验证 提交 bff9e28e 编写于 作者: G Guoxia Wang 提交者: GitHub

support dp for class_center_sample and margin_cross_entropy (#39852)

上级 a9164245
......@@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase):
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples)
print(remapped_label, sampled_class_index)
def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
label_np = np.random.randint(
0,
self.num_classes, (self.batch_size, ),
dtype=self.dtype)
label = paddle.to_tensor(label_np)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples, group=True)
self.assertRaises(ValueError, test_empty_label)
self.assertRaises(ValueError, test_group_value)
if __name__ == '__main__':
......
......@@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
return_softmax=True,
reduction=None)
def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
labels_np = np.random.randint(
0, self.num_class, (self.batch_dim, ), dtype="int64")
logits_np = np.random.uniform(
-0.99, 0.99,
[self.batch_dim, self.num_class]).astype(self.dtype)
labels = paddle.to_tensor(labels_np)
logits = paddle.to_tensor(logits_np)
loss, softmax = paddle.nn.functional.margin_cross_entropy(
logits,
labels,
margin1=self.margin1,
margin2=self.margin2,
margin3=self.margin3,
scale=self.scale,
return_softmax=True,
reduction=None,
group=True)
self.assertRaises(ValueError, test_dim)
self.assertRaises(NotImplementedError, test_label_type)
self.assertRaises(ValueError, test_group_value)
if __name__ == '__main__':
......
......@@ -1654,13 +1654,18 @@ def class_center_sample(label, num_classes, num_samples, group=None):
The API supports CPU, single GPU and multi GPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
Args:
label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes)
num_classes (int): A positive integer to specify the number of classes at local rank.
Note that num_classes of each GPU can be different.
num_samples (int): A positive integer to specify the number of class center to sample.
group (Group, optional): The abstract representation of group.
See paddle.distributed.collective.Group. Default is ``None``.
group (Group, optional): The group instance return by paddle.distributed.new_group
or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
Returns:
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
......@@ -1733,12 +1738,19 @@ def class_center_sample(label, num_classes, num_samples, group=None):
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [0, 1, 2, 3, 5, 7, 8])
"""
if group is not None and not group.is_member():
if not (group == False or group is None or hasattr(group, 'is_member')):
raise ValueError(
'Expected group is False, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return
if hasattr(group, 'is_member') and not group.is_member():
return
ring_id = 0 if group is None else group.id
ring_id = 0
rank = 0
nranks = 1
if group != False:
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
......
......@@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits,
r"""
.. math::
L=-\\frac{1}{N}\sum^N_{i=1}\log\\frac{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\\neq y_i} e^{scos\\theta_{y_i}}}
L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
where the :math:`\\theta_{y_i}` is the angle between the feature :math:`x` and
where the :math:`\theta_{y_i}` is the angle between the feature :math:`x` and
the representation of class :math:`i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
.. hint::
The API supports model parallel and single GPU. And logits.shape[-1] can be different at each rank.
The API supports single GPU and multi GPU, and don't supports CPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
And logits.shape[-1] can be different at each rank.
Args:
logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W.
......@@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits,
margin2 (float, optional): m2 of margin loss, default value is `0.5`.
margin3 (float, optional): m3 of margin loss, default value is `0.0`.
scale (float, optional): s of margin loss, default value is `64.0`.
group (Group, optional): The abstract representation of group, see paddle.distributed.collective.Group.
Default `None`.
group (Group, optional): The group instance return by paddle.distributed.new_group
or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
return_softmax (bool, optional): Whether return softmax probability. Default value is `False`.
reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, return the average of loss;
......@@ -1296,12 +1302,20 @@ def margin_cross_entropy(logits,
"""
assert reduction in ['mean', 'sum', 'none', None]
if group is not None and not group.is_member():
if not (group == False or group is None or hasattr(group, 'is_member')):
raise ValueError(
'Expected group is False, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return
ring_id = 0 if group is None else group.id
if hasattr(group, 'is_member') and not group.is_member():
return
ring_id = 0
rank = 0
nranks = 1
if group != False:
ring_id = 0 if group is None else group.id
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
......@@ -1313,7 +1327,7 @@ def margin_cross_entropy(logits,
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
'Expected input_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册