Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
12c78eee
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
12c78eee
编写于
9月 23, 2020
作者:
L
Liu Yiqun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Calculate the average time for benchmark.
上级
e41decb6
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
54 addition
and
13 deletion
+54
-13
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+21
-13
ppgan/utils/timer.py
ppgan/utils/timer.py
+33
-0
未找到文件。
ppgan/engine/trainer.py
浏览文件 @
12c78eee
...
...
@@ -11,6 +11,7 @@ from ..datasets.builder import build_dataloader
from
..models.builder
import
build_model
from
..utils.visual
import
tensor2img
,
save_image
from
..utils.filesystem
import
save
,
load
,
makedirs
from
..utils.timer
import
TimeAverager
from
..metric.psnr_ssim
import
calculate_psnr
,
calculate_ssim
...
...
@@ -61,30 +62,37 @@ class Trainer:
paddle
.
DataParallel
(
net
,
strategy
))
def
train
(
self
):
reader_cost_averager
=
TimeAverager
()
batch_cost_averager
=
TimeAverager
()
for
epoch
in
range
(
self
.
start_epoch
,
self
.
epochs
):
self
.
current_epoch
=
epoch
start_time
=
step_start_time
=
time
.
time
()
for
i
,
data
in
enumerate
(
self
.
train_dataloader
):
data_time
=
time
.
time
()
reader_cost_averager
.
record
(
time
.
time
()
-
step_start_time
)
self
.
batch_id
=
i
# unpack data from dataset and apply preprocessing
# data input should be dict
self
.
model
.
set_input
(
data
)
self
.
model
.
optimize_parameters
()
self
.
data_time
=
data_time
-
step_start_time
self
.
step_time
=
time
.
time
()
-
step_start_time
batch_cost_averager
.
record
(
time
.
time
()
-
step_start_time
)
if
i
%
self
.
log_interval
==
0
:
self
.
data_time
=
reader_cost_averager
.
get_average
()
self
.
step_time
=
batch_cost_averager
.
get_average
()
self
.
print_log
()
reader_cost_averager
.
reset
()
batch_cost_averager
.
reset
()
if
i
%
self
.
visual_interval
==
0
:
self
.
visual
(
'visual_train'
)
step_start_time
=
time
.
time
()
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
if
self
.
validate_interval
>
-
1
and
epoch
%
self
.
validate_interval
:
self
.
validate
()
self
.
model
.
lr_scheduler
.
step
()
...
...
@@ -94,8 +102,8 @@ class Trainer:
def
validate
(
self
):
if
not
hasattr
(
self
,
'val_dataloader'
):
self
.
val_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
val
,
is_train
=
False
)
self
.
val_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
val
,
is_train
=
False
)
metric_result
=
{}
...
...
@@ -141,8 +149,8 @@ class Trainer:
self
.
visual
(
'visual_val'
,
visual_results
=
visual_results
)
if
i
%
self
.
log_interval
==
0
:
self
.
logger
.
info
(
'val iter: [%d/%d]'
%
(
i
,
len
(
self
.
val_dataloader
)))
self
.
logger
.
info
(
'val iter: [%d/%d]'
%
(
i
,
len
(
self
.
val_dataloader
)))
for
metric_name
in
metric_result
.
keys
():
metric_result
[
metric_name
]
/=
len
(
self
.
val_dataloader
.
dataset
)
...
...
@@ -152,8 +160,8 @@ class Trainer:
def
test
(
self
):
if
not
hasattr
(
self
,
'test_dataloader'
):
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
is_train
=
False
)
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
is_train
=
False
)
# data[0]: img, data[1]: img path index
# test batch size must be 1
...
...
@@ -177,8 +185,8 @@ class Trainer:
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
)
if
i
%
self
.
log_interval
==
0
:
self
.
logger
.
info
(
'Test iter: [%d/%d]'
%
(
i
,
len
(
self
.
test_dataloader
)))
self
.
logger
.
info
(
'Test iter: [%d/%d]'
%
(
i
,
len
(
self
.
test_dataloader
)))
def
print_log
(
self
):
losses
=
self
.
model
.
get_current_losses
()
...
...
ppgan/utils/timer.py
0 → 100644
浏览文件 @
12c78eee
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
class
TimeAverager
(
object
):
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
_cnt
=
0
self
.
_total_time
=
0
def
record
(
self
,
usetime
):
self
.
_cnt
+=
1
self
.
_total_time
+=
usetime
def
get_average
(
self
):
if
self
.
_cnt
==
0
:
return
0
return
self
.
_total_time
/
self
.
_cnt
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录