Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
746122f4
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
746122f4
编写于
8月 29, 2020
作者:
F
FlyingQianMM
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add analysis for seg dataset
上级
564417a4
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
956 addition
and
19 deletion
+956
-19
paddlex/__init__.py
paddlex/__init__.py
+1
-1
paddlex/cv/datasets/__init__.py
paddlex/cv/datasets/__init__.py
+1
-0
paddlex/cv/datasets/analysis.py
paddlex/cv/datasets/analysis.py
+366
-0
paddlex/cv/datasets/seg_dataset.py
paddlex/cv/datasets/seg_dataset.py
+4
-1
paddlex/cv/models/deeplabv3p.py
paddlex/cv/models/deeplabv3p.py
+3
-0
paddlex/cv/models/fast_scnn.py
paddlex/cv/models/fast_scnn.py
+3
-0
paddlex/cv/models/hrnet.py
paddlex/cv/models/hrnet.py
+3
-0
paddlex/cv/models/ppyolo.py
paddlex/cv/models/ppyolo.py
+1
-1
paddlex/cv/nets/segmentation/deeplabv3p.py
paddlex/cv/nets/segmentation/deeplabv3p.py
+7
-2
paddlex/cv/nets/segmentation/fast_scnn.py
paddlex/cv/nets/segmentation/fast_scnn.py
+7
-2
paddlex/cv/nets/segmentation/hrnet.py
paddlex/cv/nets/segmentation/hrnet.py
+7
-2
paddlex/cv/transforms/seg_transforms.py
paddlex/cv/transforms/seg_transforms.py
+17
-2
paddlex/remotesensing/__init__.py
paddlex/remotesensing/__init__.py
+0
-0
paddlex/remotesensing/train_demo.py
paddlex/remotesensing/train_demo.py
+5
-5
paddlex/remotesensing/transforms.py
paddlex/remotesensing/transforms.py
+25
-3
paddlex/remotesensing/utils/__init__.py
paddlex/remotesensing/utils/__init__.py
+0
-0
paddlex/remotesensing/utils/analyse.py
paddlex/remotesensing/utils/analyse.py
+506
-0
未找到文件。
paddlex/__init__.py
浏览文件 @
746122f4
...
@@ -32,7 +32,7 @@ from . import slim
...
@@ -32,7 +32,7 @@ from . import slim
from
.
import
convertor
from
.
import
convertor
from
.
import
tools
from
.
import
tools
from
.
import
deploy
from
.
import
deploy
from
.
import
RemoteS
ensing
from
.
import
remotes
ensing
try
:
try
:
import
pycocotools
import
pycocotools
...
...
paddlex/cv/datasets/__init__.py
浏览文件 @
746122f4
...
@@ -20,3 +20,4 @@ from .easydata_cls import EasyDataCls
...
@@ -20,3 +20,4 @@ from .easydata_cls import EasyDataCls
from
.easydata_det
import
EasyDataDet
from
.easydata_det
import
EasyDataDet
from
.easydata_seg
import
EasyDataSeg
from
.easydata_seg
import
EasyDataSeg
from
.dataset
import
generate_minibatch
from
.dataset
import
generate_minibatch
from
.analysis
import
Seg
paddlex/cv/datasets/analysis.py
0 → 100644
浏览文件 @
746122f4
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
import
numpy
as
np
import
os.path
as
osp
import
cv2
from
PIL
import
Image
import
pickle
import
threading
import
multiprocessing
as
mp
import
paddlex.utils.logging
as
logging
from
paddlex.utils
import
path_normalization
from
.dataset
import
get_encoding
class
Seg
:
def
__init__
(
self
,
data_dir
,
file_list
,
label_list
):
self
.
data_dir
=
data_dir
self
.
file_list_path
=
file_list
self
.
file_list
=
list
()
self
.
labels
=
list
()
with
open
(
label_list
,
encoding
=
get_encoding
(
label_list
))
as
f
:
for
line
in
f
:
item
=
line
.
strip
()
self
.
labels
.
append
(
item
)
with
open
(
file_list
,
encoding
=
get_encoding
(
file_list
))
as
f
:
for
line
in
f
:
if
line
.
count
(
" "
)
>
1
:
raise
Exception
(
"A space is defined as the separator, but it exists in image or label name {}."
.
format
(
line
))
items
=
line
.
strip
().
split
()
items
[
0
]
=
path_normalization
(
items
[
0
])
items
[
1
]
=
path_normalization
(
items
[
1
])
full_path_im
=
osp
.
join
(
data_dir
,
items
[
0
])
full_path_label
=
osp
.
join
(
data_dir
,
items
[
1
])
if
not
osp
.
exists
(
full_path_im
):
raise
IOError
(
'The image file {} is not exist!'
.
format
(
full_path_im
))
if
not
osp
.
exists
(
full_path_label
):
raise
IOError
(
'The image file {} is not exist!'
.
format
(
full_path_label
))
self
.
file_list
.
append
([
full_path_im
,
full_path_label
])
self
.
num_samples
=
len
(
self
.
file_list
)
@
staticmethod
def
decode_image
(
im
,
label
):
if
isinstance
(
im
,
np
.
ndarray
):
if
len
(
im
.
shape
)
!=
3
:
raise
Exception
(
"im should be 3-dimensions, but now is {}-dimensions"
.
format
(
len
(
im
.
shape
)))
else
:
try
:
im
=
cv2
.
imread
(
im
)
except
:
raise
ValueError
(
'Can
\'
t read The image file {}!'
.
format
(
im
))
im
=
im
.
astype
(
'float32'
)
if
label
is
not
None
:
if
isinstance
(
label
,
np
.
ndarray
):
if
len
(
label
.
shape
)
!=
2
:
raise
Exception
(
"label should be 2-dimensions, but now is {}-dimensions"
.
format
(
len
(
label
.
shape
)))
else
:
try
:
label
=
np
.
asarray
(
Image
.
open
(
label
))
except
:
ValueError
(
'Can
\'
t read The label file {}!'
.
format
(
label
))
im_height
,
im_width
,
_
=
im
.
shape
label_height
,
label_width
=
label
.
shape
if
im_height
!=
label_height
or
im_width
!=
label_width
:
raise
Exception
(
"The height or width of the image is not same as the label"
)
return
(
im
,
label
)
def
_get_shape
(
self
):
max_height
=
max
(
self
.
im_height_list
)
max_width
=
max
(
self
.
im_width_list
)
min_height
=
min
(
self
.
im_height_list
)
min_width
=
min
(
self
.
im_width_list
)
shape_info
=
{
'max_height'
:
max_height
,
'max_width'
:
max_width
,
'min_height'
:
min_height
,
'min_width'
:
min_width
,
}
return
shape_info
def
_get_label_pixel_info
(
self
):
pixel_num
=
np
.
dot
(
self
.
im_height_list
,
self
.
im_width_list
)
label_pixel_info
=
dict
()
for
label_value
,
label_value_num
in
zip
(
self
.
label_value_list
,
self
.
label_value_num_list
):
for
v
,
n
in
zip
(
label_value
,
label_value_num
):
if
v
not
in
label_pixel_info
.
keys
():
label_pixel_info
[
v
]
=
[
n
,
float
(
n
)
/
float
(
pixel_num
)]
else
:
label_pixel_info
[
v
][
0
]
+=
n
label_pixel_info
[
v
][
1
]
+=
float
(
n
)
/
float
(
pixel_num
)
return
label_pixel_info
def
_get_image_pixel_info
(
self
):
channel
=
max
([
len
(
im_value
)
for
im_value
in
self
.
im_value_list
])
im_pixel_info
=
[
dict
()
for
c
in
range
(
channel
)]
for
im_value
,
im_value_num
in
zip
(
self
.
im_value_list
,
self
.
im_value_num_list
):
for
c
in
range
(
channel
):
for
v
,
n
in
zip
(
im_value
[
c
],
im_value_num
[
c
]):
if
v
not
in
im_pixel_info
[
c
].
keys
():
im_pixel_info
[
c
][
v
]
=
n
else
:
im_pixel_info
[
c
][
v
]
+=
n
mode
=
osp
.
split
(
self
.
file_list_path
)[
-
1
].
split
(
'.'
)[
0
]
with
open
(
osp
.
join
(
self
.
data_dir
,
'{}_image_pixel_info.pkl'
.
format
(
mode
)),
'wb'
)
as
f
:
pickle
.
dump
(
im_pixel_info
,
f
)
import
matplotlib.pyplot
as
plt
plot_id
=
(
channel
//
3
+
1
)
*
100
+
31
for
c
in
range
(
channel
):
if
c
>
8
:
continue
plt
.
subplot
(
plot_id
+
c
)
plt
.
bar
(
im_pixel_info
[
c
].
keys
(),
im_pixel_info
[
c
].
values
(),
width
=
1
,
log
=
True
)
plt
.
xlabel
(
'image pixel value'
)
plt
.
ylabel
(
'number'
)
plt
.
title
(
'channel={}'
.
format
(
c
))
plt
.
savefig
(
osp
.
join
(
self
.
data_dir
,
'{}_image_pixel_info.png'
.
format
(
mode
)),
dpi
=
800
)
plt
.
close
()
return
im_pixel_info
def
_get_mean_std
(
self
):
im_mean
=
np
.
asarray
(
self
.
im_mean_list
)
im_mean
=
im_mean
.
sum
(
axis
=
0
)
im_mean
=
im_mean
/
len
(
self
.
file_list
)
im_mean
/=
255.
im_std
=
np
.
asarray
(
self
.
im_std_list
)
im_std
=
im_std
.
sum
(
axis
=
0
)
im_std
=
im_std
/
len
(
self
.
file_list
)
im_std
/=
255.
return
(
im_mean
,
im_std
)
def
_get_image_info
(
self
,
start
,
end
):
for
id
in
range
(
start
,
end
):
full_path_im
,
full_path_label
=
self
.
file_list
[
id
]
image
,
label
=
self
.
decode_image
(
full_path_im
,
full_path_label
)
height
,
width
,
channel
=
image
.
shape
self
.
im_height_list
[
id
]
=
height
self
.
im_width_list
[
id
]
=
width
self
.
im_channel_list
[
id
]
=
channel
self
.
im_mean_list
[
id
]
=
[
np
.
mean
(
image
[:,
:,
c
])
for
c
in
range
(
channel
)]
self
.
im_std_list
[
id
]
=
[
np
.
mean
(
image
[:,
:,
c
])
for
c
in
range
(
channel
)]
for
c
in
range
(
channel
):
unique
,
counts
=
np
.
unique
(
image
[:,
:,
c
],
return_counts
=
True
)
self
.
im_value_list
[
id
].
extend
([
unique
])
self
.
im_value_num_list
[
id
].
extend
([
counts
])
unique
,
counts
=
np
.
unique
(
label
,
return_counts
=
True
)
self
.
label_value_list
[
id
]
=
unique
self
.
label_value_num_list
[
id
]
=
counts
def
_get_clipped_mean_std
(
self
,
start
,
end
,
clip_min_value
,
clip_max_value
):
for
id
in
range
(
start
,
end
):
full_path_im
,
full_path_label
=
self
.
file_list
[
id
]
image
,
label
=
self
.
decode_image
(
full_path_im
,
full_path_label
)
for
c
in
range
(
self
.
channel_num
):
np
.
clip
(
image
[:,
:,
c
],
clip_min_value
[
c
],
clip_max_value
[
c
],
out
=
image
[:,
:,
c
])
image
[:,
:,
c
]
-=
clip_min_value
[
c
]
image
[:,
:,
c
]
/=
clip_max_value
[
c
]
-
clip_min_value
[
c
]
self
.
clipped_im_mean_list
[
id
]
=
[
image
[:,
:,
c
].
mean
()
for
c
in
range
(
self
.
channel_num
)
]
self
.
clipped_im_std_list
[
id
]
=
[
image
[:,
:,
c
].
std
()
for
c
in
range
(
self
.
channel_num
)]
def
analysis
(
self
):
self
.
im_mean_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
self
.
im_std_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
self
.
im_value_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
self
.
im_value_num_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
self
.
im_height_list
=
np
.
zeros
(
len
(
self
.
file_list
),
dtype
=
'int32'
)
self
.
im_width_list
=
np
.
zeros
(
len
(
self
.
file_list
),
dtype
=
'int32'
)
self
.
im_channel_list
=
np
.
zeros
(
len
(
self
.
file_list
),
dtype
=
'int32'
)
self
.
label_value_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
self
.
label_value_num_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
num_workers
=
mp
.
cpu_count
()
//
2
if
mp
.
cpu_count
()
//
2
<
8
else
8
num_workers
=
6
threads
=
[]
one_worker_file
=
len
(
self
.
file_list
)
//
num_workers
for
i
in
range
(
num_workers
):
start
=
one_worker_file
*
i
end
=
one_worker_file
*
(
i
+
1
)
if
i
<
num_workers
-
1
else
len
(
self
.
file_list
)
t
=
threading
.
Thread
(
target
=
self
.
_get_image_info
,
args
=
(
start
,
end
))
print
(
"===="
,
len
(
self
.
file_list
),
start
,
end
)
#t.daemon = True
threads
.
append
(
t
)
for
t
in
threads
:
t
.
start
()
for
t
in
threads
:
t
.
join
()
print
(
'ok'
)
import
time
import
sys
sys
.
exit
(
0
)
time
.
sleep
(
1000000
)
return
#self._get_image_info(0, len(self.file_list))
unique
,
counts
=
np
.
unique
(
self
.
im_channel_list
,
return_counts
=
True
)
print
(
'==== unique'
)
if
len
(
unique
)
>
1
:
raise
Exception
(
"There are {} kinds of image channels: {}."
.
format
(
len
(
unique
),
unique
[:]))
self
.
channel_num
=
unique
[
0
]
shape_info
=
self
.
_get_shape
()
print
(
'==== shape_info'
)
self
.
max_height
=
shape_info
[
'max_height'
]
self
.
max_width
=
shape_info
[
'max_width'
]
self
.
min_height
=
shape_info
[
'min_height'
]
self
.
min_width
=
shape_info
[
'min_width'
]
self
.
label_pixel_info
=
self
.
_get_label_pixel_info
()
print
(
'==== label_pixel_info'
)
self
.
im_pixel_info
=
self
.
_get_image_pixel_info
()
print
(
'==== im_pixel_info'
)
im_mean
,
im_std
=
self
.
_get_mean_std
()
print
(
'==== get_mean_std'
)
max_im_value
=
list
()
min_im_value
=
list
()
for
c
in
range
(
self
.
channel_num
):
max_im_value
.
append
(
max
(
self
.
im_pixel_info
[
c
].
keys
()))
min_im_value
.
append
(
min
(
self
.
im_pixel_info
[
c
].
keys
()))
self
.
max_im_value
=
np
.
asarray
(
max_im_value
)
self
.
min_im_value
=
np
.
asarray
(
min_im_value
)
logging
.
info
(
"############## The analysis results are as follows ##############
\n
"
)
logging
.
info
(
"{} samples in file {}
\n
"
.
format
(
len
(
self
.
file_list
),
self
.
file_list_path
))
logging
.
info
(
"Maximal image height: {} Maximal image width: {}.
\n
"
.
format
(
self
.
max_height
,
self
.
max_width
))
logging
.
info
(
"Minimal image height: {} Minimal image width: {}.
\n
"
.
format
(
self
.
min_height
,
self
.
min_width
))
logging
.
info
(
"Image channel is {}.
\n
"
.
format
(
self
.
channel_num
))
logging
.
info
(
"Image mean value: {} Image standard deviation: {} (normalized by 255, sorted by a BGR format).
\n
"
.
format
(
im_mean
,
im_std
))
logging
.
info
(
"Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):"
)
for
v
,
(
n
,
r
)
in
self
.
label_pixel_info
.
items
():
logging
.
info
(
"({}, {}, {})"
.
format
(
v
,
n
,
r
))
mode
=
osp
.
split
(
self
.
file_list_path
)[
-
1
].
split
(
'.'
)[
0
]
saved_pkl_file
=
osp
.
join
(
self
.
data_dir
,
'{}_image_pixel_info.pkl'
.
format
(
mode
))
saved_png_file
=
osp
.
join
(
self
.
data_dir
,
'{}_image_pixel_info.png'
.
format
(
mode
))
logging
.
info
(
"Image pixel information is saved in the file '{}' and shown in the file '{}'"
.
format
(
saved_pkl_file
,
saved_png_file
))
def
cal_clipvalue_ratio
(
self
,
clip_min_value
,
clip_max_value
):
if
len
(
clip_min_value
)
!=
self
.
channel_num
or
len
(
clip_max_value
)
!=
self
.
channel_num
:
raise
Exception
(
"The length of clip_min_value or clip_max_value should be equal to the number of image channel {}."
.
format
(
self
.
channle_num
))
for
c
in
range
(
self
.
channel_num
):
if
clip_min_value
[
c
]
<
self
.
min_im_value
[
c
]
or
clip_min_value
[
c
]
>
self
.
max_im_value
[
c
]:
raise
Exception
(
"Clip_min_value of the channel {} is not in [{}, {}]"
.
format
(
c
,
self
.
min_im_value
[
c
],
self
.
max_im_value
[
c
]))
if
clip_max_value
[
c
]
<
self
.
min_im_value
[
c
]
or
clip_max_value
[
c
]
>
self
.
max_im_value
[
c
]:
raise
Exception
(
"Clip_max_value of the channel {} is not in [{}, {}]"
.
format
(
c
,
self
.
min_im_value
[
c
],
self
.
max_im_value
[
c
]))
clip_pixel_num
=
0
pixel_num
=
sum
(
self
.
im_pixel_info
[
c
].
values
())
for
v
,
n
in
self
.
im_pixel_info
[
c
].
items
():
if
v
<
clip_min_value
[
c
]
or
v
>
clip_max_value
[
c
]:
clip_pixel_num
+=
n
logging
.
info
(
"Channel {}, the ratio of pixels to be clipped = {}"
.
format
(
c
,
clip_pixel_num
/
pixel_num
))
def
cal_clipped_mean_std
(
self
,
clip_min_value
,
clip_max_value
):
for
c
in
range
(
self
.
channel_num
):
if
clip_min_value
[
c
]
<
self
.
min_im_value
[
c
]
or
clip_min_value
[
c
]
>
self
.
max_im_value
[
c
]:
raise
Exception
(
"Clip_min_value of the channel {} is not in [{}, {}]"
.
format
(
c
,
self
.
min_im_value
[
c
],
self
.
max_im_value
[
c
]))
if
clip_max_value
[
c
]
<
self
.
min_im_value
[
c
]
or
clip_max_value
[
c
]
>
self
.
max_im_value
[
c
]:
raise
Exception
(
"Clip_max_value of the channel {} is not in [{}, {}]"
.
format
(
c
,
self
.
min_im_value
[
c
],
self
.
max_im_value
[
c
]))
self
.
clipped_im_mean_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
self
.
clipped_im_std_list
=
[[]
for
i
in
range
(
len
(
self
.
file_list
))]
num_workers
=
mp
.
cpu_count
()
//
2
if
mp
.
cpu_count
()
//
2
<
8
else
8
threads
=
[]
one_worker_file
=
len
(
self
.
file_list
)
//
num_workers
for
i
in
range
(
num_workers
):
start
=
one_worker_file
*
i
end
=
one_worker_file
*
(
i
+
1
)
if
i
<
num_workers
-
1
else
len
(
self
.
file_list
)
t
=
threading
.
Thread
(
target
=
self
.
_get_clipped_mean_std
,
args
=
(
start
,
end
,
clip_min_value
,
clip_max_value
))
threads
.
append
(
t
)
for
t
in
threads
:
t
.
setDaemon
(
True
)
t
.
start
()
t
.
join
()
im_mean
=
np
.
asarray
(
self
.
clipped_im_mean_list
)
im_mean
=
im_mean
.
sum
(
axis
=
0
)
im_mean
=
im_mean
/
len
(
self
.
file_list
)
im_std
=
np
.
asarray
(
self
.
clipped_im_std_list
)
im_std
=
im_std
.
sum
(
axis
=
0
)
im_std
=
im_std
/
len
(
self
.
file_list
)
logging
.
info
(
"Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value)).
\n
"
.
format
(
im_mean
,
im_std
))
paddlex/cv/datasets/seg_dataset.py
浏览文件 @
746122f4
...
@@ -20,7 +20,6 @@ import paddlex.utils.logging as logging
...
@@ -20,7 +20,6 @@ import paddlex.utils.logging as logging
from
paddlex.utils
import
path_normalization
from
paddlex.utils
import
path_normalization
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
.dataset
import
get_encoding
from
.dataset
import
get_encoding
from
.dataset
import
is_pic
class
SegDataset
(
Dataset
):
class
SegDataset
(
Dataset
):
...
@@ -64,6 +63,10 @@ class SegDataset(Dataset):
...
@@ -64,6 +63,10 @@ class SegDataset(Dataset):
self
.
labels
.
append
(
item
)
self
.
labels
.
append
(
item
)
with
open
(
file_list
,
encoding
=
get_encoding
(
file_list
))
as
f
:
with
open
(
file_list
,
encoding
=
get_encoding
(
file_list
))
as
f
:
for
line
in
f
:
for
line
in
f
:
if
line
.
count
(
" "
)
>
1
:
raise
Exception
(
"A space is defined as the separator, but it exists in image or label name {}."
.
format
(
line
))
items
=
line
.
strip
().
split
()
items
=
line
.
strip
().
split
()
items
[
0
]
=
path_normalization
(
items
[
0
])
items
[
0
]
=
path_normalization
(
items
[
0
])
items
[
1
]
=
path_normalization
(
items
[
1
])
items
[
1
]
=
path_normalization
(
items
[
1
])
...
...
paddlex/cv/models/deeplabv3p.py
浏览文件 @
746122f4
...
@@ -65,6 +65,7 @@ class DeepLabv3p(BaseAPI):
...
@@ -65,6 +65,7 @@ class DeepLabv3p(BaseAPI):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
=
2
,
num_classes
=
2
,
input_channel
=
3
,
backbone
=
'MobileNetV2_x1.0'
,
backbone
=
'MobileNetV2_x1.0'
,
output_stride
=
16
,
output_stride
=
16
,
aspp_with_sep_conv
=
True
,
aspp_with_sep_conv
=
True
,
...
@@ -114,6 +115,7 @@ class DeepLabv3p(BaseAPI):
...
@@ -114,6 +115,7 @@ class DeepLabv3p(BaseAPI):
self
.
backbone
=
backbone
self
.
backbone
=
backbone
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
input_channel
=
input_channel
self
.
use_bce_loss
=
use_bce_loss
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
class_weight
=
class_weight
self
.
class_weight
=
class_weight
...
@@ -215,6 +217,7 @@ class DeepLabv3p(BaseAPI):
...
@@ -215,6 +217,7 @@ class DeepLabv3p(BaseAPI):
def
build_net
(
self
,
mode
=
'train'
):
def
build_net
(
self
,
mode
=
'train'
):
model
=
paddlex
.
cv
.
nets
.
segmentation
.
DeepLabv3p
(
model
=
paddlex
.
cv
.
nets
.
segmentation
.
DeepLabv3p
(
self
.
num_classes
,
self
.
num_classes
,
input_channel
=
self
.
input_channel
,
mode
=
mode
,
mode
=
mode
,
backbone
=
self
.
_get_backbone
(
self
.
backbone
),
backbone
=
self
.
_get_backbone
(
self
.
backbone
),
output_stride
=
self
.
output_stride
,
output_stride
=
self
.
output_stride
,
...
...
paddlex/cv/models/fast_scnn.py
浏览文件 @
746122f4
...
@@ -48,6 +48,7 @@ class FastSCNN(DeepLabv3p):
...
@@ -48,6 +48,7 @@ class FastSCNN(DeepLabv3p):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
=
2
,
num_classes
=
2
,
input_channel
=
3
,
use_bce_loss
=
False
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
use_dice_loss
=
False
,
class_weight
=
None
,
class_weight
=
None
,
...
@@ -86,6 +87,7 @@ class FastSCNN(DeepLabv3p):
...
@@ -86,6 +87,7 @@ class FastSCNN(DeepLabv3p):
)
)
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
input_channel
=
input_channel
self
.
use_bce_loss
=
use_bce_loss
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
class_weight
=
class_weight
self
.
class_weight
=
class_weight
...
@@ -97,6 +99,7 @@ class FastSCNN(DeepLabv3p):
...
@@ -97,6 +99,7 @@ class FastSCNN(DeepLabv3p):
def
build_net
(
self
,
mode
=
'train'
):
def
build_net
(
self
,
mode
=
'train'
):
model
=
paddlex
.
cv
.
nets
.
segmentation
.
FastSCNN
(
model
=
paddlex
.
cv
.
nets
.
segmentation
.
FastSCNN
(
self
.
num_classes
,
self
.
num_classes
,
input_channel
=
self
.
input_channel
,
mode
=
mode
,
mode
=
mode
,
use_bce_loss
=
self
.
use_bce_loss
,
use_bce_loss
=
self
.
use_bce_loss
,
use_dice_loss
=
self
.
use_dice_loss
,
use_dice_loss
=
self
.
use_dice_loss
,
...
...
paddlex/cv/models/hrnet.py
浏览文件 @
746122f4
...
@@ -44,6 +44,7 @@ class HRNet(DeepLabv3p):
...
@@ -44,6 +44,7 @@ class HRNet(DeepLabv3p):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
=
2
,
num_classes
=
2
,
input_channel
=
3
,
width
=
18
,
width
=
18
,
use_bce_loss
=
False
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
use_dice_loss
=
False
,
...
@@ -72,6 +73,7 @@ class HRNet(DeepLabv3p):
...
@@ -72,6 +73,7 @@ class HRNet(DeepLabv3p):
'Expect class_weight is a list or string but receive {}'
.
'Expect class_weight is a list or string but receive {}'
.
format
(
type
(
class_weight
)))
format
(
type
(
class_weight
)))
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
input_channel
=
input_channel
self
.
width
=
width
self
.
width
=
width
self
.
use_bce_loss
=
use_bce_loss
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
use_dice_loss
=
use_dice_loss
...
@@ -83,6 +85,7 @@ class HRNet(DeepLabv3p):
...
@@ -83,6 +85,7 @@ class HRNet(DeepLabv3p):
def
build_net
(
self
,
mode
=
'train'
):
def
build_net
(
self
,
mode
=
'train'
):
model
=
paddlex
.
cv
.
nets
.
segmentation
.
HRNet
(
model
=
paddlex
.
cv
.
nets
.
segmentation
.
HRNet
(
self
.
num_classes
,
self
.
num_classes
,
input_channel
=
self
.
input_channel
,
width
=
self
.
width
,
width
=
self
.
width
,
mode
=
mode
,
mode
=
mode
,
use_bce_loss
=
self
.
use_bce_loss
,
use_bce_loss
=
self
.
use_bce_loss
,
...
...
paddlex/cv/models/ppyolo.py
浏览文件 @
746122f4
...
@@ -36,7 +36,7 @@ class PPYOLO(BaseAPI):
...
@@ -36,7 +36,7 @@ class PPYOLO(BaseAPI):
Args:
Args:
num_classes (int): 类别数。默认为80。
num_classes (int): 类别数。默认为80。
backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd
']。默认为'ResNet50_v
d'。
backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd
_ssld']。默认为'ResNet50_vd_ssl
d'。
with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
...
...
paddlex/cv/nets/segmentation/deeplabv3p.py
浏览文件 @
746122f4
...
@@ -72,6 +72,7 @@ class DeepLabv3p(object):
...
@@ -72,6 +72,7 @@ class DeepLabv3p(object):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
,
backbone
,
backbone
,
input_channel
=
3
,
mode
=
'train'
,
mode
=
'train'
,
output_stride
=
16
,
output_stride
=
16
,
aspp_with_sep_conv
=
True
,
aspp_with_sep_conv
=
True
,
...
@@ -115,6 +116,7 @@ class DeepLabv3p(object):
...
@@ -115,6 +116,7 @@ class DeepLabv3p(object):
format
(
type
(
class_weight
)))
format
(
type
(
class_weight
)))
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
input_channel
=
input_channel
self
.
backbone
=
backbone
self
.
backbone
=
backbone
self
.
mode
=
mode
self
.
mode
=
mode
self
.
use_bce_loss
=
use_bce_loss
self
.
use_bce_loss
=
use_bce_loss
...
@@ -402,13 +404,16 @@ class DeepLabv3p(object):
...
@@ -402,13 +404,16 @@ class DeepLabv3p(object):
if
self
.
fixed_input_shape
is
not
None
:
if
self
.
fixed_input_shape
is
not
None
:
input_shape
=
[
input_shape
=
[
None
,
3
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
None
,
self
.
input_channel
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
]
]
inputs
[
'image'
]
=
fluid
.
data
(
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
else
:
else
:
inputs
[
'image'
]
=
fluid
.
data
(
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
3
,
None
,
None
],
name
=
'image'
)
dtype
=
'float32'
,
shape
=
[
None
,
self
.
input_channel
,
None
,
None
],
name
=
'image'
)
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
inputs
[
'label'
]
=
fluid
.
data
(
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
...
...
paddlex/cv/nets/segmentation/fast_scnn.py
浏览文件 @
746122f4
...
@@ -33,6 +33,7 @@ from .model_utils.loss import bce_loss
...
@@ -33,6 +33,7 @@ from .model_utils.loss import bce_loss
class
FastSCNN
(
object
):
class
FastSCNN
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
,
input_channel
=
3
,
mode
=
'train'
,
mode
=
'train'
,
use_bce_loss
=
False
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
use_dice_loss
=
False
,
...
@@ -62,6 +63,7 @@ class FastSCNN(object):
...
@@ -62,6 +63,7 @@ class FastSCNN(object):
format
(
type
(
class_weight
)))
format
(
type
(
class_weight
)))
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
input_channel
=
input_channel
self
.
mode
=
mode
self
.
mode
=
mode
self
.
use_bce_loss
=
use_bce_loss
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
use_dice_loss
=
use_dice_loss
...
@@ -137,13 +139,16 @@ class FastSCNN(object):
...
@@ -137,13 +139,16 @@ class FastSCNN(object):
inputs
=
OrderedDict
()
inputs
=
OrderedDict
()
if
self
.
fixed_input_shape
is
not
None
:
if
self
.
fixed_input_shape
is
not
None
:
input_shape
=
[
input_shape
=
[
None
,
3
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
None
,
self
.
input_channel
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
]
]
inputs
[
'image'
]
=
fluid
.
data
(
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
else
:
else
:
inputs
[
'image'
]
=
fluid
.
data
(
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
3
,
None
,
None
],
name
=
'image'
)
dtype
=
'float32'
,
shape
=
[
None
,
self
.
input_channel
,
None
,
None
],
name
=
'image'
)
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
inputs
[
'label'
]
=
fluid
.
data
(
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
...
...
paddlex/cv/nets/segmentation/hrnet.py
浏览文件 @
746122f4
...
@@ -32,6 +32,7 @@ import paddlex
...
@@ -32,6 +32,7 @@ import paddlex
class
HRNet
(
object
):
class
HRNet
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
,
input_channel
=
3
,
mode
=
'train'
,
mode
=
'train'
,
width
=
18
,
width
=
18
,
use_bce_loss
=
False
,
use_bce_loss
=
False
,
...
@@ -61,6 +62,7 @@ class HRNet(object):
...
@@ -61,6 +62,7 @@ class HRNet(object):
format
(
type
(
class_weight
)))
format
(
type
(
class_weight
)))
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
input_channel
=
input_channel
self
.
mode
=
mode
self
.
mode
=
mode
self
.
use_bce_loss
=
use_bce_loss
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
use_dice_loss
=
use_dice_loss
...
@@ -136,13 +138,16 @@ class HRNet(object):
...
@@ -136,13 +138,16 @@ class HRNet(object):
if
self
.
fixed_input_shape
is
not
None
:
if
self
.
fixed_input_shape
is
not
None
:
input_shape
=
[
input_shape
=
[
None
,
3
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
None
,
self
.
input_channel
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
]
]
inputs
[
'image'
]
=
fluid
.
data
(
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
else
:
else
:
inputs
[
'image'
]
=
fluid
.
data
(
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
3
,
None
,
None
],
name
=
'image'
)
dtype
=
'float32'
,
shape
=
[
None
,
self
.
input_channel
,
None
,
None
],
name
=
'image'
)
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
inputs
[
'label'
]
=
fluid
.
data
(
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
...
...
paddlex/cv/transforms/seg_transforms.py
浏览文件 @
746122f4
...
@@ -74,8 +74,22 @@ class Compose(SegTransform):
...
@@ -74,8 +74,22 @@ class Compose(SegTransform):
raise
ValueError
(
'Can
\'
t read The image file {}!'
.
format
(
im
))
raise
ValueError
(
'Can
\'
t read The image file {}!'
.
format
(
im
))
im
=
im
.
astype
(
'float32'
)
im
=
im
.
astype
(
'float32'
)
if
label
is
not
None
:
if
label
is
not
None
:
if
not
isinstance
(
label
,
np
.
ndarray
):
if
isinstance
(
label
,
np
.
ndarray
):
label
=
np
.
asarray
(
Image
.
open
(
label
))
if
len
(
label
.
shape
)
!=
2
:
raise
Exception
(
"label should be 2-dimensions, but now is {}-dimensions"
.
format
(
len
(
label
.
shape
)))
else
:
try
:
label
=
np
.
asarray
(
Image
.
open
(
label
))
except
:
ValueError
(
'Can
\'
t read The label file {}!'
.
format
(
label
))
im_height
,
im_width
,
_
=
im
.
shape
label_height
,
label_width
=
label
.
shape
if
im_height
!=
label_height
or
im_width
!=
label_width
:
raise
Exception
(
"The height or width of the image is not same as the label"
)
return
(
im
,
label
)
return
(
im
,
label
)
def
__call__
(
self
,
im
,
im_info
=
None
,
label
=
None
):
def
__call__
(
self
,
im
,
im_info
=
None
,
label
=
None
):
...
@@ -605,6 +619,7 @@ class Normalize(SegTransform):
...
@@ -605,6 +619,7 @@ class Normalize(SegTransform):
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
normalize
(
im
,
mean
,
std
,
self
.
min_val
,
self
.
max_val
)
im
=
normalize
(
im
,
mean
,
std
,
self
.
min_val
,
self
.
max_val
)
im
=
im
.
astype
(
'float32'
)
if
label
is
None
:
if
label
is
None
:
return
(
im
,
im_info
)
return
(
im
,
im_info
)
...
...
paddlex/
RemoteS
ensing/__init__.py
→
paddlex/
remotes
ensing/__init__.py
浏览文件 @
746122f4
文件已移动
paddlex/
RemoteS
ensing/train_demo.py
→
paddlex/
remotes
ensing/train_demo.py
浏览文件 @
746122f4
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
os.path
as
osp
import
os.path
as
osp
import
argparse
import
argparse
from
paddlex.seg
import
transforms
from
paddlex.seg
import
transforms
import
paddlex.
RemoteSensing.transforms
as
custom
_transforms
import
paddlex.
remotesensing.transforms
as
rs
_transforms
import
paddlex
as
pdx
import
paddlex
as
pdx
...
@@ -110,22 +110,22 @@ train_transforms = transforms.Compose([
...
@@ -110,22 +110,22 @@ train_transforms = transforms.Compose([
transforms
.
RandomHorizontalFlip
(
0.5
),
transforms
.
RandomHorizontalFlip
(
0.5
),
transforms
.
ResizeStepScaling
(
0.5
,
2.0
,
0.25
),
transforms
.
ResizeStepScaling
(
0.5
,
2.0
,
0.25
),
transforms
.
RandomPaddingCrop
(
im_padding_value
=
[
1000
]
*
channel
),
transforms
.
RandomPaddingCrop
(
im_padding_value
=
[
1000
]
*
channel
),
custom
_transforms
.
Clip
(
rs
_transforms
.
Clip
(
min_val
=
clip_min_value
,
max_val
=
clip_max_value
),
min_val
=
clip_min_value
,
max_val
=
clip_max_value
),
transforms
.
Normalize
(
transforms
.
Normalize
(
min_val
=
clip_min_value
,
max_val
=
clip_max_value
,
mean
=
mean
,
std
=
std
),
min_val
=
clip_min_value
,
max_val
=
clip_max_value
,
mean
=
mean
,
std
=
std
),
])
])
train_transforms
.
decode_image
=
custom
_transforms
.
decode_image
train_transforms
.
decode_image
=
rs
_transforms
.
decode_image
eval_transforms
=
transforms
.
Compose
([
eval_transforms
=
transforms
.
Compose
([
custom
_transforms
.
Clip
(
rs
_transforms
.
Clip
(
min_val
=
clip_min_value
,
max_val
=
clip_max_value
),
min_val
=
clip_min_value
,
max_val
=
clip_max_value
),
transforms
.
Normalize
(
transforms
.
Normalize
(
min_val
=
clip_min_value
,
max_val
=
clip_max_value
,
mean
=
mean
,
std
=
std
),
min_val
=
clip_min_value
,
max_val
=
clip_max_value
,
mean
=
mean
,
std
=
std
),
])
])
eval_transforms
.
decode_image
=
custom
_transforms
.
decode_image
eval_transforms
.
decode_image
=
rs
_transforms
.
decode_image
train_list
=
osp
.
join
(
data_dir
,
'train.txt'
)
train_list
=
osp
.
join
(
data_dir
,
'train.txt'
)
val_list
=
osp
.
join
(
data_dir
,
'val.txt'
)
val_list
=
osp
.
join
(
data_dir
,
'val.txt'
)
...
...
paddlex/
RemoteS
ensing/transforms.py
→
paddlex/
remotes
ensing/transforms.py
浏览文件 @
746122f4
...
@@ -2,17 +2,23 @@ import os
...
@@ -2,17 +2,23 @@ import os
import
os.path
as
osp
import
os.path
as
osp
import
imghdr
import
imghdr
import
gdal
import
gdal
gdal
.
UseExceptions
()
gdal
.
PushErrorHandler
(
'CPLQuietErrorHandler'
)
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
paddlex.seg
import
transforms
from
paddlex.seg
import
transforms
import
paddlex.utils.logging
as
logging
def
read_img
(
img_path
):
def
read_img
(
img_path
):
img_format
=
imghdr
.
what
(
img_path
)
img_format
=
imghdr
.
what
(
img_path
)
name
,
ext
=
osp
.
splitext
(
img_path
)
name
,
ext
=
osp
.
splitext
(
img_path
)
if
img_format
==
'tiff'
or
ext
==
'.img'
:
if
img_format
==
'tiff'
or
ext
==
'.img'
:
dataset
=
gdal
.
Open
(
img_path
)
try
:
dataset
=
gdal
.
Open
(
img_path
)
except
:
logging
.
error
(
gdal
.
GetLastErrorMsg
())
if
dataset
==
None
:
if
dataset
==
None
:
raise
Exception
(
'Can not open'
,
img_path
)
raise
Exception
(
'Can not open'
,
img_path
)
im_data
=
dataset
.
ReadAsArray
()
im_data
=
dataset
.
ReadAsArray
()
...
@@ -36,9 +42,25 @@ def decode_image(im, label):
...
@@ -36,9 +42,25 @@ def decode_image(im, label):
im
=
read_img
(
im
)
im
=
read_img
(
im
)
except
:
except
:
raise
ValueError
(
'Can
\'
t read The image file {}!'
.
format
(
im
))
raise
ValueError
(
'Can
\'
t read The image file {}!'
.
format
(
im
))
im
=
im
.
astype
(
'float32'
)
if
label
is
not
None
:
if
label
is
not
None
:
if
not
isinstance
(
label
,
np
.
ndarray
):
if
isinstance
(
label
,
np
.
ndarray
):
label
=
read_img
(
label
)
if
len
(
label
.
shape
)
!=
2
:
raise
Exception
(
"label should be 2-dimensions, but now is {}-dimensions"
.
format
(
len
(
label
.
shape
)))
else
:
try
:
label
=
np
.
asarray
(
Image
.
open
(
label
))
except
:
ValueError
(
'Can
\'
t read The label file {}!'
.
format
(
label
))
im_height
,
im_width
,
_
=
im
.
shape
label_height
,
label_width
=
label
.
shape
if
im_height
!=
label_height
or
im_width
!=
label_width
:
raise
Exception
(
"The height or width of the image is not same as the label"
)
return
(
im
,
label
)
return
(
im
,
label
)
...
...
paddlex/remotesensing/utils/__init__.py
0 → 100644
浏览文件 @
746122f4
paddlex/remotesensing/utils/analyse.py
0 → 100644
浏览文件 @
746122f4
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
os.path
as
osp
import
sys
import
argparse
from
PIL
import
Image
from
tqdm
import
tqdm
import
imghdr
import
logging
import
pickle
import
gdal
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Data analyse and data check before training.'
)
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'Dataset directory'
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--num_classes'
,
dest
=
'num_classes'
,
help
=
'Number of classes'
,
default
=
None
,
type
=
int
)
parser
.
add_argument
(
'--separator'
,
dest
=
'separator'
,
help
=
'file list separator'
,
default
=
" "
,
type
=
str
)
parser
.
add_argument
(
'--ignore_index'
,
dest
=
'ignore_index'
,
help
=
'Ignored class index'
,
default
=
255
,
type
=
int
)
if
len
(
sys
.
argv
)
==
1
:
parser
.
print_help
()
sys
.
exit
(
1
)
return
parser
.
parse_args
()
def
read_img
(
img_path
):
img_format
=
imghdr
.
what
(
img_path
)
name
,
ext
=
osp
.
splitext
(
img_path
)
if
img_format
==
'tiff'
or
ext
==
'.img'
:
dataset
=
gdal
.
Open
(
img_path
)
if
dataset
==
None
:
raise
Exception
(
'Can not open'
,
img_path
)
im_data
=
dataset
.
ReadAsArray
()
return
im_data
.
transpose
((
1
,
2
,
0
))
elif
ext
==
'.npy'
:
return
np
.
load
(
img_path
)
else
:
raise
Exception
(
'Not support {} image format!'
.
format
(
ext
))
def
img_pixel_statistics
(
img
,
img_value_num
,
img_min_value
,
img_max_value
):
channel
=
img
.
shape
[
2
]
means
=
np
.
zeros
(
channel
)
stds
=
np
.
zeros
(
channel
)
for
k
in
range
(
channel
):
img_k
=
img
[:,
:,
k
]
# count mean, std
means
[
k
]
=
np
.
mean
(
img_k
)
stds
[
k
]
=
np
.
std
(
img_k
)
# count min, max
min_value
=
np
.
min
(
img_k
)
max_value
=
np
.
max
(
img_k
)
if
img_max_value
[
k
]
<
max_value
:
img_max_value
[
k
]
=
max_value
if
img_min_value
[
k
]
>
min_value
:
img_min_value
[
k
]
=
min_value
# count the distribution of image value, value number
unique
,
counts
=
np
.
unique
(
img_k
,
return_counts
=
True
)
add_num
=
[]
max_unique
=
np
.
max
(
unique
)
add_len
=
max_unique
-
len
(
img_value_num
[
k
])
+
1
if
add_len
>
0
:
img_value_num
[
k
]
+=
([
0
]
*
add_len
)
for
i
in
range
(
len
(
unique
)):
value
=
unique
[
i
]
img_value_num
[
k
][
value
]
+=
counts
[
i
]
img_value_num
[
k
]
+=
add_num
return
means
,
stds
,
img_min_value
,
img_max_value
,
img_value_num
def
data_distribution_statistics
(
data_dir
,
img_value_num
,
logger
):
"""count the distribution of image value, value number
"""
logger
.
info
(
"
\n
-----------------------------
\n
The whole dataset statistics..."
)
if
not
img_value_num
:
return
logger
.
info
(
"
\n
Image pixel statistics:"
)
total_ratio
=
[]
[
total_ratio
.
append
([])
for
i
in
range
(
len
(
img_value_num
))]
for
k
in
range
(
len
(
img_value_num
)):
total_num
=
sum
(
img_value_num
[
k
])
total_ratio
[
k
]
=
[
i
/
total_num
for
i
in
img_value_num
[
k
]]
total_ratio
[
k
]
=
np
.
around
(
total_ratio
[
k
],
decimals
=
4
)
with
open
(
os
.
path
.
join
(
data_dir
,
'img_pixel_statistics.pkl'
),
'wb'
)
as
f
:
pickle
.
dump
([
total_ratio
,
img_value_num
],
f
)
def
data_range_statistics
(
img_min_value
,
img_max_value
,
logger
):
"""print min value, max value
"""
logger
.
info
(
"value range:
\n
img_min_value = {}
\n
img_max_value = {}"
.
format
(
img_min_value
,
img_max_value
))
def
cal_normalize_coefficient
(
total_means
,
total_stds
,
total_img_num
,
logger
):
"""count mean, std
"""
total_means
=
total_means
/
total_img_num
total_stds
=
total_stds
/
total_img_num
logger
.
info
(
"
\n
Count the channel-by-channel mean and std of the image:
\n
"
"mean = {}
\n
std = {}"
.
format
(
total_means
,
total_stds
))
def
error_print
(
str
):
return
""
.
join
([
"
\n
NOT PASS "
,
str
])
def
correct_print
(
str
):
return
""
.
join
([
"
\n
PASS "
,
str
])
def
pil_imread
(
file_path
):
"""read pseudo-color label"""
im
=
Image
.
open
(
file_path
)
return
np
.
asarray
(
im
)
def
get_img_shape_range
(
img
,
max_width
,
max_height
,
min_width
,
min_height
):
"""获取图片最大和最小宽高"""
img_shape
=
img
.
shape
height
,
width
=
img_shape
[
0
],
img_shape
[
1
]
max_height
=
max
(
height
,
max_height
)
max_width
=
max
(
width
,
max_width
)
min_height
=
min
(
height
,
min_height
)
min_width
=
min
(
width
,
min_width
)
return
max_width
,
max_height
,
min_width
,
min_height
def
get_img_channel_num
(
img
,
img_channels
):
"""获取图像的通道数"""
img_shape
=
img
.
shape
if
img_shape
[
-
1
]
not
in
img_channels
:
img_channels
.
append
(
img_shape
[
-
1
])
return
img_channels
def
is_label_single_channel
(
label
):
"""判断标签是否为灰度图"""
label_shape
=
label
.
shape
if
len
(
label_shape
)
==
2
:
return
True
else
:
return
False
def
image_label_shape_check
(
img
,
label
):
"""
验证图像和标注的大小是否匹配
"""
flag
=
True
img_height
=
img
.
shape
[
0
]
img_width
=
img
.
shape
[
1
]
label_height
=
label
.
shape
[
0
]
label_width
=
label
.
shape
[
1
]
if
img_height
!=
label_height
or
img_width
!=
label_width
:
flag
=
False
return
flag
def
ground_truth_check
(
label
,
label_path
):
"""
验证标注图像的格式
统计标注图类别和像素数
params:
label: 标注图
label_path: 标注图路径
return:
png_format: 返回是否是png格式图片
unique: 返回标注类别
counts: 返回标注的像素数
"""
if
imghdr
.
what
(
label_path
)
==
"png"
:
png_format
=
True
else
:
png_format
=
False
unique
,
counts
=
np
.
unique
(
label
,
return_counts
=
True
)
return
png_format
,
unique
,
counts
def
sum_label_check
(
label_classes
,
num_of_each_class
,
ignore_index
,
num_classes
,
total_label_classes
,
total_num_of_each_class
):
"""
统计所有标注图上的类别和每个类别的像素数
params:
label_classes: 标注类别
num_of_each_class: 各个类别的像素数目
"""
is_label_correct
=
True
if
ignore_index
in
label_classes
:
label_classes2
=
np
.
delete
(
label_classes
,
np
.
where
(
label_classes
==
ignore_index
))
else
:
label_classes2
=
label_classes
if
min
(
label_classes2
)
<
0
or
max
(
label_classes2
)
>
num_classes
-
1
:
is_label_correct
=
False
add_class
=
[]
add_num
=
[]
for
i
in
range
(
len
(
label_classes
)):
gi
=
label_classes
[
i
]
if
gi
in
total_label_classes
:
j
=
total_label_classes
.
index
(
gi
)
total_num_of_each_class
[
j
]
+=
num_of_each_class
[
i
]
else
:
add_class
.
append
(
gi
)
add_num
.
append
(
num_of_each_class
[
i
])
total_num_of_each_class
+=
add_num
total_label_classes
+=
add_class
return
is_label_correct
,
total_num_of_each_class
,
total_label_classes
def
label_class_check
(
num_classes
,
total_label_classes
,
total_num_of_each_class
,
wrong_labels
,
logger
):
"""
检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
标注类别最好从0开始,否则可能影响精度。
"""
total_ratio
=
total_num_of_each_class
/
sum
(
total_num_of_each_class
)
total_ratio
=
np
.
around
(
total_ratio
,
decimals
=
4
)
total_nc
=
sorted
(
zip
(
total_label_classes
,
total_ratio
,
total_num_of_each_class
))
if
len
(
wrong_labels
)
==
0
and
not
total_nc
[
0
][
0
]:
logger
.
info
(
correct_print
(
"label class check!"
))
else
:
logger
.
info
(
error_print
(
"label class check!"
))
if
total_nc
[
0
][
0
]:
logger
.
info
(
"Warning: label classes should start from 0"
)
if
len
(
wrong_labels
)
>
0
:
logger
.
info
(
"fatal error: label class is out of range [0, {}]"
.
format
(
num_classes
-
1
))
for
i
in
wrong_labels
:
logger
.
debug
(
i
)
return
total_nc
def
label_class_statistics
(
total_nc
,
logger
):
"""
对标注图像进行校验,输出校验结果
"""
logger
.
info
(
"
\n
Label class statistics:
\n
"
"(label class, percentage, total pixel number) = {} "
.
format
(
total_nc
))
def
shape_check
(
shape_unequal_image
,
logger
):
"""输出shape校验结果"""
if
len
(
shape_unequal_image
)
==
0
:
logger
.
info
(
correct_print
(
"shape check"
))
logger
.
info
(
"All images are the same shape as the labels"
)
else
:
logger
.
info
(
error_print
(
"shape check"
))
logger
.
info
(
"Some images are not the same shape as the labels as follow: "
)
for
i
in
shape_unequal_image
:
logger
.
debug
(
i
)
def
separator_check
(
wrong_lines
,
file_list
,
separator
,
logger
):
"""检查分割符是否复合要求"""
if
len
(
wrong_lines
)
==
0
:
logger
.
info
(
correct_print
(
file_list
.
split
(
os
.
sep
)[
-
1
]
+
" DATASET.separator check"
))
else
:
logger
.
info
(
error_print
(
file_list
.
split
(
os
.
sep
)[
-
1
]
+
" DATASET.separator check"
))
logger
.
info
(
"The following list is not separated by {}"
.
format
(
separator
))
for
i
in
wrong_lines
:
logger
.
debug
(
i
)
def
imread_check
(
imread_failed
,
logger
):
if
len
(
imread_failed
)
==
0
:
logger
.
info
(
correct_print
(
"dataset reading check"
))
logger
.
info
(
"All images can be read successfully"
)
else
:
logger
.
info
(
error_print
(
"dataset reading check"
))
logger
.
info
(
"Failed to read {} images"
.
format
(
len
(
imread_failed
)))
for
i
in
imread_failed
:
logger
.
debug
(
i
)
def
single_channel_label_check
(
label_not_single_channel
,
logger
):
if
len
(
label_not_single_channel
)
==
0
:
logger
.
info
(
correct_print
(
"label single_channel check"
))
logger
.
info
(
"All label images are single_channel"
)
else
:
logger
.
info
(
error_print
(
"label single_channel check"
))
logger
.
info
(
"{} label images are not single_channel
\n
Label pixel statistics may be insignificant"
.
format
(
len
(
label_not_single_channel
)))
for
i
in
label_not_single_channel
:
logger
.
debug
(
i
)
def
img_shape_range_statistics
(
max_width
,
min_width
,
max_height
,
min_height
,
logger
):
logger
.
info
(
"
\n
Image size statistics:"
)
logger
.
info
(
"max width = {} min width = {} max height = {} min height = {}"
.
format
(
max_width
,
min_width
,
max_height
,
min_height
))
def
img_channels_statistics
(
img_channels
,
logger
):
logger
.
info
(
"
\n
Image channels statistics
\n
Image channels = {}"
.
format
(
np
.
unique
(
img_channels
)))
def
data_analyse_and_check
(
data_dir
,
num_classes
,
separator
,
ignore_index
,
logger
):
train_file_list
=
osp
.
join
(
data_dir
,
'train.txt'
)
val_file_list
=
osp
.
join
(
data_dir
,
'val.txt'
)
test_file_list
=
osp
.
join
(
data_dir
,
'test.txt'
)
total_img_num
=
0
has_label
=
False
for
file_list
in
[
train_file_list
,
val_file_list
,
test_file_list
]:
# initialization
imread_failed
=
[]
max_width
=
0
max_height
=
0
min_width
=
sys
.
float_info
.
max
min_height
=
sys
.
float_info
.
max
label_not_single_channel
=
[]
shape_unequal_image
=
[]
wrong_labels
=
[]
wrong_lines
=
[]
total_label_classes
=
[]
total_num_of_each_class
=
[]
img_channels
=
[]
with
open
(
file_list
,
'r'
)
as
fid
:
logger
.
info
(
"
\n
-----------------------------
\n
Check {}..."
.
format
(
file_list
))
lines
=
fid
.
readlines
()
if
not
lines
:
logger
.
info
(
"File list is empty!"
)
continue
for
line
in
tqdm
(
lines
):
line
=
line
.
strip
()
parts
=
line
.
split
(
separator
)
if
len
(
parts
)
==
1
:
if
file_list
==
train_file_list
or
file_list
==
val_file_list
:
logger
.
info
(
"Train or val list must have labels!"
)
break
img_name
=
parts
img_path
=
os
.
path
.
join
(
data_dir
,
img_name
[
0
])
try
:
img
=
read_img
(
img_path
)
except
Exception
as
e
:
imread_failed
.
append
((
line
,
str
(
e
)))
continue
elif
len
(
parts
)
==
2
:
has_label
=
True
img_name
,
label_name
=
parts
[
0
],
parts
[
1
]
img_path
=
os
.
path
.
join
(
data_dir
,
img_name
)
label_path
=
os
.
path
.
join
(
data_dir
,
label_name
)
try
:
img
=
read_img
(
img_path
)
label
=
pil_imread
(
label_path
)
except
Exception
as
e
:
imread_failed
.
append
((
line
,
str
(
e
)))
continue
is_single_channel
=
is_label_single_channel
(
label
)
if
not
is_single_channel
:
label_not_single_channel
.
append
(
line
)
continue
is_equal_img_label_shape
=
image_label_shape_check
(
img
,
label
)
if
not
is_equal_img_label_shape
:
shape_unequal_image
.
append
(
line
)
png_format
,
label_classes
,
num_of_each_class
=
ground_truth_check
(
label
,
label_path
)
is_label_correct
,
total_num_of_each_class
,
total_label_classes
=
sum_label_check
(
label_classes
,
num_of_each_class
,
ignore_index
,
num_classes
,
total_label_classes
,
total_num_of_each_class
)
if
not
is_label_correct
:
wrong_labels
.
append
(
line
)
else
:
wrong_lines
.
append
(
lines
)
continue
if
total_img_num
==
0
:
channel
=
img
.
shape
[
2
]
total_means
=
np
.
zeros
(
channel
)
total_stds
=
np
.
zeros
(
channel
)
img_min_value
=
[
sys
.
float_info
.
max
]
*
channel
img_max_value
=
[
0
]
*
channel
img_value_num
=
[]
[
img_value_num
.
append
([])
for
i
in
range
(
channel
)]
means
,
stds
,
img_min_value
,
img_max_value
,
img_value_num
=
img_pixel_statistics
(
img
,
img_value_num
,
img_min_value
,
img_max_value
)
total_means
+=
means
total_stds
+=
stds
max_width
,
max_height
,
min_width
,
min_height
=
get_img_shape_range
(
img
,
max_width
,
max_height
,
min_width
,
min_height
)
img_channels
=
get_img_channel_num
(
img
,
img_channels
)
total_img_num
+=
1
# data check
separator_check
(
wrong_lines
,
file_list
,
separator
,
logger
)
imread_check
(
imread_failed
,
logger
)
if
has_label
:
single_channel_label_check
(
label_not_single_channel
,
logger
)
shape_check
(
shape_unequal_image
,
logger
)
total_nc
=
label_class_check
(
num_classes
,
total_label_classes
,
total_num_of_each_class
,
wrong_labels
,
logger
)
# data analyse on train, validation, test set.
img_channels_statistics
(
img_channels
,
logger
)
img_shape_range_statistics
(
max_width
,
min_width
,
max_height
,
min_height
,
logger
)
if
has_label
:
label_class_statistics
(
total_nc
,
logger
)
# data analyse on the whole dataset.
data_range_statistics
(
img_min_value
,
img_max_value
,
logger
)
data_distribution_statistics
(
data_dir
,
img_value_num
,
logger
)
cal_normalize_coefficient
(
total_means
,
total_stds
,
total_img_num
,
logger
)
def
main
():
args
=
parse_args
()
data_dir
=
args
.
data_dir
ignore_index
=
args
.
ignore_index
num_classes
=
args
.
num_classes
separator
=
args
.
separator
logger
=
logging
.
getLogger
()
logger
.
setLevel
(
'DEBUG'
)
BASIC_FORMAT
=
"%(message)s"
formatter
=
logging
.
Formatter
(
BASIC_FORMAT
)
sh
=
logging
.
StreamHandler
()
sh
.
setFormatter
(
formatter
)
sh
.
setLevel
(
'INFO'
)
th
=
logging
.
FileHandler
(
os
.
path
.
join
(
data_dir
,
'data_analyse_and_check.log'
),
'w'
)
th
.
setFormatter
(
formatter
)
logger
.
addHandler
(
sh
)
logger
.
addHandler
(
th
)
data_analyse_and_check
(
data_dir
,
num_classes
,
separator
,
ignore_index
,
logger
)
print
(
"
\n
Detailed error information can be viewed in {}."
.
format
(
os
.
path
.
join
(
data_dir
,
'data_analyse_and_check.log'
)))
if
__name__
==
"__main__"
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录