Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
d2c34a93
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d2c34a93
编写于
3月 23, 2023
作者:
W
wangna11BD
提交者:
GitHub
3月 23, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add log for export and inference in TIPC (#769)
* add log for TIPC * fix CI * add log
上级
0927444b
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
117 addition
and
5 deletion
+117
-5
ppgan/utils/config.py
ppgan/utils/config.py
+84
-1
tools/export_model.py
tools/export_model.py
+4
-1
tools/inference.py
tools/inference.py
+29
-3
未找到文件。
ppgan/utils/config.py
浏览文件 @
d2c34a93
...
...
@@ -14,6 +14,8 @@
import
os
import
yaml
from
.logger
import
get_logger
__all__
=
[
'get_config'
]
...
...
@@ -54,6 +56,84 @@ def parse_config(cfg_file):
create_attr_dict
(
yaml_config
)
return
yaml_config
Color
=
{
'RED'
:
'
\033
[31m'
,
'HEADER'
:
'
\033
[35m'
,
# deep purple
'PURPLE'
:
'
\033
[95m'
,
# purple
'OKBLUE'
:
'
\033
[94m'
,
'OKGREEN'
:
'
\033
[92m'
,
'WARNING'
:
'
\033
[93m'
,
'FAIL'
:
'
\033
[91m'
,
'ENDC'
:
'
\033
[0m'
}
def
coloring
(
message
,
color
=
"OKGREEN"
):
assert
color
in
Color
.
keys
()
if
os
.
environ
.
get
(
'PADDLEGAN_COLORING'
,
False
):
return
Color
[
color
]
+
str
(
message
)
+
Color
[
"ENDC"
]
else
:
return
message
def
print_dict
(
d
,
logger
,
delimiter
=
0
):
"""
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
placeholder
=
"-"
*
60
for
k
,
v
in
sorted
(
d
.
items
()):
if
isinstance
(
v
,
dict
):
logger
.
info
(
"{}{} : "
.
format
(
delimiter
*
" "
,
coloring
(
k
,
"HEADER"
)))
print_dict
(
v
,
logger
,
delimiter
+
4
)
elif
isinstance
(
v
,
list
)
and
len
(
v
)
>=
1
and
isinstance
(
v
[
0
],
dict
):
logger
.
info
(
"{}{} : "
.
format
(
delimiter
*
" "
,
coloring
(
str
(
k
),
"HEADER"
)))
for
value
in
v
:
print_dict
(
value
,
logger
,
delimiter
+
4
)
else
:
logger
.
info
(
"{}{} : {}"
.
format
(
delimiter
*
" "
,
coloring
(
k
,
"HEADER"
),
coloring
(
v
,
"OKGREEN"
)))
if
k
.
isupper
():
logger
.
info
(
placeholder
)
def
advertise
(
logger
):
"""
Show the advertising message like the following:
===========================================================
== PaddleGAN is powered by PaddlePaddle ! ==
===========================================================
== ==
== For more info please go to the following website. ==
== ==
== https://github.com/PaddlePaddle/PaddleGAN ==
===========================================================
"""
copyright
=
"PaddleGAN is powered by PaddlePaddle !"
ad
=
"For more info please go to the following website."
website
=
"https://github.com/PaddlePaddle/PaddleGAN"
AD_LEN
=
6
+
len
(
max
([
copyright
,
ad
,
website
],
key
=
len
))
logger
.
info
(
coloring
(
"
\n
{0}
\n
{1}
\n
{2}
\n
{3}
\n
{4}
\n
{5}
\n
{6}
\n
{7}
\n
"
.
format
(
"="
*
(
AD_LEN
+
4
),
"=={}=="
.
format
(
copyright
.
center
(
AD_LEN
)),
"="
*
(
AD_LEN
+
4
),
"=={}=="
.
format
(
' '
*
AD_LEN
),
"=={}=="
.
format
(
ad
.
center
(
AD_LEN
)),
"=={}=="
.
format
(
' '
*
AD_LEN
),
"=={}=="
.
format
(
website
.
center
(
AD_LEN
)),
"="
*
(
AD_LEN
+
4
),
),
"RED"
))
def
print_config
(
config
,
logger
):
"""
visualize configs
Arguments:
config: configs
"""
advertise
(
logger
)
print_dict
(
config
,
logger
)
def
override
(
dl
,
ks
,
v
):
"""
...
...
@@ -115,13 +195,16 @@ def override_config(config, options=None):
return
config
def
get_config
(
fname
,
overrides
=
None
,
show
=
Tru
e
):
def
get_config
(
fname
,
overrides
=
None
,
show
=
Fals
e
):
"""
Read config from file
"""
assert
os
.
path
.
exists
(
fname
),
(
'config file({}) is not exist'
.
format
(
fname
))
config
=
parse_config
(
fname
)
override_config
(
config
,
overrides
)
if
show
:
logger
=
get_logger
(
name
=
'ppgan.config'
)
print_config
(
config
,
logger
)
return
config
...
...
tools/export_model.py
浏览文件 @
d2c34a93
...
...
@@ -22,6 +22,7 @@ import ppgan
from
ppgan.utils.config
import
get_config
from
ppgan.utils.setup
import
setup
from
ppgan.engine.trainer
import
Trainer
from
ppgan.utils.logger
import
get_logger
def
parse_args
():
...
...
@@ -80,9 +81,11 @@ def main(args, cfg):
net
.
set_state_dict
(
state_dicts
[
net_name
])
model
.
export_model
(
cfg
.
export_model
,
args
.
output_dir
,
inputs_size
,
args
.
export_serving_model
,
args
.
model_name
)
logger
=
get_logger
(
name
=
'ppgan'
)
logger
.
info
(
"Export succeeded! The inference model exported has been saved in {}"
.
format
(
args
.
output_dir
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
cfg
=
get_config
(
args
.
config_file
,
args
.
opt
)
cfg
=
get_config
(
args
.
config_file
,
args
.
opt
,
show
=
True
)
main
(
args
,
cfg
)
tools/inference.py
浏览文件 @
d2c34a93
...
...
@@ -16,6 +16,7 @@ from ppgan.utils.visual import save_image
from
ppgan.utils.visual
import
tensor2img
from
ppgan.utils.filesystem
import
makedirs
from
ppgan.metrics
import
build_metric
from
ppgan.utils.logger
import
get_logger
MODEL_CLASSES
=
[
"pix2pix"
,
"cyclegan"
,
"wav2lip"
,
"esrgan"
,
\
...
...
@@ -151,7 +152,7 @@ def create_predictor(model_path,
print
(
'trt set dynamic shape done!'
)
predictor
=
paddle
.
inference
.
create_predictor
(
config
)
return
predictor
return
predictor
,
config
def
setup_metrics
(
cfg
):
...
...
@@ -172,8 +173,8 @@ def main():
paddle
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
cfg
=
get_config
(
args
.
config_file
,
args
.
opt
)
predictor
=
create_predictor
(
args
.
model_path
,
args
.
device
,
args
.
run_mode
,
cfg
=
get_config
(
args
.
config_file
,
args
.
opt
,
show
=
True
)
predictor
,
config
=
create_predictor
(
args
.
model_path
,
args
.
device
,
args
.
run_mode
,
args
.
batch_size
,
args
.
min_subgraph_size
,
args
.
use_dynamic_shape
,
args
.
trt_min_shape
,
args
.
trt_max_shape
,
args
.
trt_opt_shape
,
...
...
@@ -222,15 +223,40 @@ def main():
metric
.
update
(
prediction
,
real_B
)
elif
model_type
==
"cyclegan"
:
import
auto_log
logger
=
get_logger
(
name
=
'ppgan'
)
size
=
data
[
'A'
].
shape
pid
=
os
.
getpid
()
auto_logger
=
auto_log
.
AutoLogger
(
model_name
=
args
.
model_type
,
model_precision
=
args
.
run_mode
,
batch_size
=
args
.
batch_size
,
data_shape
=
size
,
save_path
=
args
.
output_path
+
'auto_log.lpg'
,
inference_config
=
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
0
)
auto_logger
.
times
.
start
()
real_A
=
data
[
'A'
].
numpy
()
input_handles
[
0
].
copy_from_cpu
(
real_A
)
auto_logger
.
times
.
stamp
()
predictor
.
run
()
auto_logger
.
times
.
stamp
()
prediction
=
output_handle
.
copy_to_cpu
()
prediction
=
paddle
.
to_tensor
(
prediction
)
image_numpy
=
tensor2img
(
prediction
[
0
],
min_max
)
save_image
(
image_numpy
,
os
.
path
.
join
(
args
.
output_path
,
"cyclegan/{}.png"
.
format
(
i
)))
logger
.
info
(
"Inference succeeded! The inference result has been saved in {}"
.
format
(
os
.
path
.
join
(
args
.
output_path
,
"cyclegan/{}.png"
.
format
(
i
))))
auto_logger
.
times
.
end
(
stamp
=
True
)
auto_logger
.
report
()
metric_file
=
os
.
path
.
join
(
args
.
output_path
,
"cyclegan/metric.txt"
)
real_B
=
paddle
.
to_tensor
(
data
[
'B'
])
for
metric
in
metrics
.
values
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录