Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
9beb154b
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9beb154b
编写于
3月 09, 2023
作者:
G
gaotingquan
提交者:
Wei Shengyu
3月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support ShiTu
上级
a41b201e
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
205 addition
and
184 deletion
+205
-184
ppcls/data/__init__.py
ppcls/data/__init__.py
+9
-8
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-0
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+6
-6
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+186
-166
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+2
-3
ppcls/engine/train/train_progressive.py
ppcls/engine/train/train_progressive.py
+1
-1
未找到文件。
ppcls/data/__init__.py
浏览文件 @
9beb154b
...
@@ -88,14 +88,15 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
...
@@ -88,14 +88,15 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
def
build_dataloader
(
config
,
mode
,
seed
=
None
):
def
build_dataloader
(
config
,
*
mode
,
seed
=
None
):
assert
mode
in
[
dataloader_config
=
config
[
"DataLoader"
]
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
for
m
in
mode
:
],
"Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert
m
in
[
assert
mode
in
config
[
"DataLoader"
].
keys
(),
"{} config not in yaml"
.
format
(
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
mode
)
],
"Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert
m
in
dataloader_config
.
keys
(),
"{} config not in yaml"
.
format
(
m
)
dataloader_config
=
config
[
"DataLoader"
][
mode
]
dataloader_config
=
dataloader_config
[
m
]
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
epochs
=
config
[
"Global"
][
"epochs"
]
epochs
=
config
[
"Global"
][
"epochs"
]
use_dali
=
config
[
"Global"
].
get
(
"use_dali"
,
False
)
use_dali
=
config
[
"Global"
].
get
(
"use_dali"
,
False
)
...
...
ppcls/engine/engine.py
浏览文件 @
9beb154b
...
@@ -22,6 +22,7 @@ from paddle import nn
...
@@ -22,6 +22,7 @@ from paddle import nn
import
numpy
as
np
import
numpy
as
np
import
random
import
random
from
..utils.amp
import
AMPForwardDecorator
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
ppcls.utils.logger
import
init_logger
from
ppcls.utils.logger
import
init_logger
from
ppcls.utils.config
import
print_config
from
ppcls.utils.config
import
print_config
...
...
ppcls/engine/evaluation/__init__.py
浏览文件 @
9beb154b
...
@@ -13,17 +13,17 @@
...
@@ -13,17 +13,17 @@
# limitations under the License.
# limitations under the License.
from
.classification
import
ClassEval
from
.classification
import
ClassEval
from
.retrieval
import
retrieval_e
val
from
.retrieval
import
RetrievalE
val
from
.adaface
import
adaface_eval
from
.adaface
import
adaface_eval
def
build_eval_func
(
config
,
mode
,
model
):
def
build_eval_func
(
config
,
mode
,
model
):
if
mode
not
in
[
"eval"
,
"train"
]:
if
mode
not
in
[
"eval"
,
"train"
]:
return
None
return
None
eval_mode
=
config
[
"Global"
].
get
(
"eval_mode"
,
None
)
task
=
config
[
"Global"
].
get
(
"task"
,
"classification"
)
if
eval_mode
is
None
:
if
task
==
"classification"
:
config
[
"Global"
][
"eval_mode"
]
=
"classification"
return
ClassEval
(
config
,
mode
,
model
)
return
ClassEval
(
config
,
mode
,
model
)
elif
task
==
"retrieval"
:
return
RetrievalEval
(
config
,
mode
,
model
)
else
:
else
:
return
getattr
(
sys
.
modules
[
__name__
],
eval_mode
+
"_eval"
)(
config
,
raise
Exception
()
mode
,
model
)
ppcls/engine/evaluation/retrieval.py
浏览文件 @
9beb154b
...
@@ -21,182 +21,202 @@ import numpy as np
...
@@ -21,182 +21,202 @@ import numpy as np
import
paddle
import
paddle
import
scipy
import
scipy
from
ppcls.utils
import
all_gather
,
logger
from
...utils.misc
import
AverageMeter
from
...utils
import
all_gather
,
logger
from
...data
import
build_dataloader
def
retrieval_eval
(
engine
,
epoch_id
=
0
):
from
...loss
import
build_loss
engine
.
model
.
eval
()
from
...metric
import
build_metrics
# step1. prepare query and gallery features
if
engine
.
gallery_query_dataloader
is
not
None
:
gallery_feat
,
gallery_label
,
gallery_camera
=
compute_feature
(
class
RetrievalEval
(
object
):
engine
,
"gallery_query"
)
def
__init__
(
self
,
config
,
mode
,
model
):
query_feat
,
query_label
,
query_camera
=
gallery_feat
,
gallery_label
,
gallery_camera
self
.
config
=
config
else
:
self
.
model
=
model
gallery_feat
,
gallery_label
,
gallery_camera
=
compute_feature
(
self
.
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
engine
,
"gallery"
)
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
query_feat
,
query_label
,
query_camera
=
compute_feature
(
engine
,
self
.
eval_metric_func
=
build_metrics
(
self
.
config
,
"Eval"
)
"query"
)
self
.
eval_loss_func
=
build_loss
(
self
.
config
,
"Eval"
)
self
.
output_info
=
dict
()
# step2. split features into feature blocks for saving memory
num_query
=
len
(
query_feat
)
self
.
gallery_query_dataloader
=
None
block_size
=
engine
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
if
len
(
self
.
config
[
"DataLoader"
][
"Eval"
].
keys
())
==
1
:
sections
=
[
block_size
]
*
(
num_query
//
block_size
)
self
.
gallery_query_dataloader
=
build_dataloader
(
self
.
config
,
if
num_query
%
block_size
>
0
:
"Eval"
)
sections
.
append
(
num_query
%
block_size
)
else
:
self
.
gallery_dataloader
=
build_dataloader
(
self
.
config
,
"Eval"
,
query_feat_blocks
=
paddle
.
split
(
query_feat
,
sections
)
"Gallery"
)
query_label_blocks
=
paddle
.
split
(
query_label
,
sections
)
self
.
query_dataloader
=
build_dataloader
(
self
.
config
,
"Eval"
,
query_camera_blocks
=
paddle
.
split
(
"Query"
)
query_camera
,
sections
)
if
query_camera
is
not
None
else
None
metric_key
=
None
def
__call__
(
self
,
epoch_id
=
0
):
self
.
model
.
eval
()
# step3. compute metric
if
engine
.
eval_loss_func
is
None
:
# step1. prepare query and gallery features
metric_dict
=
{
metric_key
:
0.0
}
if
self
.
gallery_query_dataloader
is
not
None
:
else
:
gallery_feat
,
gallery_label
,
gallery_camera
=
self
.
compute_feature
(
use_reranking
=
engine
.
config
[
"Global"
].
get
(
"re_ranking"
,
False
)
"gallery_query"
)
logger
.
info
(
f
"re_ranking=
{
use_reranking
}
"
)
query_feat
,
query_label
,
query_camera
=
gallery_feat
,
gallery_label
,
gallery_camera
if
use_reranking
:
# compute distance matrix
distmat
=
compute_re_ranking_dist
(
query_feat
,
gallery_feat
,
engine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
),
20
,
6
,
0.3
)
# exclude illegal distance
if
query_camera
is
not
None
:
camera_mask
=
query_camera
!=
gallery_camera
.
t
()
label_mask
=
query_label
!=
gallery_label
.
t
()
keep_mask
=
label_mask
|
camera_mask
distmat
=
keep_mask
.
astype
(
query_feat
.
dtype
)
*
distmat
+
(
~
keep_mask
).
astype
(
query_feat
.
dtype
)
*
(
distmat
.
max
()
+
1
)
else
:
keep_mask
=
None
# compute metric with all samples
metric_dict
=
engine
.
eval_metric_func
(
-
distmat
,
query_label
,
gallery_label
,
keep_mask
)
else
:
else
:
metric_dict
=
defaultdict
(
float
)
gallery_feat
,
gallery_label
,
gallery_camera
=
self
.
compute_feature
(
for
block_idx
,
block_feat
in
enumerate
(
query_feat_blocks
):
"gallery"
)
query_feat
,
query_label
,
query_camera
=
self
.
compute_feature
(
"query"
)
# step2. split features into feature blocks for saving memory
num_query
=
len
(
query_feat
)
block_size
=
self
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sections
=
[
block_size
]
*
(
num_query
//
block_size
)
if
num_query
%
block_size
>
0
:
sections
.
append
(
num_query
%
block_size
)
query_feat_blocks
=
paddle
.
split
(
query_feat
,
sections
)
query_label_blocks
=
paddle
.
split
(
query_label
,
sections
)
query_camera_blocks
=
paddle
.
split
(
query_camera
,
sections
)
if
query_camera
is
not
None
else
None
metric_key
=
None
# step3. compute metric
if
self
.
eval_loss_func
is
None
:
metric_dict
=
{
metric_key
:
0.0
}
else
:
use_reranking
=
self
.
config
[
"Global"
].
get
(
"re_ranking"
,
False
)
logger
.
info
(
f
"re_ranking=
{
use_reranking
}
"
)
if
use_reranking
:
# compute distance matrix
# compute distance matrix
distmat
=
paddle
.
matmul
(
distmat
=
compute_re_ranking_dist
(
block_feat
,
gallery_feat
,
transpose_y
=
True
)
query_feat
,
gallery_feat
,
self
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
),
20
,
6
,
0.3
)
# exclude illegal distance
# exclude illegal distance
if
query_camera
is
not
None
:
if
query_camera
is
not
None
:
camera_mask
=
query_camera_blocks
[
camera_mask
=
query_camera
!=
gallery_camera
.
t
()
block_idx
]
!=
gallery_camera
.
t
()
label_mask
=
query_label
!=
gallery_label
.
t
()
label_mask
=
query_label_blocks
[
block_idx
]
!=
gallery_label
.
t
()
keep_mask
=
label_mask
|
camera_mask
keep_mask
=
label_mask
|
camera_mask
distmat
=
keep_mask
.
astype
(
query_feat
.
dtype
)
*
distmat
distmat
=
keep_mask
.
astype
(
query_feat
.
dtype
)
*
distmat
+
(
~
keep_mask
).
astype
(
query_feat
.
dtype
)
*
(
distmat
.
max
()
+
1
)
else
:
else
:
keep_mask
=
None
keep_mask
=
None
# compute metric by block
# compute metric with all samples
metric_block
=
engine
.
eval_metric_func
(
metric_dict
=
self
.
eval_metric_func
(
-
distmat
,
query_label
,
distmat
,
query_label_blocks
[
block_idx
],
gallery_label
,
gallery_label
,
keep_mask
)
keep_mask
)
else
:
# accumulate metric
metric_dict
=
defaultdict
(
float
)
for
key
in
metric_block
:
for
block_idx
,
block_feat
in
enumerate
(
query_feat_blocks
):
metric_dict
[
key
]
+=
metric_block
[
key
]
*
block_feat
.
shape
[
# compute distance matrix
0
]
/
num_query
distmat
=
paddle
.
matmul
(
block_feat
,
gallery_feat
,
transpose_y
=
True
)
metric_info_list
=
[]
# exclude illegal distance
for
key
,
value
in
metric_dict
.
items
():
if
query_camera
is
not
None
:
metric_info_list
.
append
(
f
"
{
key
}
:
{
value
:.
5
f
}
"
)
camera_mask
=
query_camera_blocks
[
if
metric_key
is
None
:
block_idx
]
!=
gallery_camera
.
t
()
metric_key
=
key
label_mask
=
query_label_blocks
[
metric_msg
=
", "
.
join
(
metric_info_list
)
block_idx
]
!=
gallery_label
.
t
()
logger
.
info
(
f
"[Eval][Epoch
{
epoch_id
}
][Avg]
{
metric_msg
}
"
)
keep_mask
=
label_mask
|
camera_mask
distmat
=
keep_mask
.
astype
(
query_feat
.
dtype
)
*
distmat
return
metric_dict
[
metric_key
]
else
:
keep_mask
=
None
# compute metric by block
def
compute_feature
(
engine
,
name
=
"gallery"
):
metric_block
=
self
.
eval_metric_func
(
if
name
==
"gallery"
:
distmat
,
query_label_blocks
[
block_idx
],
gallery_label
,
dataloader
=
engine
.
gallery_dataloader
keep_mask
)
elif
name
==
"query"
:
# accumulate metric
dataloader
=
engine
.
query_dataloader
for
key
in
metric_block
:
elif
name
==
"gallery_query"
:
metric_dict
[
key
]
+=
metric_block
[
dataloader
=
engine
.
gallery_query_dataloader
key
]
*
block_feat
.
shape
[
0
]
/
num_query
else
:
raise
ValueError
(
metric_info_list
=
[]
f
"Only support gallery or query or gallery_query dataset, but got
{
name
}
"
for
key
,
value
in
metric_dict
.
items
():
)
metric_info_list
.
append
(
f
"
{
key
}
:
{
value
:.
5
f
}
"
)
if
metric_key
is
None
:
all_feat
=
[]
metric_key
=
key
all_label
=
[]
metric_msg
=
", "
.
join
(
metric_info_list
)
all_camera
=
[]
logger
.
info
(
f
"[Eval][Epoch
{
epoch_id
}
][Avg]
{
metric_msg
}
"
)
has_camera
=
False
for
idx
,
batch
in
enumerate
(
dataloader
):
# load is very time-consuming
return
metric_dict
[
metric_key
]
if
idx
%
engine
.
config
[
"Global"
][
"print_batch_step"
]
==
0
:
logger
.
info
(
def
compute_feature
(
self
,
name
=
"gallery"
):
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
if
name
==
"gallery"
:
dataloader
=
self
.
gallery_dataloader
elif
name
==
"query"
:
dataloader
=
self
.
query_dataloader
elif
name
==
"gallery_query"
:
dataloader
=
self
.
gallery_query_dataloader
else
:
raise
ValueError
(
f
"Only support gallery or query or gallery_query dataset, but got
{
name
}
"
)
)
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
all_feat
=
[]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
all_label
=
[]
if
len
(
batch
)
>=
3
:
all_camera
=
[]
has_camera
=
True
has_camera
=
False
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
for
idx
,
batch
in
enumerate
(
dataloader
):
# load is very time-consuming
if
engine
.
amp
and
engine
.
amp_eval
:
if
idx
%
self
.
print_batch_step
==
0
:
with
paddle
.
amp
.
auto_cast
(
logger
.
info
(
custom_black_list
=
{
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
"flatten_contiguous_range"
,
"greater_than"
)
},
level
=
engine
.
amp_level
):
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
out
=
engine
.
model
(
batch
[
0
])
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
else
:
if
len
(
batch
)
>=
3
:
out
=
engine
.
model
(
batch
[
0
])
has_camera
=
True
if
"Student"
in
out
:
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
out
=
out
[
"Student"
]
out
=
self
.
model
(
batch
)
# get features
if
engine
.
config
[
"Global"
].
get
(
"retrieval_feature_from"
,
if
"Student"
in
out
:
"features"
)
==
"features"
:
out
=
out
[
"Student"
]
# use output from neck as feature
batch_feat
=
out
[
"features"
]
# get features
else
:
if
self
.
config
[
"Global"
].
get
(
"retrieval_feature_from"
,
# use output from backbone as feature
"features"
)
==
"features"
:
batch_feat
=
out
[
"backbone"
]
# use output from neck as feature
batch_feat
=
out
[
"features"
]
# do norm(optional)
else
:
if
engine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
# use output from backbone as feature
batch_feat
=
paddle
.
nn
.
functional
.
normalize
(
batch_feat
,
p
=
2
)
batch_feat
=
out
[
"backbone"
]
# do binarize(optional)
# do norm(optional)
if
engine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
if
self
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
batch_feat
=
paddle
.
round
(
batch_feat
).
astype
(
"float32"
)
*
2.0
-
1.0
batch_feat
=
paddle
.
nn
.
functional
.
normalize
(
batch_feat
,
p
=
2
)
elif
engine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
batch_feat
=
paddle
.
sign
(
batch_feat
).
astype
(
"float32"
)
# do binarize(optional)
if
self
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
if
paddle
.
distributed
.
get_world_size
()
>
1
:
batch_feat
=
paddle
.
round
(
batch_feat
).
astype
(
all_feat
.
append
(
all_gather
(
batch_feat
))
"float32"
)
*
2.0
-
1.0
all_label
.
append
(
all_gather
(
batch
[
1
]))
elif
self
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
if
has_camera
:
batch_feat
=
paddle
.
sign
(
batch_feat
).
astype
(
"float32"
)
all_camera
.
append
(
all_gather
(
batch
[
2
]))
if
paddle
.
distributed
.
get_world_size
()
>
1
:
all_feat
.
append
(
all_gather
(
batch_feat
))
all_label
.
append
(
all_gather
(
batch
[
1
]))
if
has_camera
:
all_camera
.
append
(
all_gather
(
batch
[
2
]))
else
:
all_feat
.
append
(
batch_feat
)
all_label
.
append
(
batch
[
1
])
if
has_camera
:
all_camera
.
append
(
batch
[
2
])
if
self
.
use_dali
:
dataloader
.
reset
()
all_feat
=
paddle
.
concat
(
all_feat
)
all_label
=
paddle
.
concat
(
all_label
)
if
has_camera
:
all_camera
=
paddle
.
concat
(
all_camera
)
else
:
else
:
all_feat
.
append
(
batch_feat
)
all_camera
=
None
all_label
.
append
(
batch
[
1
])
# discard redundant padding sample(s) at the end
if
has_camera
:
total_samples
=
dataloader
.
size
if
self
.
use_dali
else
len
(
all_camera
.
append
(
batch
[
2
])
dataloader
.
dataset
)
all_feat
=
all_feat
[:
total_samples
]
if
engine
.
use_dali
:
all_label
=
all_label
[:
total_samples
]
dataloader
.
reset
()
if
has_camera
:
all_camera
=
all_camera
[:
total_samples
]
all_feat
=
paddle
.
concat
(
all_feat
)
all_label
=
paddle
.
concat
(
all_label
)
logger
.
info
(
f
"Build
{
name
}
done, all feat shape:
{
all_feat
.
shape
}
"
)
if
has_camera
:
return
all_feat
,
all_label
,
all_camera
all_camera
=
paddle
.
concat
(
all_camera
)
else
:
all_camera
=
None
# discard redundant padding sample(s) at the end
total_samples
=
dataloader
.
size
if
engine
.
use_dali
else
len
(
dataloader
.
dataset
)
all_feat
=
all_feat
[:
total_samples
]
all_label
=
all_label
[:
total_samples
]
if
has_camera
:
all_camera
=
all_camera
[:
total_samples
]
logger
.
info
(
f
"Build
{
name
}
done, all feat shape:
{
all_feat
.
shape
}
"
)
return
all_feat
,
all_label
,
all_camera
def
k_reciprocal_neighbor
(
rank
:
np
.
ndarray
,
p
:
int
,
k
:
int
)
->
np
.
ndarray
:
def
k_reciprocal_neighbor
(
rank
:
np
.
ndarray
,
p
:
int
,
k
:
int
)
->
np
.
ndarray
:
...
...
ppcls/engine/train/__init__.py
浏览文件 @
9beb154b
...
@@ -22,9 +22,8 @@ from .train_progressive import train_epoch_progressive
...
@@ -22,9 +22,8 @@ from .train_progressive import train_epoch_progressive
def
build_train_func
(
config
,
mode
,
model
,
eval_func
):
def
build_train_func
(
config
,
mode
,
model
,
eval_func
):
if
mode
!=
"train"
:
if
mode
!=
"train"
:
return
None
return
None
train_mode
=
config
[
"Global"
].
get
(
"task"
,
None
)
task
=
config
[
"Global"
].
get
(
"task"
,
"classification"
)
if
train_mode
is
None
:
if
task
==
"classification"
or
task
==
"retrieval"
:
config
[
"Global"
][
"task"
]
=
"classification"
return
ClassTrainer
(
config
,
model
,
eval_func
)
return
ClassTrainer
(
config
,
model
,
eval_func
)
else
:
else
:
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)(
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)(
...
...
ppcls/engine/train/train_progressive.py
浏览文件 @
9beb154b
...
@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
...
@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
from
ppcls.data
import
build_dataloader
from
ppcls.data
import
build_dataloader
from
ppcls.utils
import
logger
,
type_name
from
ppcls.utils
import
logger
,
type_name
from
.
regular_train_epoch
import
regular_train_epoch
from
.
classification
import
ClassTrainer
def
train_epoch_progressive
(
engine
,
epoch_id
,
print_batch_step
):
def
train_epoch_progressive
(
engine
,
epoch_id
,
print_batch_step
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录