Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
69f563d2
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
69f563d2
编写于
6月 03, 2021
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename losses -> loss
上级
51f0b78b
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
124 addition
and
95 deletion
+124
-95
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+1
-1
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+0
-0
ppcls/loss/celoss.py
ppcls/loss/celoss.py
+0
-0
ppcls/loss/centerloss.py
ppcls/loss/centerloss.py
+17
-10
ppcls/loss/comfunc.py
ppcls/loss/comfunc.py
+8
-7
ppcls/loss/emlloss.py
ppcls/loss/emlloss.py
+36
-28
ppcls/loss/msmloss.py
ppcls/loss/msmloss.py
+24
-18
ppcls/loss/npairsloss.py
ppcls/loss/npairsloss.py
+12
-11
ppcls/loss/trihardloss.py
ppcls/loss/trihardloss.py
+26
-20
ppcls/loss/triplet.py
ppcls/loss/triplet.py
+0
-0
未找到文件。
ppcls/engine/trainer.py
浏览文件 @
69f563d2
...
...
@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
from
ppcls.utils
import
logger
from
ppcls.data
import
build_dataloader
from
ppcls.arch
import
build_model
from
ppcls.loss
es
import
build_loss
from
ppcls.loss
import
build_loss
from
ppcls.arch.loss_metrics
import
build_metrics
from
ppcls.optimizer
import
build_optimizer
from
ppcls.utils.save_load
import
load_dygraph_pretrain
...
...
ppcls/loss
es
/__init__.py
→
ppcls/loss/__init__.py
浏览文件 @
69f563d2
文件已移动
ppcls/loss
es
/celoss.py
→
ppcls/loss/celoss.py
浏览文件 @
69f563d2
文件已移动
ppcls/loss
es
/centerloss.py
→
ppcls/loss/centerloss.py
浏览文件 @
69f563d2
...
...
@@ -5,12 +5,15 @@ import paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
CenterLoss
(
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
5013
,
feat_dim
=
2048
):
super
(
CenterLoss
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
feat_dim
=
feat_dim
self
.
centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
"float64"
)
#random center
self
.
centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
"float64"
)
#random center
def
__call__
(
self
,
input
,
target
):
"""
...
...
@@ -23,25 +26,29 @@ class CenterLoss(nn.Layer):
#calc feat * feat
dist1
=
paddle
.
sum
(
paddle
.
square
(
feats
),
axis
=
1
,
keepdim
=
True
)
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
#dist2 of centers
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
#num_classes
dist2
=
paddle
.
expand
(
dist2
,
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
#num_classes
dist2
=
paddle
.
expand
(
dist2
,
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
dist2
=
paddle
.
transpose
(
dist2
,
[
1
,
0
])
#first x * x + y * y
distmat
=
paddle
.
add
(
dist1
,
dist2
)
tmp
=
paddle
.
matmul
(
feats
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
distmat
=
distmat
-
2.0
*
tmp
tmp
=
paddle
.
matmul
(
feats
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
distmat
=
distmat
-
2.0
*
tmp
#generate the mask
classes
=
paddle
.
arange
(
self
.
num_classes
).
astype
(
"int64"
)
labels
=
paddle
.
expand
(
paddle
.
unsqueeze
(
labels
,
1
),
(
batch_size
,
self
.
num_classes
))
mask
=
paddle
.
equal
(
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
labels
).
astype
(
"float64"
)
#get mask
labels
=
paddle
.
expand
(
paddle
.
unsqueeze
(
labels
,
1
),
(
batch_size
,
self
.
num_classes
))
mask
=
paddle
.
equal
(
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
labels
).
astype
(
"float64"
)
#get mask
dist
=
paddle
.
multiply
(
distmat
,
mask
)
dist
=
paddle
.
multiply
(
distmat
,
mask
)
loss
=
paddle
.
sum
(
paddle
.
clip
(
dist
,
min
=
1e-12
,
max
=
1e+12
))
/
batch_size
return
{
'CenterLoss'
:
loss
}
ppcls/loss
es
/comfunc.py
→
ppcls/loss/comfunc.py
浏览文件 @
69f563d2
...
...
@@ -18,26 +18,27 @@ from __future__ import print_function
import
numpy
as
np
def
rerange_index
(
batch_size
,
samples_each_class
):
tmp
=
np
.
arange
(
0
,
batch_size
*
batch_size
)
tmp
=
tmp
.
reshape
(
-
1
,
batch_size
)
tmp
=
np
.
arange
(
0
,
batch_size
*
batch_size
)
tmp
=
tmp
.
reshape
(
-
1
,
batch_size
)
rerange_index
=
[]
for
i
in
range
(
batch_size
):
step
=
i
//
samples_each_class
start
=
step
*
samples_each_class
end
=
(
step
+
1
)
*
samples_each_class
end
=
(
step
+
1
)
*
samples_each_class
pos_idx
=
[]
neg_idx
=
[]
pos_idx
=
[]
neg_idx
=
[]
for
j
,
k
in
enumerate
(
tmp
[
i
]):
if
j
>=
start
and
j
<
end
:
if
j
==
i
:
pos_idx
.
insert
(
0
,
k
)
else
:
pos_idx
.
append
(
k
)
pos_idx
.
append
(
k
)
else
:
neg_idx
.
append
(
k
)
neg_idx
.
append
(
k
)
rerange_index
+=
(
pos_idx
+
neg_idx
)
rerange_index
=
np
.
array
(
rerange_index
).
astype
(
np
.
int32
)
...
...
ppcls/loss
es
/emlloss.py
→
ppcls/loss/emlloss.py
浏览文件 @
69f563d2
...
...
@@ -21,56 +21,64 @@ import paddle
import
numpy
as
np
from
.comfunc
import
rerange_index
class
EmlLoss
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
batch_size
=
40
,
samples_each_class
=
2
):
def
__init__
(
self
,
batch_size
=
40
,
samples_each_class
=
2
):
super
(
EmlLoss
,
self
).
__init__
()
assert
(
batch_size
%
samples_each_class
==
0
)
assert
(
batch_size
%
samples_each_class
==
0
)
self
.
samples_each_class
=
samples_each_class
self
.
batch_size
=
batch_size
self
.
rerange_index
=
rerange_index
(
batch_size
,
samples_each_class
)
self
.
batch_size
=
batch_size
self
.
rerange_index
=
rerange_index
(
batch_size
,
samples_each_class
)
self
.
thresh
=
20.0
self
.
beta
=
100000
self
.
beta
=
100000
def
surrogate_function
(
self
,
beta
,
theta
,
bias
):
x
=
theta
*
paddle
.
exp
(
bias
)
x
=
theta
*
paddle
.
exp
(
bias
)
output
=
paddle
.
log
(
1
+
beta
*
x
)
/
math
.
log
(
1
+
beta
)
return
output
def
surrogate_function_approximate
(
self
,
beta
,
theta
,
bias
):
output
=
(
paddle
.
log
(
theta
)
+
bias
+
math
.
log
(
beta
))
/
math
.
log
(
1
+
beta
)
output
=
(
paddle
.
log
(
theta
)
+
bias
+
math
.
log
(
beta
))
/
math
.
log
(
1
+
beta
)
return
output
def
surrogate_function_stable
(
self
,
beta
,
theta
,
target
,
thresh
):
max_gap
=
paddle
.
to_tensor
(
thresh
,
dtype
=
'float32'
)
max_gap
.
stop_gradient
=
True
target_max
=
paddle
.
maximum
(
target
,
max_gap
)
target_min
=
paddle
.
minimum
(
target
,
max_gap
)
loss1
=
self
.
surrogate_function
(
beta
,
theta
,
target_min
)
loss2
=
self
.
surrogate_function_approximate
(
beta
,
theta
,
target_max
)
bias
=
self
.
surrogate_function
(
beta
,
theta
,
max_gap
)
loss
=
loss1
+
loss2
-
bias
bias
=
self
.
surrogate_function
(
beta
,
theta
,
max_gap
)
loss
=
loss1
+
loss2
-
bias
return
loss
def
forward
(
self
,
input
,
target
=
None
):
features
=
input
[
"features"
]
samples_each_class
=
self
.
samples_each_class
batch_size
=
self
.
batch_size
rerange_index
=
self
.
rerange_index
batch_size
=
self
.
batch_size
rerange_index
=
self
.
rerange_index
#calc distance
diffs
=
paddle
.
unsqueeze
(
features
,
axis
=
1
)
-
paddle
.
unsqueeze
(
features
,
axis
=
0
)
similary_matrix
=
paddle
.
sum
(
paddle
.
square
(
diffs
),
axis
=-
1
)
tmp
=
paddle
.
reshape
(
similary_matrix
,
shape
=
[
-
1
,
1
])
diffs
=
paddle
.
unsqueeze
(
features
,
axis
=
1
)
-
paddle
.
unsqueeze
(
features
,
axis
=
0
)
similary_matrix
=
paddle
.
sum
(
paddle
.
square
(
diffs
),
axis
=-
1
)
tmp
=
paddle
.
reshape
(
similary_matrix
,
shape
=
[
-
1
,
1
])
rerange_index
=
paddle
.
to_tensor
(
rerange_index
)
tmp
=
paddle
.
gather
(
tmp
,
index
=
rerange_index
)
similary_matrix
=
paddle
.
reshape
(
tmp
,
shape
=
[
-
1
,
batch_size
])
ignore
,
pos
,
neg
=
paddle
.
split
(
similary_matrix
,
num_or_sections
=
[
1
,
samples_each_class
-
1
,
batch_size
-
samples_each_class
],
axis
=
1
)
ignore
.
stop_gradient
=
True
tmp
=
paddle
.
gather
(
tmp
,
index
=
rerange_index
)
similary_matrix
=
paddle
.
reshape
(
tmp
,
shape
=
[
-
1
,
batch_size
])
ignore
,
pos
,
neg
=
paddle
.
split
(
similary_matrix
,
num_or_sections
=
[
1
,
samples_each_class
-
1
,
batch_size
-
samples_each_class
],
axis
=
1
)
ignore
.
stop_gradient
=
True
pos_max
=
paddle
.
max
(
pos
,
axis
=
1
,
keepdim
=
True
)
pos
=
paddle
.
exp
(
pos
-
pos_max
)
...
...
@@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer):
neg_min
=
paddle
.
min
(
neg
,
axis
=
1
,
keepdim
=
True
)
neg
=
paddle
.
exp
(
neg_min
-
neg
)
neg_mean
=
paddle
.
mean
(
neg
,
axis
=
1
,
keepdim
=
True
)
bias
=
pos_max
-
neg_min
theta
=
paddle
.
multiply
(
neg_mean
,
pos_mean
)
loss
=
self
.
surrogate_function_stable
(
self
.
beta
,
theta
,
bias
,
self
.
thresh
)
loss
=
self
.
surrogate_function_stable
(
self
.
beta
,
theta
,
bias
,
self
.
thresh
)
loss
=
paddle
.
mean
(
loss
)
return
{
"emlloss"
:
loss
}
ppcls/loss
es
/msmloss.py
→
ppcls/loss/msmloss.py
浏览文件 @
69f563d2
...
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
paddle
from
.comfunc
import
rerange_index
class
MSMLoss
(
paddle
.
nn
.
Layer
):
"""
MSMLoss Loss, based on triplet loss. USE P * K samples.
...
...
@@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer):
]
only consider samples_each_class = 2
"""
def
__init__
(
self
,
batch_size
=
120
,
samples_each_class
=
2
,
margin
=
0.1
):
def
__init__
(
self
,
batch_size
=
120
,
samples_each_class
=
2
,
margin
=
0.1
):
super
(
MSMLoss
,
self
).
__init__
()
self
.
margin
=
margin
self
.
samples_each_class
=
samples_each_class
self
.
batch_size
=
batch_size
self
.
rerange_index
=
rerange_index
(
batch_size
,
samples_each_class
)
self
.
batch_size
=
batch_size
self
.
rerange_index
=
rerange_index
(
batch_size
,
samples_each_class
)
def
forward
(
self
,
input
,
target
=
None
):
#normalization
features
=
input
[
"features"
]
features
=
self
.
_nomalize
(
features
)
samples_each_class
=
self
.
samples_each_class
rerange_index
=
paddle
.
to_tensor
(
self
.
rerange_index
)
rerange_index
=
paddle
.
to_tensor
(
self
.
rerange_index
)
#calc sm
diffs
=
paddle
.
unsqueeze
(
features
,
axis
=
1
)
-
paddle
.
unsqueeze
(
features
,
axis
=
0
)
similary_matrix
=
paddle
.
sum
(
paddle
.
square
(
diffs
),
axis
=-
1
)
diffs
=
paddle
.
unsqueeze
(
features
,
axis
=
1
)
-
paddle
.
unsqueeze
(
features
,
axis
=
0
)
similary_matrix
=
paddle
.
sum
(
paddle
.
square
(
diffs
),
axis
=-
1
)
#rerange
tmp
=
paddle
.
reshape
(
similary_matrix
,
shape
=
[
-
1
,
1
])
tmp
=
paddle
.
gather
(
tmp
,
index
=
rerange_index
)
similary_matrix
=
paddle
.
reshape
(
tmp
,
shape
=
[
-
1
,
self
.
batch_size
])
tmp
=
paddle
.
reshape
(
similary_matrix
,
shape
=
[
-
1
,
1
])
tmp
=
paddle
.
gather
(
tmp
,
index
=
rerange_index
)
similary_matrix
=
paddle
.
reshape
(
tmp
,
shape
=
[
-
1
,
self
.
batch_size
])
#split
ignore
,
pos
,
neg
=
paddle
.
split
(
similary_matrix
,
num_or_sections
=
[
1
,
samples_each_class
-
1
,
-
1
],
axis
=
1
)
ignore
.
stop_gradient
=
True
ignore
,
pos
,
neg
=
paddle
.
split
(
similary_matrix
,
num_or_sections
=
[
1
,
samples_each_class
-
1
,
-
1
],
axis
=
1
)
ignore
.
stop_gradient
=
True
hard_pos
=
paddle
.
max
(
pos
)
hard_pos
=
paddle
.
max
(
pos
)
hard_neg
=
paddle
.
min
(
neg
)
loss
=
hard_pos
+
self
.
margin
-
hard_neg
loss
=
paddle
.
nn
.
ReLU
()(
loss
)
loss
=
paddle
.
nn
.
ReLU
()(
loss
)
return
{
"msmloss"
:
loss
}
def
_nomalize
(
self
,
input
):
input_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
input
),
axis
=
1
,
keepdim
=
True
))
input_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
input
),
axis
=
1
,
keepdim
=
True
))
return
paddle
.
divide
(
input
,
input_norm
)
ppcls/loss
es
/npairsloss.py
→
ppcls/loss/npairsloss.py
浏览文件 @
69f563d2
...
...
@@ -3,12 +3,12 @@ from __future__ import division
from
__future__
import
print_function
import
paddle
class
NpairsLoss
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
reg_lambda
=
0.01
):
super
(
NpairsLoss
,
self
).
__init__
()
self
.
reg_lambda
=
reg_lambda
def
forward
(
self
,
input
,
target
=
None
):
"""
anchor and positive(should include label)
...
...
@@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer):
features
=
input
[
"features"
]
reg_lambda
=
self
.
reg_lambda
batch_size
=
features
.
shape
[
0
]
fea_dim
=
features
.
shape
[
1
]
fea_dim
=
features
.
shape
[
1
]
num_class
=
batch_size
//
2
#reshape
out_feas
=
paddle
.
reshape
(
features
,
shape
=
[
-
1
,
2
,
fea_dim
])
anc_feas
,
pos_feas
=
paddle
.
split
(
out_feas
,
num_or_sections
=
2
,
axis
=
1
)
anc_feas
=
paddle
.
squeeze
(
anc_feas
,
axis
=
1
)
anc_feas
,
pos_feas
=
paddle
.
split
(
out_feas
,
num_or_sections
=
2
,
axis
=
1
)
anc_feas
=
paddle
.
squeeze
(
anc_feas
,
axis
=
1
)
pos_feas
=
paddle
.
squeeze
(
pos_feas
,
axis
=
1
)
#get simi matrix
similarity_matrix
=
paddle
.
matmul
(
anc_feas
,
pos_feas
,
transpose_y
=
True
)
#get similarity matrix
similarity_matrix
=
paddle
.
matmul
(
anc_feas
,
pos_feas
,
transpose_y
=
True
)
#get similarity matrix
sparse_labels
=
paddle
.
arange
(
0
,
num_class
,
dtype
=
'int64'
)
xentloss
=
paddle
.
nn
.
CrossEntropyLoss
()(
similarity_matrix
,
sparse_labels
)
#by default: mean
xentloss
=
paddle
.
nn
.
CrossEntropyLoss
()(
similarity_matrix
,
sparse_labels
)
#by default: mean
#l2 norm
reg
=
paddle
.
mean
(
paddle
.
sum
(
paddle
.
square
(
features
),
axis
=
1
))
l2loss
=
0.5
*
reg_lambda
*
reg
return
{
"npairsloss"
:
xentloss
+
l2loss
}
ppcls/loss
es
/trihardloss.py
→
ppcls/loss/trihardloss.py
浏览文件 @
69f563d2
...
...
@@ -19,6 +19,7 @@ from __future__ import print_function
import
paddle
from
.comfunc
import
rerange_index
class
TriHardLoss
(
paddle
.
nn
.
Layer
):
"""
TriHard Loss, based on triplet loss. USE P * K samples.
...
...
@@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer):
]
only consider samples_each_class = 2
"""
def
__init__
(
self
,
batch_size
=
120
,
samples_each_class
=
2
,
margin
=
0.1
):
def
__init__
(
self
,
batch_size
=
120
,
samples_each_class
=
2
,
margin
=
0.1
):
super
(
TriHardLoss
,
self
).
__init__
()
self
.
margin
=
margin
self
.
samples_each_class
=
samples_each_class
self
.
batch_size
=
batch_size
self
.
rerange_index
=
rerange_index
(
batch_size
,
samples_each_class
)
self
.
batch_size
=
batch_size
self
.
rerange_index
=
rerange_index
(
batch_size
,
samples_each_class
)
def
forward
(
self
,
input
,
target
=
None
):
features
=
input
[
"features"
]
assert
(
self
.
batch_size
==
features
.
shape
[
0
])
#normalization
features
=
self
.
_nomalize
(
features
)
samples_each_class
=
self
.
samples_each_class
rerange_index
=
paddle
.
to_tensor
(
self
.
rerange_index
)
rerange_index
=
paddle
.
to_tensor
(
self
.
rerange_index
)
#calc sm
diffs
=
paddle
.
unsqueeze
(
features
,
axis
=
1
)
-
paddle
.
unsqueeze
(
features
,
axis
=
0
)
similary_matrix
=
paddle
.
sum
(
paddle
.
square
(
diffs
),
axis
=-
1
)
diffs
=
paddle
.
unsqueeze
(
features
,
axis
=
1
)
-
paddle
.
unsqueeze
(
features
,
axis
=
0
)
similary_matrix
=
paddle
.
sum
(
paddle
.
square
(
diffs
),
axis
=-
1
)
#rerange
tmp
=
paddle
.
reshape
(
similary_matrix
,
shape
=
[
-
1
,
1
])
tmp
=
paddle
.
gather
(
tmp
,
index
=
rerange_index
)
similary_matrix
=
paddle
.
reshape
(
tmp
,
shape
=
[
-
1
,
self
.
batch_size
])
tmp
=
paddle
.
reshape
(
similary_matrix
,
shape
=
[
-
1
,
1
])
tmp
=
paddle
.
gather
(
tmp
,
index
=
rerange_index
)
similary_matrix
=
paddle
.
reshape
(
tmp
,
shape
=
[
-
1
,
self
.
batch_size
])
#split
ignore
,
pos
,
neg
=
paddle
.
split
(
similary_matrix
,
num_or_sections
=
[
1
,
samples_each_class
-
1
,
-
1
],
axis
=
1
)
ignore
.
stop_gradient
=
True
hard_pos
=
paddle
.
max
(
pos
,
axis
=
1
)
ignore
,
pos
,
neg
=
paddle
.
split
(
similary_matrix
,
num_or_sections
=
[
1
,
samples_each_class
-
1
,
-
1
],
axis
=
1
)
ignore
.
stop_gradient
=
True
hard_pos
=
paddle
.
max
(
pos
,
axis
=
1
)
hard_neg
=
paddle
.
min
(
neg
,
axis
=
1
)
loss
=
hard_pos
+
self
.
margin
-
hard_neg
loss
=
paddle
.
nn
.
ReLU
()(
loss
)
loss
=
paddle
.
nn
.
ReLU
()(
loss
)
loss
=
paddle
.
mean
(
loss
)
return
{
"trihardloss"
:
loss
}
def
_nomalize
(
self
,
input
):
input_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
input
),
axis
=
1
,
keepdim
=
True
))
input_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
input
),
axis
=
1
,
keepdim
=
True
))
return
paddle
.
divide
(
input
,
input_norm
)
ppcls/loss
es
/triplet.py
→
ppcls/loss/triplet.py
浏览文件 @
69f563d2
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录