未验证 提交 bff9e28e 编写于 作者: G Guoxia Wang 提交者: GitHub

support dp for class_center_sample and margin_cross_entropy (#39852)

上级 a9164245
...@@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase): ...@@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase):
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample( remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples) label, self.num_classes, self.num_samples)
print(remapped_label, sampled_class_index)
def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
label_np = np.random.randint(
0,
self.num_classes, (self.batch_size, ),
dtype=self.dtype)
label = paddle.to_tensor(label_np)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples, group=True)
self.assertRaises(ValueError, test_empty_label) self.assertRaises(ValueError, test_empty_label)
self.assertRaises(ValueError, test_group_value)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase): ...@@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
return_softmax=True, return_softmax=True,
reduction=None) reduction=None)
def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
labels_np = np.random.randint(
0, self.num_class, (self.batch_dim, ), dtype="int64")
logits_np = np.random.uniform(
-0.99, 0.99,
[self.batch_dim, self.num_class]).astype(self.dtype)
labels = paddle.to_tensor(labels_np)
logits = paddle.to_tensor(logits_np)
loss, softmax = paddle.nn.functional.margin_cross_entropy(
logits,
labels,
margin1=self.margin1,
margin2=self.margin2,
margin3=self.margin3,
scale=self.scale,
return_softmax=True,
reduction=None,
group=True)
self.assertRaises(ValueError, test_dim) self.assertRaises(ValueError, test_dim)
self.assertRaises(NotImplementedError, test_label_type) self.assertRaises(NotImplementedError, test_label_type)
self.assertRaises(ValueError, test_group_value)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -1651,16 +1651,21 @@ def class_center_sample(label, num_classes, num_samples, group=None): ...@@ -1651,16 +1651,21 @@ def class_center_sample(label, num_classes, num_samples, group=None):
.. hint:: .. hint::
If the number of the positive class centers is greater than the input num_samples, it keeps all the positive If the number of the positive class centers is greater than the input num_samples, it keeps all the positive
class centers and the shape of sampled_class_center will be [num_positive_class_centers]. class centers and the shape of sampled_class_center will be [num_positive_class_centers].
The API supports CPU, single GPU and multi GPU. The API supports CPU, single GPU and multi GPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
Args: Args:
label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes) label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes)
num_classes (int): A positive integer to specify the number of classes at local rank. num_classes (int): A positive integer to specify the number of classes at local rank.
Note that num_classes of each GPU can be different. Note that num_classes of each GPU can be different.
num_samples (int): A positive integer to specify the number of class center to sample. num_samples (int): A positive integer to specify the number of class center to sample.
group (Group, optional): The abstract representation of group. group (Group, optional): The group instance return by paddle.distributed.new_group
See paddle.distributed.collective.Group. Default is ``None``. or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
Returns: Returns:
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center, Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
...@@ -1733,18 +1738,25 @@ def class_center_sample(label, num_classes, num_samples, group=None): ...@@ -1733,18 +1738,25 @@ def class_center_sample(label, num_classes, num_samples, group=None):
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True, #Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [0, 1, 2, 3, 5, 7, 8]) # [0, 1, 2, 3, 5, 7, 8])
""" """
if group is not None and not group.is_member(): if not (group == False or group is None or hasattr(group, 'is_member')):
raise ValueError(
'Expected group is False, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return
if hasattr(group, 'is_member') and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0
rank = 0 rank = 0
nranks = 1 nranks = 1
if core.is_compiled_with_dist(): if group != False:
parallel_env = paddle.distributed.ParallelEnv() if core.is_compiled_with_dist():
global_rank = parallel_env.rank parallel_env = paddle.distributed.ParallelEnv()
rank = global_rank if group is None else group.get_group_rank( global_rank = parallel_env.rank
global_rank) rank = global_rank if group is None else group.get_group_rank(
nranks = parallel_env.world_size if group is None else group.nranks global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
if num_samples > num_classes: if num_samples > num_classes:
raise ValueError( raise ValueError(
......
...@@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits, ...@@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits,
r""" r"""
.. math:: .. math::
L=-\\frac{1}{N}\sum^N_{i=1}\log\\frac{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\\neq y_i} e^{scos\\theta_{y_i}}} L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
where the :math:`\\theta_{y_i}` is the angle between the feature :math:`x` and where the :math:`\theta_{y_i}` is the angle between the feature :math:`x` and
the representation of class :math:`i`. The details of ArcFace loss the representation of class :math:`i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698. could be referred to https://arxiv.org/abs/1801.07698.
.. hint:: .. hint::
The API supports model parallel and single GPU. And logits.shape[-1] can be different at each rank. The API supports single GPU and multi GPU, and don't supports CPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
And logits.shape[-1] can be different at each rank.
Args: Args:
logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W. logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W.
...@@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits, ...@@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits,
margin2 (float, optional): m2 of margin loss, default value is `0.5`. margin2 (float, optional): m2 of margin loss, default value is `0.5`.
margin3 (float, optional): m3 of margin loss, default value is `0.0`. margin3 (float, optional): m3 of margin loss, default value is `0.0`.
scale (float, optional): s of margin loss, default value is `64.0`. scale (float, optional): s of margin loss, default value is `64.0`.
group (Group, optional): The abstract representation of group, see paddle.distributed.collective.Group. group (Group, optional): The group instance return by paddle.distributed.new_group
Default `None`. or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
return_softmax (bool, optional): Whether return softmax probability. Default value is `False`. return_softmax (bool, optional): Whether return softmax probability. Default value is `False`.
reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``. reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, return the average of loss; If :attr:`reduction` is ``'mean'``, return the average of loss;
...@@ -1296,24 +1302,32 @@ def margin_cross_entropy(logits, ...@@ -1296,24 +1302,32 @@ def margin_cross_entropy(logits,
""" """
assert reduction in ['mean', 'sum', 'none', None] assert reduction in ['mean', 'sum', 'none', None]
if group is not None and not group.is_member(): if not (group == False or group is None or hasattr(group, 'is_member')):
raise ValueError(
'Expected group is False, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return return
ring_id = 0 if group is None else group.id if hasattr(group, 'is_member') and not group.is_member():
return
ring_id = 0
rank = 0 rank = 0
nranks = 1 nranks = 1
if core.is_compiled_with_dist(): if group != False:
parallel_env = paddle.distributed.ParallelEnv() ring_id = 0 if group is None else group.id
global_rank = parallel_env.rank if core.is_compiled_with_dist():
rank = global_rank if group is None else group.get_group_rank( parallel_env = paddle.distributed.ParallelEnv()
global_rank) global_rank = parallel_env.rank
nranks = parallel_env.world_size if group is None else group.nranks rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
input_dims = len(list(logits.shape)) input_dims = len(list(logits.shape))
label_dims = len(list(label.shape)) label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims: if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError( raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\ 'Expected input_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
if input_dims - 1 == label_dims: if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=-1) label = paddle.unsqueeze(label, axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册