Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
362338cb
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
362338cb
编写于
8月 21, 2020
作者:
W
wuzewu
提交者:
GitHub
8月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #350 from wuyefeilin/dygraph
add environment information collection
上级
e5336bb5
30ca35ed
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
156 addition
and
44 deletion
+156
-44
dygraph/benchmark/deeplabv3p.py
dygraph/benchmark/deeplabv3p.py
+6
-1
dygraph/benchmark/hrnet.py
dygraph/benchmark/hrnet.py
+6
-1
dygraph/core/infer.py
dygraph/core/infer.py
+2
-2
dygraph/core/train.py
dygraph/core/train.py
+3
-3
dygraph/core/val.py
dygraph/core/val.py
+7
-7
dygraph/infer.py
dygraph/infer.py
+1
-1
dygraph/train.py
dygraph/train.py
+7
-1
dygraph/utils/__init__.py
dygraph/utils/__init__.py
+2
-1
dygraph/utils/get_environ_info.py
dygraph/utils/get_environ_info.py
+113
-0
dygraph/utils/logger.py
dygraph/utils/logger.py
+0
-0
dygraph/utils/utils.py
dygraph/utils/utils.py
+8
-26
dygraph/val.py
dygraph/val.py
+1
-1
未找到文件。
dygraph/benchmark/deeplabv3p.py
浏览文件 @
362338cb
...
@@ -21,6 +21,7 @@ from dygraph.datasets import DATASETS
...
@@ -21,6 +21,7 @@ from dygraph.datasets import DATASETS
import
dygraph.transforms
as
T
import
dygraph.transforms
as
T
from
dygraph.models
import
MODELS
from
dygraph.models
import
MODELS
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
logger
from
dygraph.core
import
train
from
dygraph.core
import
train
...
@@ -129,8 +130,12 @@ def parse_args():
...
@@ -129,8 +130,12 @@ def parse_args():
def
main
(
args
):
def
main
(
args
):
env_info
=
get_environ_info
()
env_info
=
get_environ_info
()
info
=
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
env_info
.
items
()]
info
=
'
\n
'
.
join
([
'
\n
'
,
format
(
'Environment Information'
,
'-^48s'
)]
+
info
+
[
'-'
*
48
])
logger
.
info
(
info
)
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'
place'
]
==
'cuda'
and
fluid
.
is_compiled_with_cuda
()
\
if
env_info
[
'
Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
fluid
.
CPUPlace
()
if
args
.
dataset
not
in
DATASETS
:
if
args
.
dataset
not
in
DATASETS
:
...
...
dygraph/benchmark/hrnet.py
浏览文件 @
362338cb
...
@@ -21,6 +21,7 @@ from dygraph.datasets import DATASETS
...
@@ -21,6 +21,7 @@ from dygraph.datasets import DATASETS
import
dygraph.transforms
as
T
import
dygraph.transforms
as
T
from
dygraph.models
import
MODELS
from
dygraph.models
import
MODELS
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
logger
from
dygraph.core
import
train
from
dygraph.core
import
train
...
@@ -129,8 +130,12 @@ def parse_args():
...
@@ -129,8 +130,12 @@ def parse_args():
def
main
(
args
):
def
main
(
args
):
env_info
=
get_environ_info
()
env_info
=
get_environ_info
()
info
=
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
env_info
.
items
()]
info
=
'
\n
'
.
join
([
'
\n
'
,
format
(
'Environment Information'
,
'-^48s'
)]
+
info
+
[
'-'
*
48
])
logger
.
info
(
info
)
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'
place'
]
==
'cuda'
and
fluid
.
is_compiled_with_cuda
()
\
if
env_info
[
'
Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
fluid
.
CPUPlace
()
if
args
.
dataset
not
in
DATASETS
:
if
args
.
dataset
not
in
DATASETS
:
...
...
dygraph/core/infer.py
浏览文件 @
362338cb
...
@@ -21,7 +21,7 @@ import cv2
...
@@ -21,7 +21,7 @@ import cv2
import
tqdm
import
tqdm
from
dygraph
import
utils
from
dygraph
import
utils
import
dygraph.utils.logg
ing
as
logging
import
dygraph.utils.logg
er
as
logger
def
mkdir
(
path
):
def
mkdir
(
path
):
...
@@ -39,7 +39,7 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
...
@@ -39,7 +39,7 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
added_saved_dir
=
os
.
path
.
join
(
save_dir
,
'added'
)
added_saved_dir
=
os
.
path
.
join
(
save_dir
,
'added'
)
pred_saved_dir
=
os
.
path
.
join
(
save_dir
,
'prediction'
)
pred_saved_dir
=
os
.
path
.
join
(
save_dir
,
'prediction'
)
logg
ing
.
info
(
"Start to predict..."
)
logg
er
.
info
(
"Start to predict..."
)
for
im
,
im_info
,
im_path
in
tqdm
.
tqdm
(
test_dataset
):
for
im
,
im_info
,
im_path
in
tqdm
.
tqdm
(
test_dataset
):
im
=
to_variable
(
im
)
im
=
to_variable
(
im
)
pred
,
_
=
model
(
im
)
pred
,
_
=
model
(
im
)
...
...
dygraph/core/train.py
浏览文件 @
362338cb
...
@@ -19,7 +19,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
...
@@ -19,7 +19,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
paddle.fluid.io
import
DataLoader
from
paddle.fluid.io
import
DataLoader
from
paddle.incubate.hapi.distributed
import
DistributedBatchSampler
from
paddle.incubate.hapi.distributed
import
DistributedBatchSampler
import
dygraph.utils.logg
ing
as
logging
import
dygraph.utils.logg
er
as
logger
from
dygraph.utils
import
load_pretrained_model
from
dygraph.utils
import
load_pretrained_model
from
dygraph.utils
import
resume
from
dygraph.utils
import
resume
from
dygraph.utils
import
Timer
,
calculate_eta
from
dygraph.utils
import
Timer
,
calculate_eta
...
@@ -111,7 +111,7 @@ def train(model,
...
@@ -111,7 +111,7 @@ def train(model,
train_batch_cost
=
0.0
train_batch_cost
=
0.0
remain_steps
=
total_steps
-
num_steps
remain_steps
=
total_steps
-
num_steps
eta
=
calculate_eta
(
remain_steps
,
avg_train_batch_cost
)
eta
=
calculate_eta
(
remain_steps
,
avg_train_batch_cost
)
logg
ing
.
info
(
logg
er
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.
format
(
epoch
+
1
,
num_epochs
,
step
+
1
,
steps_per_epoch
,
.
format
(
epoch
+
1
,
num_epochs
,
step
+
1
,
steps_per_epoch
,
avg_loss
*
nranks
,
lr
,
avg_train_batch_cost
,
avg_loss
*
nranks
,
lr
,
avg_train_batch_cost
,
...
@@ -152,7 +152,7 @@ def train(model,
...
@@ -152,7 +152,7 @@ def train(model,
best_model_dir
=
os
.
path
.
join
(
save_dir
,
"best_model"
)
best_model_dir
=
os
.
path
.
join
(
save_dir
,
"best_model"
)
fluid
.
save_dygraph
(
model
.
state_dict
(),
fluid
.
save_dygraph
(
model
.
state_dict
(),
os
.
path
.
join
(
best_model_dir
,
'model'
))
os
.
path
.
join
(
best_model_dir
,
'model'
))
logg
ing
.
info
(
logg
er
.
info
(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.
format
(
best_model_epoch
,
best_mean_iou
))
.
format
(
best_model_epoch
,
best_mean_iou
))
...
...
dygraph/core/val.py
浏览文件 @
362338cb
...
@@ -20,7 +20,7 @@ import cv2
...
@@ -20,7 +20,7 @@ import cv2
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.base
import
to_variable
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
dygraph.utils.logg
ing
as
logging
import
dygraph.utils.logg
er
as
logger
from
dygraph.utils
import
ConfusionMatrix
from
dygraph.utils
import
ConfusionMatrix
from
dygraph.utils
import
Timer
,
calculate_eta
from
dygraph.utils
import
Timer
,
calculate_eta
...
@@ -39,7 +39,7 @@ def evaluate(model,
...
@@ -39,7 +39,7 @@ def evaluate(model,
total_steps
=
len
(
eval_dataset
)
total_steps
=
len
(
eval_dataset
)
conf_mat
=
ConfusionMatrix
(
num_classes
,
streaming
=
True
)
conf_mat
=
ConfusionMatrix
(
num_classes
,
streaming
=
True
)
logg
ing
.
info
(
logg
er
.
info
(
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
len
(
eval_dataset
),
total_steps
))
len
(
eval_dataset
),
total_steps
))
timer
=
Timer
()
timer
=
Timer
()
...
@@ -69,7 +69,7 @@ def evaluate(model,
...
@@ -69,7 +69,7 @@ def evaluate(model,
time_step
=
timer
.
elapsed_time
()
time_step
=
timer
.
elapsed_time
()
remain_step
=
total_steps
-
step
-
1
remain_step
=
total_steps
-
step
-
1
logg
ing
.
debug
(
logg
er
.
debug
(
"[EVAL] Epoch={}, Step={}/{}, iou={:4f}, sec/step={:.4f} | ETA {}"
.
"[EVAL] Epoch={}, Step={}/{}, iou={:4f}, sec/step={:.4f} | ETA {}"
.
format
(
epoch_id
,
step
+
1
,
total_steps
,
iou
,
time_step
,
format
(
epoch_id
,
step
+
1
,
total_steps
,
iou
,
time_step
,
calculate_eta
(
remain_step
,
time_step
)))
calculate_eta
(
remain_step
,
time_step
)))
...
@@ -77,9 +77,9 @@ def evaluate(model,
...
@@ -77,9 +77,9 @@ def evaluate(model,
category_iou
,
miou
=
conf_mat
.
mean_iou
()
category_iou
,
miou
=
conf_mat
.
mean_iou
()
category_acc
,
macc
=
conf_mat
.
accuracy
()
category_acc
,
macc
=
conf_mat
.
accuracy
()
logg
ing
.
info
(
"[EVAL] #Images={} mAcc={:.4f} mIoU={:.4f}"
.
format
(
logg
er
.
info
(
"[EVAL] #Images={} mAcc={:.4f} mIoU={:.4f}"
.
format
(
len
(
eval_dataset
),
macc
,
miou
))
len
(
eval_dataset
),
macc
,
miou
))
logg
ing
.
info
(
"[EVAL] Category IoU: "
+
str
(
category_iou
))
logg
er
.
info
(
"[EVAL] Category IoU: "
+
str
(
category_iou
))
logg
ing
.
info
(
"[EVAL] Category Acc: "
+
str
(
category_acc
))
logg
er
.
info
(
"[EVAL] Category Acc: "
+
str
(
category_acc
))
logg
ing
.
info
(
"[EVAL] Kappa:{:.4f} "
.
format
(
conf_mat
.
kappa
()))
logg
er
.
info
(
"[EVAL] Kappa:{:.4f} "
.
format
(
conf_mat
.
kappa
()))
return
miou
,
macc
return
miou
,
macc
dygraph/infer.py
浏览文件 @
362338cb
...
@@ -84,7 +84,7 @@ def parse_args():
...
@@ -84,7 +84,7 @@ def parse_args():
def
main
(
args
):
def
main
(
args
):
env_info
=
get_environ_info
()
env_info
=
get_environ_info
()
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'
place'
]
==
'cuda'
and
fluid
.
is_compiled_with_cuda
()
\
if
env_info
[
'
Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
fluid
.
CPUPlace
()
if
args
.
dataset
not
in
DATASETS
:
if
args
.
dataset
not
in
DATASETS
:
...
...
dygraph/train.py
浏览文件 @
362338cb
...
@@ -22,6 +22,7 @@ import dygraph.transforms as T
...
@@ -22,6 +22,7 @@ import dygraph.transforms as T
#from dygraph.models import MODELS
#from dygraph.models import MODELS
from
dygraph.cvlibs
import
manager
from
dygraph.cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
logger
from
dygraph.core
import
train
from
dygraph.core
import
train
...
@@ -130,8 +131,13 @@ def parse_args():
...
@@ -130,8 +131,13 @@ def parse_args():
def
main
(
args
):
def
main
(
args
):
env_info
=
get_environ_info
()
env_info
=
get_environ_info
()
info
=
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
env_info
.
items
()]
info
=
'
\n
'
.
join
([
'
\n
'
,
format
(
'Environment Information'
,
'-^48s'
)]
+
info
+
[
'-'
*
48
])
logger
.
info
(
info
)
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'
place'
]
==
'cuda'
and
fluid
.
is_compiled_with_cuda
()
\
if
env_info
[
'
Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
fluid
.
CPUPlace
()
if
args
.
dataset
not
in
DATASETS
:
if
args
.
dataset
not
in
DATASETS
:
...
...
dygraph/utils/__init__.py
浏览文件 @
362338cb
...
@@ -12,8 +12,9 @@
...
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.
import
logg
ing
from
.
import
logg
er
from
.
import
download
from
.
import
download
from
.metrics
import
ConfusionMatrix
from
.metrics
import
ConfusionMatrix
from
.utils
import
*
from
.utils
import
*
from
.timer
import
Timer
,
calculate_eta
from
.timer
import
Timer
,
calculate_eta
from
.get_environ_info
import
get_environ_info
dygraph/utils/get_environ_info.py
0 → 100644
浏览文件 @
362338cb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
collections
import
OrderedDict
import
subprocess
import
glob
import
paddle
import
paddle.fluid
as
fluid
import
cv2
IS_WINDOWS
=
sys
.
platform
==
'win32'
def
_find_cuda_home
():
'''Finds the CUDA install path. It refers to the implementation of
pytorch <https://github.com/pytorch/pytorch/blob/master/torch/utils/cpp_extension.py>.
'''
# Guess #1
cuda_home
=
os
.
environ
.
get
(
'CUDA_HOME'
)
or
os
.
environ
.
get
(
'CUDA_PATH'
)
if
cuda_home
is
None
:
# Guess #2
try
:
which
=
'where'
if
IS_WINDOWS
else
'which'
nvcc
=
subprocess
.
check_output
([
which
,
'nvcc'
]).
decode
().
rstrip
(
'
\r\n
'
)
cuda_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
nvcc
))
except
Exception
:
# Guess #3
if
IS_WINDOWS
:
cuda_homes
=
glob
.
glob
(
'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*'
)
if
len
(
cuda_homes
)
==
0
:
cuda_home
=
''
else
:
cuda_home
=
cuda_homes
[
0
]
else
:
cuda_home
=
'/usr/local/cuda'
if
not
os
.
path
.
exists
(
cuda_home
):
cuda_home
=
None
return
cuda_home
def
_get_nvcc_info
(
cuda_home
):
if
cuda_home
is
not
None
and
os
.
path
.
isdir
(
cuda_home
):
try
:
nvcc
=
os
.
path
.
join
(
cuda_home
,
'bin/nvcc'
)
nvcc
=
subprocess
.
check_output
(
"{} -V"
.
format
(
nvcc
),
shell
=
True
).
decode
()
nvcc
=
nvcc
.
strip
().
split
(
'
\n
'
)[
-
1
]
except
subprocess
.
SubprocessError
:
nvcc
=
"Not Available"
return
nvcc
def
_get_gpu_info
():
try
:
gpu_info
=
subprocess
.
check_output
([
'nvidia-smi'
,
'-L'
]).
decode
().
strip
()
gpu_info
=
gpu_info
.
split
(
'
\n
'
)
for
i
in
range
(
len
(
gpu_info
)):
gpu_info
[
i
]
=
' '
.
join
(
gpu_info
[
i
].
split
(
' '
)[:
4
])
except
:
gpu_info
=
' Can not get GPU information. Please make sure CUDA have been installed successfully.'
return
gpu_info
def
get_environ_info
():
"""collect environment information"""
env_info
=
{}
env_info
[
'System Platform'
]
=
sys
.
platform
if
env_info
[
'System Platform'
]
==
'linux'
:
lsb_v
=
subprocess
.
check_output
([
'lsb_release'
,
'-v'
]).
decode
().
strip
()
lsb_v
=
lsb_v
.
replace
(
'
\t
'
,
' '
)
lsb_d
=
subprocess
.
check_output
([
'lsb_release'
,
'-d'
]).
decode
().
strip
()
lsb_d
=
lsb_d
.
replace
(
'
\t
'
,
' '
)
env_info
[
'LSB'
]
=
[
lsb_v
,
lsb_d
]
env_info
[
'Python'
]
=
sys
.
version
.
replace
(
'
\n
'
,
''
)
compiled_with_cuda
=
paddle
.
fluid
.
is_compiled_with_cuda
()
env_info
[
'Paddle compiled with cuda'
]
=
compiled_with_cuda
if
compiled_with_cuda
:
cuda_home
=
_find_cuda_home
()
env_info
[
'NVCC'
]
=
_get_nvcc_info
(
cuda_home
)
gpu_nums
=
fluid
.
core
.
get_cuda_device_count
()
env_info
[
'GPUs used'
]
=
gpu_nums
env_info
[
'CUDA_VISIBLE_DEVICES'
]
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
env_info
[
'GPU'
]
=
_get_gpu_info
()
gcc
=
subprocess
.
check_output
([
'gcc'
,
'--version'
]).
decode
()
gcc
=
gcc
.
strip
().
split
(
'
\n
'
)[
0
]
env_info
[
'GCC'
]
=
gcc
env_info
[
'PaddlePaddle'
]
=
paddle
.
__version__
env_info
[
'OpenCV'
]
=
cv2
.
__version__
return
env_info
dygraph/utils/logg
ing
.py
→
dygraph/utils/logg
er
.py
浏览文件 @
362338cb
文件已移动
dygraph/utils/utils.py
浏览文件 @
362338cb
...
@@ -18,7 +18,7 @@ import math
...
@@ -18,7 +18,7 @@ import math
import
cv2
import
cv2
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
.
import
logg
ing
from
.
import
logg
er
def
seconds_to_hms
(
seconds
):
def
seconds_to_hms
(
seconds
):
...
@@ -29,27 +29,9 @@ def seconds_to_hms(seconds):
...
@@ -29,27 +29,9 @@ def seconds_to_hms(seconds):
return
hms_str
return
hms_str
def
get_environ_info
():
info
=
dict
()
info
[
'place'
]
=
'cpu'
info
[
'num'
]
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
if
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
!=
""
:
if
hasattr
(
fluid
.
core
,
'get_cuda_device_count'
):
gpu_num
=
0
try
:
gpu_num
=
fluid
.
core
.
get_cuda_device_count
()
except
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
pass
if
gpu_num
>
0
:
info
[
'place'
]
=
'cuda'
info
[
'num'
]
=
fluid
.
core
.
get_cuda_device_count
()
return
info
def
load_pretrained_model
(
model
,
pretrained_model
):
def
load_pretrained_model
(
model
,
pretrained_model
):
if
pretrained_model
is
not
None
:
if
pretrained_model
is
not
None
:
logg
ing
.
info
(
'Load pretrained model from {}'
.
format
(
pretrained_model
))
logg
er
.
info
(
'Load pretrained model from {}'
.
format
(
pretrained_model
))
if
os
.
path
.
exists
(
pretrained_model
):
if
os
.
path
.
exists
(
pretrained_model
):
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
try
:
try
:
...
@@ -62,10 +44,10 @@ def load_pretrained_model(model, pretrained_model):
...
@@ -62,10 +44,10 @@ def load_pretrained_model(model, pretrained_model):
num_params_loaded
=
0
num_params_loaded
=
0
for
k
in
keys
:
for
k
in
keys
:
if
k
not
in
para_state_dict
:
if
k
not
in
para_state_dict
:
logg
ing
.
warning
(
"{} is not in pretrained model"
.
format
(
k
))
logg
er
.
warning
(
"{} is not in pretrained model"
.
format
(
k
))
elif
list
(
para_state_dict
[
k
].
shape
)
!=
list
(
elif
list
(
para_state_dict
[
k
].
shape
)
!=
list
(
model_state_dict
[
k
].
shape
):
model_state_dict
[
k
].
shape
):
logg
ing
.
warning
(
logg
er
.
warning
(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
k
,
para_state_dict
[
k
].
shape
,
.
format
(
k
,
para_state_dict
[
k
].
shape
,
model_state_dict
[
k
].
shape
))
model_state_dict
[
k
].
shape
))
...
@@ -73,7 +55,7 @@ def load_pretrained_model(model, pretrained_model):
...
@@ -73,7 +55,7 @@ def load_pretrained_model(model, pretrained_model):
model_state_dict
[
k
]
=
para_state_dict
[
k
]
model_state_dict
[
k
]
=
para_state_dict
[
k
]
num_params_loaded
+=
1
num_params_loaded
+=
1
model
.
set_dict
(
model_state_dict
)
model
.
set_dict
(
model_state_dict
)
logg
ing
.
info
(
"There are {}/{} varaibles are loaded."
.
format
(
logg
er
.
info
(
"There are {}/{} varaibles are loaded."
.
format
(
num_params_loaded
,
len
(
model_state_dict
)))
num_params_loaded
,
len
(
model_state_dict
)))
else
:
else
:
...
@@ -81,12 +63,12 @@ def load_pretrained_model(model, pretrained_model):
...
@@ -81,12 +63,12 @@ def load_pretrained_model(model, pretrained_model):
'The pretrained model directory is not Found: {}'
.
format
(
'The pretrained model directory is not Found: {}'
.
format
(
pretrained_model
))
pretrained_model
))
else
:
else
:
logg
ing
.
info
(
'No pretrained model to load, train from scratch'
)
logg
er
.
info
(
'No pretrained model to load, train from scratch'
)
def
resume
(
model
,
optimizer
,
resume_model
):
def
resume
(
model
,
optimizer
,
resume_model
):
if
resume_model
is
not
None
:
if
resume_model
is
not
None
:
logg
ing
.
info
(
'Resume model from {}'
.
format
(
resume_model
))
logg
er
.
info
(
'Resume model from {}'
.
format
(
resume_model
))
if
os
.
path
.
exists
(
resume_model
):
if
os
.
path
.
exists
(
resume_model
):
resume_model
=
os
.
path
.
normpath
(
resume_model
)
resume_model
=
os
.
path
.
normpath
(
resume_model
)
ckpt_path
=
os
.
path
.
join
(
resume_model
,
'model'
)
ckpt_path
=
os
.
path
.
join
(
resume_model
,
'model'
)
...
@@ -102,7 +84,7 @@ def resume(model, optimizer, resume_model):
...
@@ -102,7 +84,7 @@ def resume(model, optimizer, resume_model):
'The resume model directory is not Found: {}'
.
format
(
'The resume model directory is not Found: {}'
.
format
(
resume_model
))
resume_model
))
else
:
else
:
logg
ing
.
info
(
'No model need to resume'
)
logg
er
.
info
(
'No model need to resume'
)
def
visualize
(
image
,
result
,
save_dir
=
None
,
weight
=
0.6
):
def
visualize
(
image
,
result
,
save_dir
=
None
,
weight
=
0.6
):
...
...
dygraph/val.py
浏览文件 @
362338cb
...
@@ -72,7 +72,7 @@ def parse_args():
...
@@ -72,7 +72,7 @@ def parse_args():
def
main
(
args
):
def
main
(
args
):
env_info
=
get_environ_info
()
env_info
=
get_environ_info
()
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'
place'
]
==
'cuda'
and
fluid
.
is_compiled_with_cuda
()
\
if
env_info
[
'
Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
fluid
.
CPUPlace
()
if
args
.
dataset
not
in
DATASETS
:
if
args
.
dataset
not
in
DATASETS
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录