Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
9fecdbaf
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9fecdbaf
编写于
10月 13, 2021
作者:
B
Bin Lu
提交者:
GitHub
10月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update center_loss.py
上级
1ac84b07
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
18 addition
and
18 deletion
+18
-18
ppocr/losses/center_loss.py
ppocr/losses/center_loss.py
+18
-18
未找到文件。
ppocr/losses/center_loss.py
浏览文件 @
9fecdbaf
...
...
@@ -27,7 +27,6 @@ class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def
__init__
(
self
,
num_classes
=
6625
,
feat_dim
=
96
,
...
...
@@ -37,8 +36,7 @@ class CenterLoss(nn.Layer):
self
.
num_classes
=
num_classes
self
.
feat_dim
=
feat_dim
self
.
centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
"float64"
)
#random center
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
"float64"
)
if
init_center
:
assert
os
.
path
.
exists
(
...
...
@@ -60,22 +58,23 @@ class CenterLoss(nn.Layer):
batch_size
=
feats_reshape
.
shape
[
0
]
#calc feat * feat
dist1
=
paddle
.
sum
(
paddle
.
square
(
feats_reshape
),
axis
=
1
,
keepdim
=
True
)
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
#calc l2 distance between feats and centers
square_feat
=
paddle
.
sum
(
paddle
.
square
(
feats_reshape
),
axis
=
1
,
keepdim
=
True
)
square_feat
=
paddle
.
expand
(
square_feat
,
[
batch_size
,
self
.
num_classes
])
#dist2 of centers
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
#num_classes
dist2
=
paddle
.
expand
(
dist2
,
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
dist2
=
paddle
.
transpose
(
dist2
,
[
1
,
0
])
square_center
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
square_center
=
paddle
.
expand
(
square_center
,
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
square_center
=
paddle
.
transpose
(
square_center
,
[
1
,
0
])
#first x * x + y * y
distmat
=
paddle
.
add
(
dist1
,
dist2
)
tmp
=
paddle
.
matmul
(
feats_reshape
,
distmat
=
paddle
.
add
(
square_feat
,
square_center
)
feat_dot_center
=
paddle
.
matmul
(
feats_reshape
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
distmat
=
distmat
-
2.0
*
tmp
distmat
=
distmat
-
2.0
*
feat_dot_center
#generate the mask
classes
=
paddle
.
arange
(
self
.
num_classes
).
astype
(
"int64"
)
...
...
@@ -83,7 +82,8 @@ class CenterLoss(nn.Layer):
paddle
.
unsqueeze
(
label
,
1
),
(
batch_size
,
self
.
num_classes
))
mask
=
paddle
.
equal
(
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
label
).
astype
(
"float64"
)
#get mask
label
).
astype
(
"float64"
)
dist
=
paddle
.
multiply
(
distmat
,
mask
)
loss
=
paddle
.
sum
(
paddle
.
clip
(
dist
,
min
=
1e-12
,
max
=
1e+12
))
/
batch_size
return
{
'loss_center'
:
loss
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录