From bff9e28e54bbb8cce1db8b81b06db1b99b225344 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Thu, 24 Mar 2022 17:27:42 +0800 Subject: [PATCH] support dp for class_center_sample and margin_cross_entropy (#39852) --- .../unittests/test_class_center_sample_op.py | 14 ++++++- .../unittests/test_margin_cross_entropy_op.py | 23 ++++++++++ python/paddle/nn/functional/common.py | 34 ++++++++++----- python/paddle/nn/functional/loss.py | 42 ++++++++++++------- 4 files changed, 87 insertions(+), 26 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py b/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py index 29cae0eb001..eb7d05df492 100644 --- a/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py @@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase): remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample( label, self.num_classes, self.num_samples) - print(remapped_label, sampled_class_index) + + def test_group_value(): + for place in self.places: + with paddle.fluid.dygraph.guard(place): + label_np = np.random.randint( + 0, + self.num_classes, (self.batch_size, ), + dtype=self.dtype) + label = paddle.to_tensor(label_np) + + remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample( + label, self.num_classes, self.num_samples, group=True) self.assertRaises(ValueError, test_empty_label) + self.assertRaises(ValueError, test_group_value) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py index 15730710adf..2b511b9eb44 100644 --- a/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py @@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase): return_softmax=True, reduction=None) + def test_group_value(): + for place in self.places: + with paddle.fluid.dygraph.guard(place): + labels_np = np.random.randint( + 0, self.num_class, (self.batch_dim, ), dtype="int64") + logits_np = np.random.uniform( + -0.99, 0.99, + [self.batch_dim, self.num_class]).astype(self.dtype) + labels = paddle.to_tensor(labels_np) + logits = paddle.to_tensor(logits_np) + + loss, softmax = paddle.nn.functional.margin_cross_entropy( + logits, + labels, + margin1=self.margin1, + margin2=self.margin2, + margin3=self.margin3, + scale=self.scale, + return_softmax=True, + reduction=None, + group=True) + self.assertRaises(ValueError, test_dim) self.assertRaises(NotImplementedError, test_label_type) + self.assertRaises(ValueError, test_group_value) if __name__ == '__main__': diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 9e78ca6be3f..e757fbf5348 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1651,16 +1651,21 @@ def class_center_sample(label, num_classes, num_samples, group=None): .. hint:: If the number of the positive class centers is greater than the input num_samples, it keeps all the positive class centers and the shape of sampled_class_center will be [num_positive_class_centers]. - + The API supports CPU, single GPU and multi GPU. + For data parallel mode, set ``group=False``. + + For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group. + Args: label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes) num_classes (int): A positive integer to specify the number of classes at local rank. Note that num_classes of each GPU can be different. num_samples (int): A positive integer to specify the number of class center to sample. - group (Group, optional): The abstract representation of group. - See paddle.distributed.collective.Group. Default is ``None``. + group (Group, optional): The group instance return by paddle.distributed.new_group + or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks). + Default is ``None``. Returns: Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center, @@ -1733,18 +1738,25 @@ def class_center_sample(label, num_classes, num_samples, group=None): #Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True, # [0, 1, 2, 3, 5, 7, 8]) """ - if group is not None and not group.is_member(): + if not (group == False or group is None or hasattr(group, 'is_member')): + raise ValueError( + 'Expected group is False, None or instance of paddle.distributed.collective.Group \ + (got group: {})'.format(group)) + return + + if hasattr(group, 'is_member') and not group.is_member(): return - ring_id = 0 if group is None else group.id + ring_id = 0 rank = 0 nranks = 1 - if core.is_compiled_with_dist(): - parallel_env = paddle.distributed.ParallelEnv() - global_rank = parallel_env.rank - rank = global_rank if group is None else group.get_group_rank( - global_rank) - nranks = parallel_env.world_size if group is None else group.nranks + if group != False: + if core.is_compiled_with_dist(): + parallel_env = paddle.distributed.ParallelEnv() + global_rank = parallel_env.rank + rank = global_rank if group is None else group.get_group_rank( + global_rank) + nranks = parallel_env.world_size if group is None else group.nranks if num_samples > num_classes: raise ValueError( diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 10d4073b80c..b4594986f41 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits, r""" .. math:: - L=-\\frac{1}{N}\sum^N_{i=1}\log\\frac{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\\neq y_i} e^{scos\\theta_{y_i}}} + L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}} - where the :math:`\\theta_{y_i}` is the angle between the feature :math:`x` and + where the :math:`\theta_{y_i}` is the angle between the feature :math:`x` and the representation of class :math:`i`. The details of ArcFace loss could be referred to https://arxiv.org/abs/1801.07698. .. hint:: - The API supports model parallel and single GPU. And logits.shape[-1] can be different at each rank. + The API supports single GPU and multi GPU, and don't supports CPU. + + For data parallel mode, set ``group=False``. + + For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group. + And logits.shape[-1] can be different at each rank. Args: logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W. @@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits, margin2 (float, optional): m2 of margin loss, default value is `0.5`. margin3 (float, optional): m3 of margin loss, default value is `0.0`. scale (float, optional): s of margin loss, default value is `64.0`. - group (Group, optional): The abstract representation of group, see paddle.distributed.collective.Group. - Default `None`. + group (Group, optional): The group instance return by paddle.distributed.new_group + or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks). + Default is ``None``. return_softmax (bool, optional): Whether return softmax probability. Default value is `False`. reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, return the average of loss; @@ -1296,24 +1302,32 @@ def margin_cross_entropy(logits, """ assert reduction in ['mean', 'sum', 'none', None] - if group is not None and not group.is_member(): + if not (group == False or group is None or hasattr(group, 'is_member')): + raise ValueError( + 'Expected group is False, None or instance of paddle.distributed.collective.Group \ + (got group: {})'.format(group)) return - ring_id = 0 if group is None else group.id + if hasattr(group, 'is_member') and not group.is_member(): + return + + ring_id = 0 rank = 0 nranks = 1 - if core.is_compiled_with_dist(): - parallel_env = paddle.distributed.ParallelEnv() - global_rank = parallel_env.rank - rank = global_rank if group is None else group.get_group_rank( - global_rank) - nranks = parallel_env.world_size if group is None else group.nranks + if group != False: + ring_id = 0 if group is None else group.id + if core.is_compiled_with_dist(): + parallel_env = paddle.distributed.ParallelEnv() + global_rank = parallel_env.rank + rank = global_rank if group is None else group.get_group_rank( + global_rank) + nranks = parallel_env.world_size if group is None else group.nranks input_dims = len(list(logits.shape)) label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: raise ValueError( - 'Expected nput_dims - 1 = label_dims or input_dims == label_dims\ + 'Expected input_dims - 1 = label_dims or input_dims == label_dims\ (got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=-1) -- GitLab