Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bff9e28e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bff9e28e
编写于
3月 24, 2022
作者:
G
Guoxia Wang
提交者:
GitHub
3月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support dp for class_center_sample and margin_cross_entropy (#39852)
上级
a9164245
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
87 addition
and
26 deletion
+87
-26
python/paddle/fluid/tests/unittests/test_class_center_sample_op.py
...ddle/fluid/tests/unittests/test_class_center_sample_op.py
+13
-1
python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py
...dle/fluid/tests/unittests/test_margin_cross_entropy_op.py
+23
-0
python/paddle/nn/functional/common.py
python/paddle/nn/functional/common.py
+23
-11
python/paddle/nn/functional/loss.py
python/paddle/nn/functional/loss.py
+28
-14
未找到文件。
python/paddle/fluid/tests/unittests/test_class_center_sample_op.py
浏览文件 @
bff9e28e
...
@@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase):
...
@@ -241,9 +241,21 @@ class TestClassCenterSampleAPIError1(unittest.TestCase):
remapped_label
,
sampled_class_index
=
paddle
.
nn
.
functional
.
class_center_sample
(
remapped_label
,
sampled_class_index
=
paddle
.
nn
.
functional
.
class_center_sample
(
label
,
self
.
num_classes
,
self
.
num_samples
)
label
,
self
.
num_classes
,
self
.
num_samples
)
print
(
remapped_label
,
sampled_class_index
)
def
test_group_value
():
for
place
in
self
.
places
:
with
paddle
.
fluid
.
dygraph
.
guard
(
place
):
label_np
=
np
.
random
.
randint
(
0
,
self
.
num_classes
,
(
self
.
batch_size
,
),
dtype
=
self
.
dtype
)
label
=
paddle
.
to_tensor
(
label_np
)
remapped_label
,
sampled_class_index
=
paddle
.
nn
.
functional
.
class_center_sample
(
label
,
self
.
num_classes
,
self
.
num_samples
,
group
=
True
)
self
.
assertRaises
(
ValueError
,
test_empty_label
)
self
.
assertRaises
(
ValueError
,
test_empty_label
)
self
.
assertRaises
(
ValueError
,
test_group_value
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_margin_cross_entropy_op.py
浏览文件 @
bff9e28e
...
@@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
...
@@ -400,8 +400,31 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
return_softmax
=
True
,
return_softmax
=
True
,
reduction
=
None
)
reduction
=
None
)
def
test_group_value
():
for
place
in
self
.
places
:
with
paddle
.
fluid
.
dygraph
.
guard
(
place
):
labels_np
=
np
.
random
.
randint
(
0
,
self
.
num_class
,
(
self
.
batch_dim
,
),
dtype
=
"int64"
)
logits_np
=
np
.
random
.
uniform
(
-
0.99
,
0.99
,
[
self
.
batch_dim
,
self
.
num_class
]).
astype
(
self
.
dtype
)
labels
=
paddle
.
to_tensor
(
labels_np
)
logits
=
paddle
.
to_tensor
(
logits_np
)
loss
,
softmax
=
paddle
.
nn
.
functional
.
margin_cross_entropy
(
logits
,
labels
,
margin1
=
self
.
margin1
,
margin2
=
self
.
margin2
,
margin3
=
self
.
margin3
,
scale
=
self
.
scale
,
return_softmax
=
True
,
reduction
=
None
,
group
=
True
)
self
.
assertRaises
(
ValueError
,
test_dim
)
self
.
assertRaises
(
ValueError
,
test_dim
)
self
.
assertRaises
(
NotImplementedError
,
test_label_type
)
self
.
assertRaises
(
NotImplementedError
,
test_label_type
)
self
.
assertRaises
(
ValueError
,
test_group_value
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/nn/functional/common.py
浏览文件 @
bff9e28e
...
@@ -1651,16 +1651,21 @@ def class_center_sample(label, num_classes, num_samples, group=None):
...
@@ -1651,16 +1651,21 @@ def class_center_sample(label, num_classes, num_samples, group=None):
.. hint::
.. hint::
If the number of the positive class centers is greater than the input num_samples, it keeps all the positive
If the number of the positive class centers is greater than the input num_samples, it keeps all the positive
class centers and the shape of sampled_class_center will be [num_positive_class_centers].
class centers and the shape of sampled_class_center will be [num_positive_class_centers].
The API supports CPU, single GPU and multi GPU.
The API supports CPU, single GPU and multi GPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
Args:
Args:
label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes)
label (Tensor): 1-D tensor with shape [N], each label in [0, num_classes)
num_classes (int): A positive integer to specify the number of classes at local rank.
num_classes (int): A positive integer to specify the number of classes at local rank.
Note that num_classes of each GPU can be different.
Note that num_classes of each GPU can be different.
num_samples (int): A positive integer to specify the number of class center to sample.
num_samples (int): A positive integer to specify the number of class center to sample.
group (Group, optional): The abstract representation of group.
group (Group, optional): The group instance return by paddle.distributed.new_group
See paddle.distributed.collective.Group. Default is ``None``.
or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
Returns:
Returns:
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
...
@@ -1733,18 +1738,25 @@ def class_center_sample(label, num_classes, num_samples, group=None):
...
@@ -1733,18 +1738,25 @@ def class_center_sample(label, num_classes, num_samples, group=None):
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [0, 1, 2, 3, 5, 7, 8])
# [0, 1, 2, 3, 5, 7, 8])
"""
"""
if
group
is
not
None
and
not
group
.
is_member
():
if
not
(
group
==
False
or
group
is
None
or
hasattr
(
group
,
'is_member'
)):
raise
ValueError
(
'Expected group is False, None or instance of paddle.distributed.collective.Group
\
(got group: {})'
.
format
(
group
))
return
if
hasattr
(
group
,
'is_member'
)
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
rank
=
0
rank
=
0
nranks
=
1
nranks
=
1
if
core
.
is_compiled_with_dist
():
if
group
!=
False
:
parallel_env
=
paddle
.
distributed
.
ParallelEnv
()
if
core
.
is_compiled_with_dist
():
global_rank
=
parallel_env
.
rank
parallel_env
=
paddle
.
distributed
.
ParallelEnv
()
rank
=
global_rank
if
group
is
None
else
group
.
get_group_rank
(
global_rank
=
parallel_env
.
rank
global_rank
)
rank
=
global_rank
if
group
is
None
else
group
.
get_group_rank
(
nranks
=
parallel_env
.
world_size
if
group
is
None
else
group
.
nranks
global_rank
)
nranks
=
parallel_env
.
world_size
if
group
is
None
else
group
.
nranks
if
num_samples
>
num_classes
:
if
num_samples
>
num_classes
:
raise
ValueError
(
raise
ValueError
(
...
...
python/paddle/nn/functional/loss.py
浏览文件 @
bff9e28e
...
@@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits,
...
@@ -1119,14 +1119,19 @@ def margin_cross_entropy(logits,
r
"""
r
"""
.. math::
.. math::
L=-\
\frac{1}{N}\sum^N_{i=1}\log\\frac{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\\neq y_i} e^{scos\
\theta_{y_i}}}
L=-\
frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos
\theta_{y_i}}}
where the :math:`\
\
theta_{y_i}` is the angle between the feature :math:`x` and
where the :math:`\theta_{y_i}` is the angle between the feature :math:`x` and
the representation of class :math:`i`. The details of ArcFace loss
the representation of class :math:`i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
could be referred to https://arxiv.org/abs/1801.07698.
.. hint::
.. hint::
The API supports model parallel and single GPU. And logits.shape[-1] can be different at each rank.
The API supports single GPU and multi GPU, and don't supports CPU.
For data parallel mode, set ``group=False``.
For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
And logits.shape[-1] can be different at each rank.
Args:
Args:
logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W.
logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W.
...
@@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits,
...
@@ -1136,8 +1141,9 @@ def margin_cross_entropy(logits,
margin2 (float, optional): m2 of margin loss, default value is `0.5`.
margin2 (float, optional): m2 of margin loss, default value is `0.5`.
margin3 (float, optional): m3 of margin loss, default value is `0.0`.
margin3 (float, optional): m3 of margin loss, default value is `0.0`.
scale (float, optional): s of margin loss, default value is `64.0`.
scale (float, optional): s of margin loss, default value is `64.0`.
group (Group, optional): The abstract representation of group, see paddle.distributed.collective.Group.
group (Group, optional): The group instance return by paddle.distributed.new_group
Default `None`.
or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
Default is ``None``.
return_softmax (bool, optional): Whether return softmax probability. Default value is `False`.
return_softmax (bool, optional): Whether return softmax probability. Default value is `False`.
reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, return the average of loss;
If :attr:`reduction` is ``'mean'``, return the average of loss;
...
@@ -1296,24 +1302,32 @@ def margin_cross_entropy(logits,
...
@@ -1296,24 +1302,32 @@ def margin_cross_entropy(logits,
"""
"""
assert
reduction
in
[
'mean'
,
'sum'
,
'none'
,
None
]
assert
reduction
in
[
'mean'
,
'sum'
,
'none'
,
None
]
if
group
is
not
None
and
not
group
.
is_member
():
if
not
(
group
==
False
or
group
is
None
or
hasattr
(
group
,
'is_member'
)):
raise
ValueError
(
'Expected group is False, None or instance of paddle.distributed.collective.Group
\
(got group: {})'
.
format
(
group
))
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
if
hasattr
(
group
,
'is_member'
)
and
not
group
.
is_member
():
return
ring_id
=
0
rank
=
0
rank
=
0
nranks
=
1
nranks
=
1
if
core
.
is_compiled_with_dist
():
if
group
!=
False
:
parallel_env
=
paddle
.
distributed
.
ParallelEnv
()
ring_id
=
0
if
group
is
None
else
group
.
id
global_rank
=
parallel_env
.
rank
if
core
.
is_compiled_with_dist
():
rank
=
global_rank
if
group
is
None
else
group
.
get_group_rank
(
parallel_env
=
paddle
.
distributed
.
ParallelEnv
()
global_rank
)
global_rank
=
parallel_env
.
rank
nranks
=
parallel_env
.
world_size
if
group
is
None
else
group
.
nranks
rank
=
global_rank
if
group
is
None
else
group
.
get_group_rank
(
global_rank
)
nranks
=
parallel_env
.
world_size
if
group
is
None
else
group
.
nranks
input_dims
=
len
(
list
(
logits
.
shape
))
input_dims
=
len
(
list
(
logits
.
shape
))
label_dims
=
len
(
list
(
label
.
shape
))
label_dims
=
len
(
list
(
label
.
shape
))
if
input_dims
-
1
!=
label_dims
and
input_dims
!=
label_dims
:
if
input_dims
-
1
!=
label_dims
and
input_dims
!=
label_dims
:
raise
ValueError
(
raise
ValueError
(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims
\
'Expected
i
nput_dims - 1 = label_dims or input_dims == label_dims
\
(got nput_dims{}, label_dims{})'
.
format
(
input_dims
,
label_dims
))
(got nput_dims{}, label_dims{})'
.
format
(
input_dims
,
label_dims
))
if
input_dims
-
1
==
label_dims
:
if
input_dims
-
1
==
label_dims
:
label
=
paddle
.
unsqueeze
(
label
,
axis
=-
1
)
label
=
paddle
.
unsqueeze
(
label
,
axis
=-
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录