Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
9a5a9f45
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9a5a9f45
编写于
4月 30, 2020
作者:
L
LutaoChu
提交者:
GitHub
4月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add RemoteSensing scenario (#230)
上级
0a78a8cf
变更
21
展开全部
隐藏空白更改
内联
并排
Showing
21 changed file
with
3502 addition
and
0 deletion
+3502
-0
contrib/RemoteSensing/__init__.py
contrib/RemoteSensing/__init__.py
+24
-0
contrib/RemoteSensing/main.py
contrib/RemoteSensing/main.py
+108
-0
contrib/RemoteSensing/models/__init__.py
contrib/RemoteSensing/models/__init__.py
+2
-0
contrib/RemoteSensing/models/base.py
contrib/RemoteSensing/models/base.py
+406
-0
contrib/RemoteSensing/models/load_model.py
contrib/RemoteSensing/models/load_model.py
+95
-0
contrib/RemoteSensing/models/unet.py
contrib/RemoteSensing/models/unet.py
+316
-0
contrib/RemoteSensing/nets/__init__.py
contrib/RemoteSensing/nets/__init__.py
+1
-0
contrib/RemoteSensing/nets/libs.py
contrib/RemoteSensing/nets/libs.py
+219
-0
contrib/RemoteSensing/nets/loss.py
contrib/RemoteSensing/nets/loss.py
+115
-0
contrib/RemoteSensing/nets/unet.py
contrib/RemoteSensing/nets/unet.py
+268
-0
contrib/RemoteSensing/readers/__init__.py
contrib/RemoteSensing/readers/__init__.py
+15
-0
contrib/RemoteSensing/readers/base.py
contrib/RemoteSensing/readers/base.py
+249
-0
contrib/RemoteSensing/readers/reader.py
contrib/RemoteSensing/readers/reader.py
+90
-0
contrib/RemoteSensing/transforms/__init__.py
contrib/RemoteSensing/transforms/__init__.py
+16
-0
contrib/RemoteSensing/transforms/ops.py
contrib/RemoteSensing/transforms/ops.py
+174
-0
contrib/RemoteSensing/transforms/transforms.py
contrib/RemoteSensing/transforms/transforms.py
+962
-0
contrib/RemoteSensing/utils/__init__.py
contrib/RemoteSensing/utils/__init__.py
+18
-0
contrib/RemoteSensing/utils/logging.py
contrib/RemoteSensing/utils/logging.py
+46
-0
contrib/RemoteSensing/utils/metrics.py
contrib/RemoteSensing/utils/metrics.py
+145
-0
contrib/RemoteSensing/utils/pretrain_weights.py
contrib/RemoteSensing/utils/pretrain_weights.py
+13
-0
contrib/RemoteSensing/utils/utils.py
contrib/RemoteSensing/utils/utils.py
+220
-0
未找到文件。
contrib/RemoteSensing/__init__.py
0 → 100644
浏览文件 @
9a5a9f45
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
utils
from
.
import
nets
from
.
import
models
from
.
import
transforms
from
.
import
readers
from
.utils.utils
import
get_environ_info
env_info
=
get_environ_info
()
log_level
=
2
contrib/RemoteSensing/main.py
0 → 100644
浏览文件 @
9a5a9f45
import
sys
import
os
import
os.path
as
osp
import
cv2
import
numpy
as
np
from
PIL
import
Image
as
Image
#================================setting========================
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
"1"
batch_size
=
4
channel
=
10
epochs
=
1
save_dir
=
'saved_model/snow2019_unet_all_channel_vertical'
data_dir
=
"../../../dataset/snow2019/all_channel_data/"
#=============================================================
sys
.
path
.
append
(
osp
.
join
(
os
.
getcwd
(),
'..'
))
import
RemoteSensing.transforms.transforms
as
T
from
RemoteSensing.readers.reader
import
Reader
from
RemoteSensing.models
import
UNet
,
load_model
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
train_list
=
osp
.
join
(
data_dir
,
'train.txt'
)
val_list
=
osp
.
join
(
data_dir
,
'val.txt'
)
label_list
=
osp
.
join
(
data_dir
,
'labels.txt'
)
os
.
system
(
'cp ./{} {}'
.
format
(
__file__
,
osp
.
join
(
save_dir
,
__file__
)))
train_transforms
=
T
.
Compose
([
T
.
RandomVerticalFlip
(
0.5
),
T
.
RandomHorizontalFlip
(
0.5
),
T
.
ResizeStepScaling
(
0.5
,
2.0
,
0.25
),
T
.
RandomPaddingCrop
(
769
),
T
.
Normalize
(
mean
=
[
0.5
]
*
channel
,
std
=
[
0.5
]
*
channel
),
])
eval_transforms
=
T
.
Compose
([
T
.
Padding
([
1049
,
1049
]),
T
.
Normalize
(
mean
=
[
0.5
]
*
channel
,
std
=
[
0.5
]
*
channel
),
])
test_transforms
=
T
.
Compose
([
T
.
Padding
([
1049
,
1049
]),
T
.
Normalize
(
mean
=
[
0.5
]
*
channel
,
std
=
[
0.5
]
*
channel
),
])
train_reader
=
Reader
(
data_dir
=
data_dir
,
file_list
=
train_list
,
label_list
=
label_list
,
transforms
=
train_transforms
,
num_workers
=
8
,
buffer_size
=
16
,
shuffle
=
True
,
parallel_method
=
'thread'
)
eval_reader
=
Reader
(
data_dir
=
data_dir
,
file_list
=
val_list
,
label_list
=
label_list
,
transforms
=
eval_transforms
,
num_workers
=
8
,
buffer_size
=
16
,
shuffle
=
False
,
parallel_method
=
'thread'
)
model
=
UNet
(
num_classes
=
2
,
input_channel
=
channel
,
use_bce_loss
=
True
,
use_dice_loss
=
True
)
model
.
train
(
num_epochs
=
epochs
,
train_reader
=
train_reader
,
train_batch_size
=
batch_size
,
eval_reader
=
eval_reader
,
save_interval_epochs
=
5
,
log_interval_steps
=
10
,
save_dir
=
save_dir
,
pretrain_weights
=
None
,
optimizer
=
None
,
learning_rate
=
0.01
,
)
# predict
model
=
load_model
(
osp
.
join
(
save_dir
,
'best_model'
))
pred_dir
=
osp
.
join
(
save_dir
,
'pred'
)
if
not
osp
.
exists
(
pred_dir
):
os
.
mkdir
(
pred_dir
)
color_map
=
[
0
,
0
,
0
,
255
,
255
,
255
]
with
open
(
val_list
)
as
f
:
lines
=
f
.
readlines
()
for
line
in
lines
:
img_path
=
line
.
split
(
' '
)[
0
]
print
(
'Predicting {}'
.
format
(
img_path
))
img_path_
=
osp
.
join
(
data_dir
,
img_path
)
pred
=
model
.
predict
(
img_path_
)
pred_name
=
osp
.
basename
(
img_path
).
rstrip
(
'npy'
)
+
'png'
pred_path
=
osp
.
join
(
pred_dir
,
pred_name
)
pred_mask
=
Image
.
fromarray
(
pred
.
astype
(
np
.
uint8
),
mode
=
'P'
)
pred_mask
.
putpalette
(
color_map
)
pred_mask
.
save
(
pred_path
)
contrib/RemoteSensing/models/__init__.py
0 → 100644
浏览文件 @
9a5a9f45
from
.load_model
import
*
from
.unet
import
*
contrib/RemoteSensing/models/base.py
0 → 100644
浏览文件 @
9a5a9f45
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
import
paddle.fluid
as
fluid
import
os
import
numpy
as
np
import
time
import
math
import
yaml
import
copy
import
json
import
functools
import
RemoteSensing.utils.logging
as
logging
import
RemoteSensing
from
collections
import
OrderedDict
from
os
import
path
as
osp
from
paddle.fluid.framework
import
Program
from
..utils.pretrain_weights
import
get_pretrain_weights
def
dict2str
(
dict_input
):
out
=
''
for
k
,
v
in
dict_input
.
items
():
try
:
v
=
round
(
float
(
v
),
6
)
except
:
pass
out
=
out
+
'{}={}, '
.
format
(
k
,
v
)
return
out
.
strip
(
', '
)
class
BaseAPI
:
def
__init__
(
self
):
# 现有的CV模型都有这个属性,而这个属且也需要在eval时用到
self
.
num_classes
=
None
self
.
labels
=
None
if
RemoteSensing
.
env_info
[
'place'
]
==
'cpu'
:
self
.
places
=
fluid
.
cpu_places
()
else
:
self
.
places
=
fluid
.
cuda_places
()
self
.
exe
=
fluid
.
Executor
(
self
.
places
[
0
])
self
.
train_prog
=
None
self
.
test_prog
=
None
self
.
parallel_train_prog
=
None
self
.
train_inputs
=
None
self
.
test_inputs
=
None
self
.
train_outputs
=
None
self
.
test_outputs
=
None
self
.
train_data_loader
=
None
self
.
eval_metrics
=
None
# 若模型是从inference model加载进来的,无法调用训练接口进行训练
self
.
trainable
=
True
# 是否使用多卡间同步BatchNorm均值和方差
self
.
sync_bn
=
False
# 当前模型状态
self
.
status
=
'Normal'
def
_get_single_card_bs
(
self
,
batch_size
):
if
batch_size
%
len
(
self
.
places
)
==
0
:
return
int
(
batch_size
//
len
(
self
.
places
))
else
:
raise
Exception
(
"Please support correct batch_size,
\
which can be divided by available cards({}) in {}"
.
format
(
RemoteSensing
.
env_info
[
'num'
],
RemoteSensing
.
env_info
[
'place'
]))
def
build_program
(
self
):
# 构建训练网络
self
.
train_inputs
,
self
.
train_outputs
=
self
.
build_net
(
mode
=
'train'
)
self
.
train_prog
=
fluid
.
default_main_program
()
startup_prog
=
fluid
.
default_startup_program
()
# 构建预测网络
self
.
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
self
.
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
self
.
test_inputs
,
self
.
test_outputs
=
self
.
build_net
(
mode
=
'test'
)
self
.
test_prog
=
self
.
test_prog
.
clone
(
for_test
=
True
)
def
arrange_transforms
(
self
,
transforms
,
mode
=
'train'
):
# 给transforms添加arrange操作
if
transforms
.
transforms
[
-
1
].
__class__
.
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
RemoteSensing
.
transforms
.
transforms
.
ArrangeSegmenter
(
mode
=
mode
)
else
:
transforms
.
transforms
.
append
(
RemoteSensing
.
transforms
.
transforms
.
ArrangeSegmenter
(
mode
=
mode
))
def
build_train_data_loader
(
self
,
reader
,
batch_size
):
# 初始化data_loader
if
self
.
train_data_loader
is
None
:
self
.
train_data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
list
(
self
.
train_inputs
.
values
()),
capacity
=
64
,
use_double_buffer
=
True
,
iterable
=
True
)
batch_size_each_gpu
=
self
.
_get_single_card_bs
(
batch_size
)
generator
=
reader
.
generator
(
batch_size
=
batch_size_each_gpu
,
drop_last
=
True
)
self
.
train_data_loader
.
set_sample_list_generator
(
reader
.
generator
(
batch_size
=
batch_size_each_gpu
),
places
=
self
.
places
)
def
net_initialize
(
self
,
startup_prog
=
None
,
pretrain_weights
=
None
,
fuse_bn
=
False
,
save_dir
=
'.'
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
):
if
hasattr
(
self
,
'backbone'
):
backbone
=
self
.
backbone
else
:
backbone
=
self
.
__class__
.
__name__
pretrain_weights
=
get_pretrain_weights
(
pretrain_weights
,
backbone
,
save_dir
)
if
startup_prog
is
None
:
startup_prog
=
fluid
.
default_startup_program
()
self
.
exe
.
run
(
startup_prog
)
if
pretrain_weights
is
not
None
:
logging
.
info
(
"Load pretrain weights from {}."
.
format
(
pretrain_weights
))
RemoteSensing
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
# 进行裁剪
if
sensitivities_file
is
not
None
:
from
.slim.prune_config
import
get_sensitivities
sensitivities_file
=
get_sensitivities
(
sensitivities_file
,
self
,
save_dir
)
from
.slim.prune
import
get_params_ratios
,
prune_program
prune_params_ratios
=
get_params_ratios
(
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
prune_program
(
self
,
prune_params_ratios
)
self
.
status
=
'Prune'
def
get_model_info
(
self
):
info
=
dict
()
info
[
'Model'
]
=
self
.
__class__
.
__name__
info
[
'_Attributes'
]
=
{}
if
'self'
in
self
.
init_params
:
del
self
.
init_params
[
'self'
]
if
'__class__'
in
self
.
init_params
:
del
self
.
init_params
[
'__class__'
]
info
[
'_init_params'
]
=
self
.
init_params
info
[
'_Attributes'
][
'num_classes'
]
=
self
.
num_classes
info
[
'_Attributes'
][
'labels'
]
=
self
.
labels
try
:
primary_metric_key
=
list
(
self
.
eval_metrics
.
keys
())[
0
]
primary_metric_value
=
float
(
self
.
eval_metrics
[
primary_metric_key
])
info
[
'_Attributes'
][
'eval_metrics'
]
=
{
primary_metric_key
:
primary_metric_value
}
except
:
pass
if
hasattr
(
self
,
'test_transforms'
):
if
self
.
test_transforms
is
not
None
:
info
[
'Transforms'
]
=
list
()
for
op
in
self
.
test_transforms
.
transforms
:
name
=
op
.
__class__
.
__name__
attr
=
op
.
__dict__
info
[
'Transforms'
].
append
({
name
:
attr
})
return
info
def
save_model
(
self
,
save_dir
):
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
fluid
.
save
(
self
.
train_prog
,
osp
.
join
(
save_dir
,
'model'
))
model_info
=
self
.
get_model_info
()
model_info
[
'status'
]
=
self
.
status
with
open
(
osp
.
join
(
save_dir
,
'model.yml'
),
encoding
=
'utf-8'
,
mode
=
'w'
)
as
f
:
yaml
.
dump
(
model_info
,
f
)
# 评估结果保存
if
hasattr
(
self
,
'eval_details'
):
with
open
(
osp
.
join
(
save_dir
,
'eval_details.json'
),
'w'
)
as
f
:
json
.
dump
(
self
.
eval_details
,
f
)
if
self
.
status
==
'Prune'
:
# 保存裁剪的shape
shapes
=
{}
for
block
in
self
.
train_prog
.
blocks
:
for
param
in
block
.
all_parameters
():
pd_var
=
fluid
.
global_scope
().
find_var
(
param
.
name
)
pd_param
=
pd_var
.
get_tensor
()
shapes
[
param
.
name
]
=
np
.
array
(
pd_param
).
shape
with
open
(
osp
.
join
(
save_dir
,
'prune.yml'
),
encoding
=
'utf-8'
,
mode
=
'w'
)
as
f
:
yaml
.
dump
(
shapes
,
f
)
# 模型保存成功的标志
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
logging
.
info
(
"Model saved in {}."
.
format
(
save_dir
))
def
export_inference_model
(
self
,
save_dir
):
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())]
test_outputs
=
list
(
self
.
test_outputs
.
values
())
if
self
.
__class__
.
__name__
==
'MaskRCNN'
:
from
RemoteSensing.utils.save
import
save_mask_inference_model
save_mask_inference_model
(
dirname
=
save_dir
,
executor
=
self
.
exe
,
params_filename
=
'__params__'
,
feeded_var_names
=
test_input_names
,
target_vars
=
test_outputs
,
main_program
=
self
.
test_prog
)
else
:
fluid
.
io
.
save_inference_model
(
dirname
=
save_dir
,
executor
=
self
.
exe
,
params_filename
=
'__params__'
,
feeded_var_names
=
test_input_names
,
target_vars
=
test_outputs
,
main_program
=
self
.
test_prog
)
model_info
=
self
.
get_model_info
()
model_info
[
'status'
]
=
'Infer'
# 保存模型输出的变量描述
model_info
[
'_ModelInputsOutputs'
]
=
dict
()
model_info
[
'_ModelInputsOutputs'
][
'test_inputs'
]
=
[
[
k
,
v
.
name
]
for
k
,
v
in
self
.
test_inputs
.
items
()
]
model_info
[
'_ModelInputsOutputs'
][
'test_outputs'
]
=
[
[
k
,
v
.
name
]
for
k
,
v
in
self
.
test_outputs
.
items
()
]
with
open
(
osp
.
join
(
save_dir
,
'model.yml'
),
encoding
=
'utf-8'
,
mode
=
'w'
)
as
f
:
yaml
.
dump
(
model_info
,
f
)
# 模型保存成功的标志
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
def
train_loop
(
self
,
num_epochs
,
train_reader
,
train_batch_size
,
eval_reader
=
None
,
save_interval_epochs
=
1
,
log_interval_steps
=
10
,
save_dir
=
'output'
,
use_vdl
=
False
):
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
if
use_vdl
:
from
visualdl
import
LogWriter
vdl_logdir
=
osp
.
join
(
save_dir
,
'vdl_log'
)
# 给transform添加arrange操作
self
.
arrange_transforms
(
transforms
=
train_reader
.
transforms
,
mode
=
'train'
)
# 构建train_data_loader
self
.
build_train_data_loader
(
reader
=
train_reader
,
batch_size
=
train_batch_size
)
if
eval_reader
is
not
None
:
self
.
eval_transforms
=
eval_reader
.
transforms
self
.
test_transforms
=
copy
.
deepcopy
(
eval_reader
.
transforms
)
# 获取实时变化的learning rate
lr
=
self
.
optimizer
.
_learning_rate
if
isinstance
(
lr
,
fluid
.
framework
.
Variable
):
self
.
train_outputs
[
'lr'
]
=
lr
# 在多卡上跑训练
if
self
.
parallel_train_prog
is
None
:
build_strategy
=
fluid
.
compiler
.
BuildStrategy
()
build_strategy
.
fuse_all_optimizer_ops
=
False
if
RemoteSensing
.
env_info
[
'place'
]
!=
'cpu'
and
len
(
self
.
places
)
>
1
:
build_strategy
.
sync_batch_norm
=
self
.
sync_bn
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
num_iteration_per_drop_scope
=
1
self
.
parallel_train_prog
=
fluid
.
CompiledProgram
(
self
.
train_prog
).
with_data_parallel
(
loss_name
=
self
.
train_outputs
[
'loss'
].
name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
total_num_steps
=
math
.
floor
(
train_reader
.
num_samples
/
train_batch_size
)
num_steps
=
0
time_stat
=
list
()
if
use_vdl
:
# VisualDL component
log_writer
=
LogWriter
(
vdl_logdir
,
sync_cycle
=
20
)
train_step_component
=
OrderedDict
()
eval_component
=
OrderedDict
()
best_accuracy_key
=
""
best_accuracy
=
-
1.0
best_model_epoch
=
1
for
i
in
range
(
num_epochs
):
records
=
list
()
step_start_time
=
time
.
time
()
for
step
,
data
in
enumerate
(
self
.
train_data_loader
()):
outputs
=
self
.
exe
.
run
(
self
.
parallel_train_prog
,
feed
=
data
,
fetch_list
=
list
(
self
.
train_outputs
.
values
()))
outputs_avg
=
np
.
mean
(
np
.
array
(
outputs
),
axis
=
1
)
records
.
append
(
outputs_avg
)
# 训练完成剩余时间预估
current_time
=
time
.
time
()
step_cost_time
=
current_time
-
step_start_time
step_start_time
=
current_time
if
len
(
time_stat
)
<
20
:
time_stat
.
append
(
step_cost_time
)
else
:
time_stat
[
num_steps
%
20
]
=
step_cost_time
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
)
*
np
.
mean
(
time_stat
)
eta_h
=
math
.
floor
(
eta
/
3600
)
eta_m
=
math
.
floor
((
eta
-
eta_h
*
3600
)
/
60
)
eta_s
=
int
(
eta
-
eta_h
*
3600
-
eta_m
*
60
)
eta_str
=
"{}:{}:{}"
.
format
(
eta_h
,
eta_m
,
eta_s
)
# 每间隔log_interval_steps,输出loss信息
num_steps
+=
1
if
num_steps
%
log_interval_steps
==
0
:
step_metrics
=
OrderedDict
(
zip
(
list
(
self
.
train_outputs
.
keys
()),
outputs_avg
))
if
use_vdl
:
for
k
,
v
in
step_metrics
.
items
():
if
k
not
in
train_step_component
.
keys
():
with
log_writer
.
mode
(
'Each_Step_while_Training'
)
as
step_logger
:
train_step_component
[
k
]
=
step_logger
.
scalar
(
'Training: {}'
.
format
(
k
))
train_step_component
[
k
].
add_record
(
num_steps
,
v
)
logging
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, {}, eta={}"
.
format
(
i
+
1
,
num_epochs
,
step
+
1
,
total_num_steps
,
dict2str
(
step_metrics
),
eta_str
))
train_metrics
=
OrderedDict
(
zip
(
list
(
self
.
train_outputs
.
keys
()),
np
.
mean
(
records
,
axis
=
0
)))
logging
.
info
(
'[TRAIN] Epoch {} finished, {} .'
.
format
(
i
+
1
,
dict2str
(
train_metrics
)))
# 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
if
(
i
+
1
)
%
save_interval_epochs
==
0
or
i
==
num_epochs
-
1
:
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
i
+
1
))
if
not
osp
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
if
eval_reader
is
not
None
:
# 检测目前仅支持单卡评估,训练数据batch大小与显卡数量之商为验证数据batch大小。
eval_batch_size
=
train_batch_size
self
.
eval_metrics
,
self
.
eval_details
=
self
.
evaluate
(
eval_reader
=
eval_reader
,
batch_size
=
eval_batch_size
,
verbose
=
True
,
epoch_id
=
i
+
1
,
return_details
=
True
)
logging
.
info
(
'[EVAL] Finished, Epoch={}, {} .'
.
format
(
i
+
1
,
dict2str
(
self
.
eval_metrics
)))
# 保存最优模型
best_accuracy_key
=
list
(
self
.
eval_metrics
.
keys
())[
0
]
current_accuracy
=
self
.
eval_metrics
[
best_accuracy_key
]
if
current_accuracy
>
best_accuracy
:
best_accuracy
=
current_accuracy
best_model_epoch
=
i
+
1
best_model_dir
=
osp
.
join
(
save_dir
,
"best_model"
)
self
.
save_model
(
save_dir
=
best_model_dir
)
if
use_vdl
:
for
k
,
v
in
self
.
eval_metrics
.
items
():
if
isinstance
(
v
,
list
):
continue
if
isinstance
(
v
,
np
.
ndarray
):
if
v
.
size
>
1
:
continue
if
k
not
in
eval_component
:
with
log_writer
.
mode
(
'Each_Epoch_on_Eval_Data'
)
as
eval_logger
:
eval_component
[
k
]
=
eval_logger
.
scalar
(
'Evaluation: {}'
.
format
(
k
))
eval_component
[
k
].
add_record
(
i
+
1
,
v
)
self
.
save_model
(
save_dir
=
current_save_dir
)
logging
.
info
(
'Current evaluated best model in eval_reader is epoch_{}, {}={}'
.
format
(
best_model_epoch
,
best_accuracy_key
,
best_accuracy
))
contrib/RemoteSensing/models/load_model.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
yaml
import
os.path
as
osp
import
six
import
copy
from
collections
import
OrderedDict
import
paddle.fluid
as
fluid
from
paddle.fluid.framework
import
Parameter
from
..utils
import
logging
import
RemoteSensing
def
load_model
(
model_dir
):
if
not
osp
.
exists
(
osp
.
join
(
model_dir
,
"model.yml"
)):
raise
Exception
(
"There's not model.yml in {}"
.
format
(
model_dir
))
with
open
(
osp
.
join
(
model_dir
,
"model.yml"
))
as
f
:
info
=
yaml
.
load
(
f
.
read
(),
Loader
=
yaml
.
Loader
)
status
=
info
[
'status'
]
if
not
hasattr
(
RemoteSensing
.
models
,
info
[
'Model'
]):
raise
Exception
(
"There's no attribute {} in RemoteSensing.models"
.
format
(
info
[
'Model'
]))
model
=
getattr
(
RemoteSensing
.
models
,
info
[
'Model'
])(
**
info
[
'_init_params'
])
if
status
==
"Normal"
or
\
status
==
"Prune"
:
startup_prog
=
fluid
.
Program
()
model
.
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
model
.
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
.
test_inputs
,
model
.
test_outputs
=
model
.
build_net
(
mode
=
'test'
)
model
.
test_prog
=
model
.
test_prog
.
clone
(
for_test
=
True
)
model
.
exe
.
run
(
startup_prog
)
if
status
==
"Prune"
:
from
.slim.prune
import
update_program
model
.
test_prog
=
update_program
(
model
.
test_prog
,
model_dir
,
model
.
places
[
0
])
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
load_dict
=
pickle
.
load
(
f
)
fluid
.
io
.
set_program_state
(
model
.
test_prog
,
load_dict
)
elif
status
==
"Infer"
or
\
status
==
"Quant"
:
[
prog
,
input_names
,
outputs
]
=
fluid
.
io
.
load_inference_model
(
model_dir
,
model
.
exe
,
params_filename
=
'__params__'
)
model
.
test_prog
=
prog
test_outputs_info
=
info
[
'_ModelInputsOutputs'
][
'test_outputs'
]
model
.
test_inputs
=
OrderedDict
()
model
.
test_outputs
=
OrderedDict
()
for
name
in
input_names
:
model
.
test_inputs
[
name
]
=
model
.
test_prog
.
global_block
().
var
(
name
)
for
i
,
out
in
enumerate
(
outputs
):
var_desc
=
test_outputs_info
[
i
]
model
.
test_outputs
[
var_desc
[
0
]]
=
out
if
'Transforms'
in
info
:
model
.
test_transforms
=
build_transforms
(
info
[
'Transforms'
])
model
.
eval_transforms
=
copy
.
deepcopy
(
model
.
test_transforms
)
if
'_Attributes'
in
info
:
for
k
,
v
in
info
[
'_Attributes'
].
items
():
if
k
in
model
.
__dict__
:
model
.
__dict__
[
k
]
=
v
logging
.
info
(
"Model[{}] loaded."
.
format
(
info
[
'Model'
]))
return
model
def
build_transforms
(
transforms_info
):
from
..transforms
import
transforms
as
T
transforms
=
list
()
for
op_info
in
transforms_info
:
op_name
=
list
(
op_info
.
keys
())[
0
]
op_attr
=
op_info
[
op_name
]
if
not
hasattr
(
T
,
op_name
):
raise
Exception
(
"There's no operator named '{}' in transforms"
.
format
(
op_name
))
transforms
.
append
(
getattr
(
T
,
op_name
)(
**
op_attr
))
eval_transforms
=
T
.
Compose
(
transforms
)
return
eval_transforms
contrib/RemoteSensing/models/unet.py
0 → 100644
浏览文件 @
9a5a9f45
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
import
os.path
as
osp
import
numpy
as
np
import
math
import
cv2
import
paddle.fluid
as
fluid
import
RemoteSensing
import
RemoteSensing.utils.logging
as
logging
from
collections
import
OrderedDict
from
.base
import
BaseAPI
from
..utils.metrics
import
ConfusionMatrix
class
UNet
(
BaseAPI
):
"""实现UNet网络的构建并进行训练、评估、预测和模型导出。
Args:
num_classes (int): 类别数。
upsample_mode (str): UNet decode时采用的上采样方式,取值为'bilinear'时利用双线行差值进行上菜样,
当输入其他选项时则利用反卷积进行上菜样,默认为'bilinear'。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。
"""
def
__init__
(
self
,
num_classes
=
2
,
upsample_mode
=
'bilinear'
,
input_channel
=
3
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
class_weight
=
None
,
ignore_index
=
255
):
self
.
init_params
=
locals
()
super
(
UNet
,
self
).
__init__
()
# dice_loss或bce_loss只适用两类分割中
if
num_classes
>
2
and
(
use_bce_loss
or
use_dice_loss
):
raise
ValueError
(
"dice loss and bce loss is only applicable to binary classfication"
)
if
class_weight
is
not
None
:
if
isinstance
(
class_weight
,
list
):
if
len
(
class_weight
)
!=
num_classes
:
raise
ValueError
(
"Length of class_weight should be equal to number of classes"
)
elif
isinstance
(
class_weight
,
str
):
if
class_weight
.
lower
()
!=
'dynamic'
:
raise
ValueError
(
"if class_weight is string, must be dynamic!"
)
else
:
raise
TypeError
(
'Expect class_weight is a list or string but receive {}'
.
format
(
type
(
class_weight
)))
self
.
num_classes
=
num_classes
self
.
upsample_mode
=
upsample_mode
self
.
input_channel
=
input_channel
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
class_weight
=
class_weight
self
.
ignore_index
=
ignore_index
self
.
labels
=
None
# 若模型是从inference model加载进来的,无法调用训练接口进行训练
self
.
trainable
=
True
def
build_net
(
self
,
mode
=
'train'
):
model
=
RemoteSensing
.
nets
.
UNet
(
self
.
num_classes
,
mode
=
mode
,
upsample_mode
=
self
.
upsample_mode
,
input_channel
=
self
.
input_channel
,
use_bce_loss
=
self
.
use_bce_loss
,
use_dice_loss
=
self
.
use_dice_loss
,
class_weight
=
self
.
class_weight
,
ignore_index
=
self
.
ignore_index
)
inputs
=
model
.
generate_inputs
()
model_out
=
model
.
build_net
(
inputs
)
outputs
=
OrderedDict
()
if
mode
==
'train'
:
self
.
optimizer
.
minimize
(
model_out
)
outputs
[
'loss'
]
=
model_out
elif
mode
==
'eval'
:
outputs
[
'loss'
]
=
model_out
[
0
]
outputs
[
'pred'
]
=
model_out
[
1
]
outputs
[
'label'
]
=
model_out
[
2
]
outputs
[
'mask'
]
=
model_out
[
3
]
else
:
outputs
[
'pred'
]
=
model_out
[
0
]
outputs
[
'logit'
]
=
model_out
[
1
]
return
inputs
,
outputs
def
default_optimizer
(
self
,
learning_rate
,
num_epochs
,
num_steps_each_epoch
,
lr_decay_power
=
0.9
):
decay_step
=
num_epochs
*
num_steps_each_epoch
lr_decay
=
fluid
.
layers
.
polynomial_decay
(
learning_rate
,
decay_step
,
end_learning_rate
=
0
,
power
=
lr_decay_power
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
lr_decay
,
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
4e-05
))
return
optimizer
def
train
(
self
,
num_epochs
,
train_reader
,
train_batch_size
=
2
,
eval_reader
=
None
,
save_interval_epochs
=
1
,
log_interval_steps
=
2
,
save_dir
=
'output'
,
pretrain_weights
=
'COCO'
,
optimizer
=
None
,
learning_rate
=
0.01
,
lr_decay_power
=
0.9
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
):
"""训练。
Args:
num_epochs (int): 训练迭代轮数。
train_reader (RemoteSensing.readers): 训练数据读取器。
train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
eval_reader (RemoteSensing.readers): 评估数据读取器。
save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
save_dir (str): 模型保存路径。默认'output'。
pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'COCO',
则自动下载在COCO图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'COCO'。
optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用
fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
learning_rate (float): 默认优化器的初始学习率。默认0.01。
lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
Raises:
ValueError: 模型从inference model进行加载。
"""
if
not
self
.
trainable
:
raise
ValueError
(
"Model is not trainable since it was loaded from a inference model."
)
self
.
labels
=
train_reader
.
labels
if
optimizer
is
None
:
num_steps_each_epoch
=
train_reader
.
num_samples
//
train_batch_size
optimizer
=
self
.
default_optimizer
(
learning_rate
=
learning_rate
,
num_epochs
=
num_epochs
,
num_steps_each_epoch
=
num_steps_each_epoch
,
lr_decay_power
=
lr_decay_power
)
self
.
optimizer
=
optimizer
# 构建训练、验证、预测网络
self
.
build_program
()
# 初始化网络权重
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
# 训练
self
.
train_loop
(
num_epochs
=
num_epochs
,
train_reader
=
train_reader
,
train_batch_size
=
train_batch_size
,
eval_reader
=
eval_reader
,
save_interval_epochs
=
save_interval_epochs
,
log_interval_steps
=
log_interval_steps
,
save_dir
=
save_dir
,
use_vdl
=
use_vdl
)
def
evaluate
(
self
,
eval_reader
,
batch_size
=
1
,
verbose
=
True
,
epoch_id
=
None
,
return_details
=
False
):
"""评估。
Args:
eval_reader (RemoteSensing.readers): 评估数据读取器。
batch_size (int): 评估时的batch大小。默认1。
verbose (bool): 是否打印日志。默认True。
epoch_id (int): 当前评估模型所在的训练轮数。
return_details (bool): 是否返回详细信息。默认False。
Returns:
dict: 当return_details为False时,返回dict。包含关键字:'miou'、'categore_iou'、'macc'、
'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
包含关键字:'confusion_matrix',表示评估的混淆矩阵。
"""
self
.
arrange_transforms
(
transforms
=
eval_reader
.
transforms
,
mode
=
'eval'
)
total_steps
=
math
.
ceil
(
eval_reader
.
num_samples
*
1.0
/
batch_size
)
conf_mat
=
ConfusionMatrix
(
self
.
num_classes
,
streaming
=
True
)
data_generator
=
eval_reader
.
generator
(
batch_size
=
batch_size
,
drop_last
=
False
)
if
not
hasattr
(
self
,
'parallel_test_prog'
):
self
.
parallel_test_prog
=
fluid
.
CompiledProgram
(
self
.
test_prog
).
with_data_parallel
(
share_vars_from
=
self
.
parallel_train_prog
)
batch_size_each_gpu
=
self
.
_get_single_card_bs
(
batch_size
)
for
step
,
data
in
enumerate
(
data_generator
()):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
])
num_samples
=
images
.
shape
[
0
]
if
num_samples
<
batch_size
:
num_pad_samples
=
batch_size
-
num_samples
pad_images
=
np
.
tile
(
images
[
0
:
1
],
(
num_pad_samples
,
1
,
1
,
1
))
images
=
np
.
concatenate
([
images
,
pad_images
])
feed_data
=
{
'image'
:
images
}
outputs
=
self
.
exe
.
run
(
self
.
parallel_test_prog
,
feed
=
feed_data
,
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
True
)
pred
=
outputs
[
0
]
if
num_samples
<
batch_size
:
pred
=
pred
[
0
:
num_samples
]
mask
=
labels
!=
self
.
ignore_index
conf_mat
.
calculate
(
pred
=
pred
,
label
=
labels
,
ignore
=
mask
)
_
,
iou
=
conf_mat
.
mean_iou
()
if
verbose
:
logging
.
info
(
"[EVAL] Epoch={}, Step={}/{}, iou={}"
.
format
(
epoch_id
,
step
+
1
,
total_steps
,
iou
))
category_iou
,
miou
=
conf_mat
.
mean_iou
()
category_acc
,
macc
=
conf_mat
.
accuracy
()
metrics
=
OrderedDict
(
zip
([
'miou'
,
'categore_iou'
,
'macc'
,
'category_acc'
,
'kappa'
],
[
miou
,
category_iou
,
macc
,
category_acc
,
conf_mat
.
kappa
()]))
if
return_details
:
eval_details
=
{
'confusion_matrix'
:
conf_mat
.
confusion_matrix
.
tolist
()
}
return
metrics
,
eval_details
return
metrics
def
predict
(
self
,
im_file
,
transforms
=
None
):
"""预测。
Args:
img_file(str): 预测图像路径。
transforms(RemoteSensing.transforms): 数据预处理操作。
Returns:
np.ndarray: 预测结果灰度图。
"""
if
transforms
is
None
and
not
hasattr
(
self
,
'test_transforms'
):
raise
Exception
(
"transforms need to be defined, now is None."
)
if
transforms
is
not
None
:
self
.
arrange_transforms
(
transforms
=
transforms
,
mode
=
'test'
)
im
,
im_info
=
transforms
(
im_file
)
else
:
self
.
arrange_transforms
(
transforms
=
self
.
test_transforms
,
mode
=
'test'
)
im
,
im_info
=
self
.
test_transforms
(
im_file
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()))
pred
=
result
[
0
]
pred
=
np
.
squeeze
(
pred
).
astype
(
np
.
uint8
)
keys
=
list
(
im_info
.
keys
())
for
k
in
keys
[::
-
1
]:
if
k
==
'shape_before_resize'
:
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
pred
=
cv2
.
resize
(
pred
,
(
w
,
h
),
cv2
.
INTER_NEAREST
)
elif
k
==
'shape_before_padding'
:
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
pred
=
pred
[
0
:
h
,
0
:
w
]
return
pred
contrib/RemoteSensing/nets/__init__.py
0 → 100644
浏览文件 @
9a5a9f45
from
.unet
import
UNet
contrib/RemoteSensing/nets/libs.py
0 → 100644
浏览文件 @
9a5a9f45
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
contextlib
bn_regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)
name_scope
=
""
@
contextlib
.
contextmanager
def
scope
(
name
):
global
name_scope
bk
=
name_scope
name_scope
=
name_scope
+
name
+
'/'
yield
name_scope
=
bk
def
max_pool
(
input
,
kernel
,
stride
,
padding
):
data
=
fluid
.
layers
.
pool2d
(
input
,
pool_size
=
kernel
,
pool_type
=
'max'
,
pool_stride
=
stride
,
pool_padding
=
padding
)
return
data
def
avg_pool
(
input
,
kernel
,
stride
,
padding
=
0
):
data
=
fluid
.
layers
.
pool2d
(
input
,
pool_size
=
kernel
,
pool_type
=
'avg'
,
pool_stride
=
stride
,
pool_padding
=
padding
)
return
data
def
group_norm
(
input
,
G
,
eps
=
1e-5
,
param_attr
=
None
,
bias_attr
=
None
):
N
,
C
,
H
,
W
=
input
.
shape
if
C
%
G
!=
0
:
for
d
in
range
(
10
):
for
t
in
[
d
,
-
d
]:
if
G
+
t
<=
0
:
continue
if
C
%
(
G
+
t
)
==
0
:
G
=
G
+
t
break
if
C
%
G
==
0
:
break
assert
C
%
G
==
0
,
"group can not divide channle"
x
=
fluid
.
layers
.
group_norm
(
input
,
groups
=
G
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
name
=
name_scope
+
'group_norm'
)
return
x
def
bn
(
*
args
,
norm_type
=
'bn'
,
eps
=
1e-5
,
bn_momentum
=
0.99
,
group_norm
=
32
,
**
kargs
):
if
norm_type
==
'bn'
:
with
scope
(
'BatchNorm'
):
return
fluid
.
layers
.
batch_norm
(
*
args
,
epsilon
=
eps
,
momentum
=
bn_momentum
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
bn_regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
bn_regularizer
),
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
,
**
kargs
)
elif
norm_type
==
'gn'
:
with
scope
(
'GroupNorm'
):
return
group_norm
(
args
[
0
],
group_norm
,
eps
=
eps
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
bn_regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
bn_regularizer
))
else
:
raise
Exception
(
"Unsupport norm type:"
+
norm_type
)
def
bn_relu
(
data
,
norm_type
=
'bn'
,
eps
=
1e-5
):
return
fluid
.
layers
.
relu
(
bn
(
data
,
norm_type
=
norm_type
,
eps
=
eps
))
def
relu
(
data
):
return
fluid
.
layers
.
relu
(
data
)
def
conv
(
*
args
,
**
kargs
):
kargs
[
'param_attr'
]
=
name_scope
+
'weights'
if
'bias_attr'
in
kargs
and
kargs
[
'bias_attr'
]:
kargs
[
'bias_attr'
]
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'biases'
,
regularizer
=
None
,
initializer
=
fluid
.
initializer
.
ConstantInitializer
(
value
=
0.0
))
else
:
kargs
[
'bias_attr'
]
=
False
return
fluid
.
layers
.
conv2d
(
*
args
,
**
kargs
)
def
deconv
(
*
args
,
**
kargs
):
kargs
[
'param_attr'
]
=
name_scope
+
'weights'
if
'bias_attr'
in
kargs
and
kargs
[
'bias_attr'
]:
kargs
[
'bias_attr'
]
=
name_scope
+
'biases'
else
:
kargs
[
'bias_attr'
]
=
False
return
fluid
.
layers
.
conv2d_transpose
(
*
args
,
**
kargs
)
def
separate_conv
(
input
,
channel
,
stride
,
filter
,
dilation
=
1
,
act
=
None
,
eps
=
1e-5
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
),
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
with
scope
(
'depthwise'
):
input
=
conv
(
input
,
input
.
shape
[
1
],
filter
,
stride
,
groups
=
input
.
shape
[
1
],
padding
=
(
filter
//
2
)
*
dilation
,
dilation
=
dilation
,
use_cudnn
=
False
,
param_attr
=
param_attr
)
input
=
bn
(
input
,
eps
=
eps
)
if
act
:
input
=
act
(
input
)
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
None
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.06
))
with
scope
(
'pointwise'
):
input
=
conv
(
input
,
channel
,
1
,
1
,
groups
=
1
,
padding
=
0
,
param_attr
=
param_attr
)
input
=
bn
(
input
,
eps
=
eps
)
if
act
:
input
=
act
(
input
)
return
input
def
conv_bn_layer
(
input
,
filter_size
,
num_filters
,
stride
,
padding
,
channels
=
None
,
num_groups
=
1
,
if_act
=
True
,
name
=
None
,
use_cudnn
=
True
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_groups
,
act
=
None
,
use_cudnn
=
use_cudnn
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_weights'
),
bias_attr
=
False
)
bn_name
=
name
+
'_bn'
bn
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
param_attr
=
fluid
.
ParamAttr
(
name
=
bn_name
+
"_scale"
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
bn_name
+
"_offset"
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
if
if_act
:
return
fluid
.
layers
.
relu6
(
bn
)
else
:
return
bn
def
sigmoid_to_softmax
(
input
):
"""
one channel to two channel
"""
logit
=
fluid
.
layers
.
sigmoid
(
input
)
logit_back
=
1
-
logit
logit
=
fluid
.
layers
.
concat
([
logit_back
,
logit
],
axis
=
1
)
return
logit
contrib/RemoteSensing/nets/loss.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
import
numpy
as
np
def
softmax_with_loss
(
logit
,
label
,
ignore_mask
=
None
,
num_classes
=
2
,
weight
=
None
,
ignore_index
=
255
):
ignore_mask
=
fluid
.
layers
.
cast
(
ignore_mask
,
'float32'
)
label
=
fluid
.
layers
.
elementwise_min
(
label
,
fluid
.
layers
.
assign
(
np
.
array
([
num_classes
-
1
],
dtype
=
np
.
int32
)))
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
logit
=
fluid
.
layers
.
reshape
(
logit
,
[
-
1
,
num_classes
])
label
=
fluid
.
layers
.
reshape
(
label
,
[
-
1
,
1
])
label
=
fluid
.
layers
.
cast
(
label
,
'int64'
)
ignore_mask
=
fluid
.
layers
.
reshape
(
ignore_mask
,
[
-
1
,
1
])
if
weight
is
None
:
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
ignore_index
,
return_softmax
=
True
)
else
:
label_one_hot
=
fluid
.
layers
.
one_hot
(
input
=
label
,
depth
=
num_classes
)
if
isinstance
(
weight
,
list
):
assert
len
(
weight
)
==
num_classes
,
"weight length must equal num of classes"
weight
=
fluid
.
layers
.
assign
(
np
.
array
([
weight
],
dtype
=
'float32'
))
elif
isinstance
(
weight
,
str
):
assert
weight
.
lower
(
)
==
'dynamic'
,
'if weight is string, must be dynamic!'
tmp
=
[]
total_num
=
fluid
.
layers
.
cast
(
fluid
.
layers
.
shape
(
label
)[
0
],
'float32'
)
for
i
in
range
(
num_classes
):
cls_pixel_num
=
fluid
.
layers
.
reduce_sum
(
label_one_hot
[:,
i
])
ratio
=
total_num
/
(
cls_pixel_num
+
1
)
tmp
.
append
(
ratio
)
weight
=
fluid
.
layers
.
concat
(
tmp
)
weight
=
weight
/
fluid
.
layers
.
reduce_sum
(
weight
)
*
num_classes
elif
isinstance
(
weight
,
fluid
.
layers
.
Variable
):
pass
else
:
raise
ValueError
(
'Expect weight is a list, string or Variable, but receive {}'
.
format
(
type
(
weight
)))
weight
=
fluid
.
layers
.
reshape
(
weight
,
[
1
,
num_classes
])
weighted_label_one_hot
=
fluid
.
layers
.
elementwise_mul
(
label_one_hot
,
weight
)
probs
=
fluid
.
layers
.
softmax
(
logit
)
loss
=
fluid
.
layers
.
cross_entropy
(
probs
,
weighted_label_one_hot
,
soft_label
=
True
,
ignore_index
=
ignore_index
)
weighted_label_one_hot
.
stop_gradient
=
True
loss
=
loss
*
ignore_mask
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
ignore_mask
)
+
0.00001
)
label
.
stop_gradient
=
True
ignore_mask
.
stop_gradient
=
True
return
avg_loss
# to change, how to appicate ignore index and ignore mask
def
dice_loss
(
logit
,
label
,
ignore_mask
=
None
,
epsilon
=
0.00001
):
if
logit
.
shape
[
1
]
!=
1
or
label
.
shape
[
1
]
!=
1
or
ignore_mask
.
shape
[
1
]
!=
1
:
raise
Exception
(
"dice loss is only applicable to one channel classfication"
)
ignore_mask
=
fluid
.
layers
.
cast
(
ignore_mask
,
'float32'
)
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
cast
(
label
,
'int64'
)
ignore_mask
=
fluid
.
layers
.
transpose
(
ignore_mask
,
[
0
,
2
,
3
,
1
])
logit
=
fluid
.
layers
.
sigmoid
(
logit
)
logit
=
logit
*
ignore_mask
label
=
label
*
ignore_mask
reduce_dim
=
list
(
range
(
1
,
len
(
logit
.
shape
)))
inse
=
fluid
.
layers
.
reduce_sum
(
logit
*
label
,
dim
=
reduce_dim
)
dice_denominator
=
fluid
.
layers
.
reduce_sum
(
logit
,
dim
=
reduce_dim
)
+
fluid
.
layers
.
reduce_sum
(
label
,
dim
=
reduce_dim
)
dice_score
=
1
-
inse
*
2
/
(
dice_denominator
+
epsilon
)
label
.
stop_gradient
=
True
ignore_mask
.
stop_gradient
=
True
return
fluid
.
layers
.
reduce_mean
(
dice_score
)
def
bce_loss
(
logit
,
label
,
ignore_mask
=
None
,
ignore_index
=
255
):
if
logit
.
shape
[
1
]
!=
1
or
label
.
shape
[
1
]
!=
1
or
ignore_mask
.
shape
[
1
]
!=
1
:
raise
Exception
(
"bce loss is only applicable to binary classfication"
)
label
=
fluid
.
layers
.
cast
(
label
,
'float32'
)
loss
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
x
=
logit
,
label
=
label
,
ignore_index
=
ignore_index
,
normalize
=
True
)
# or False
loss
=
fluid
.
layers
.
reduce_sum
(
loss
)
label
.
stop_gradient
=
True
ignore_mask
.
stop_gradient
=
True
return
loss
contrib/RemoteSensing/nets/unet.py
0 → 100644
浏览文件 @
9a5a9f45
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
paddle.fluid
as
fluid
from
.libs
import
scope
,
name_scope
from
.libs
import
bn
,
bn_relu
,
relu
from
.libs
import
conv
,
max_pool
,
deconv
from
.libs
import
sigmoid_to_softmax
from
.loss
import
softmax_with_loss
from
.loss
import
dice_loss
from
.loss
import
bce_loss
class
UNet
(
object
):
"""实现Unet模型
`"U-Net: Convolutional Networks for Biomedical Image Segmentation"
<https://arxiv.org/abs/1505.04597>`
Args:
num_classes (int): 类别数
mode (str): 网络运行模式,根据mode构建网络的输入和返回。
当mode为'train'时,输入为image(-1, 3, -1, -1)和label (-1, 1, -1, -1) 返回loss。
当mode为'train'时,输入为image (-1, 3, -1, -1)和label (-1, 1, -1, -1),返回loss,
pred (与网络输入label 相同大小的预测结果,值代表相应的类别),label,mask(非忽略值的mask,
与label相同大小,bool类型)。
当mode为'test'时,输入为image(-1, 3, -1, -1)返回pred (-1, 1, -1, -1)和
logit (-1, num_classes, -1, -1) 通道维上代表每一类的概率值。
upsample_mode (str): UNet decode时采用的上采样方式,取值为'bilinear'时利用双线行差值进行上菜样,
当输入其他选项时则利用反卷积进行上菜样,默认为'bilinear'。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。
class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。
"""
def
__init__
(
self
,
num_classes
,
mode
=
'train'
,
upsample_mode
=
'bilinear'
,
input_channel
=
3
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
class_weight
=
None
,
ignore_index
=
255
):
# dice_loss或bce_loss只适用两类分割中
if
num_classes
>
2
and
(
use_bce_loss
or
use_dice_loss
):
raise
Exception
(
"dice loss and bce loss is only applicable to binary classfication"
)
if
class_weight
is
not
None
:
if
isinstance
(
class_weight
,
list
):
if
len
(
class_weight
)
!=
num_classes
:
raise
ValueError
(
"Length of class_weight should be equal to number of classes"
)
elif
isinstance
(
class_weight
,
str
):
if
class_weight
.
lower
()
!=
'dynamic'
:
raise
ValueError
(
"if class_weight is string, must be dynamic!"
)
else
:
raise
TypeError
(
'Expect class_weight is a list or string but receive {}'
.
format
(
type
(
class_weight
)))
self
.
num_classes
=
num_classes
self
.
mode
=
mode
self
.
upsample_mode
=
upsample_mode
self
.
input_channel
=
input_channel
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
class_weight
=
class_weight
self
.
ignore_index
=
ignore_index
def
_double_conv
(
self
,
data
,
out_ch
):
param_attr
=
fluid
.
ParamAttr
(
name
=
'weights'
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
),
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
with
scope
(
"conv0"
):
data
=
bn_relu
(
conv
(
data
,
out_ch
,
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
))
with
scope
(
"conv1"
):
data
=
bn_relu
(
conv
(
data
,
out_ch
,
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
))
return
data
def
_down
(
self
,
data
,
out_ch
):
# 下采样:max_pool + 2个卷积
with
scope
(
"down"
):
data
=
max_pool
(
data
,
2
,
2
,
0
)
data
=
self
.
_double_conv
(
data
,
out_ch
)
return
data
def
_up
(
self
,
data
,
short_cut
,
out_ch
):
# 上采样:data上采样(resize或deconv), 并与short_cut concat
param_attr
=
fluid
.
ParamAttr
(
name
=
'weights'
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
),
initializer
=
fluid
.
initializer
.
XavierInitializer
(),
)
with
scope
(
"up"
):
if
self
.
upsample_mode
==
'bilinear'
:
short_cut_shape
=
fluid
.
layers
.
shape
(
short_cut
)
data
=
fluid
.
layers
.
resize_bilinear
(
data
,
short_cut_shape
[
2
:])
else
:
data
=
deconv
(
data
,
out_ch
//
2
,
filter_size
=
2
,
stride
=
2
,
padding
=
0
,
param_attr
=
param_attr
)
data
=
fluid
.
layers
.
concat
([
data
,
short_cut
],
axis
=
1
)
data
=
self
.
_double_conv
(
data
,
out_ch
)
return
data
def
_encode
(
self
,
data
):
# 编码器设置
short_cuts
=
[]
with
scope
(
"encode"
):
with
scope
(
"block1"
):
data
=
self
.
_double_conv
(
data
,
64
)
short_cuts
.
append
(
data
)
with
scope
(
"block2"
):
data
=
self
.
_down
(
data
,
128
)
short_cuts
.
append
(
data
)
with
scope
(
"block3"
):
data
=
self
.
_down
(
data
,
256
)
short_cuts
.
append
(
data
)
with
scope
(
"block4"
):
data
=
self
.
_down
(
data
,
512
)
short_cuts
.
append
(
data
)
with
scope
(
"block5"
):
data
=
self
.
_down
(
data
,
512
)
return
data
,
short_cuts
def
_decode
(
self
,
data
,
short_cuts
):
# 解码器设置,与编码器对称
with
scope
(
"decode"
):
with
scope
(
"decode1"
):
data
=
self
.
_up
(
data
,
short_cuts
[
3
],
256
)
with
scope
(
"decode2"
):
data
=
self
.
_up
(
data
,
short_cuts
[
2
],
128
)
with
scope
(
"decode3"
):
data
=
self
.
_up
(
data
,
short_cuts
[
1
],
64
)
with
scope
(
"decode4"
):
data
=
self
.
_up
(
data
,
short_cuts
[
0
],
64
)
return
data
def
_get_logit
(
self
,
data
,
num_classes
):
# 根据类别数设置最后一个卷积层输出
param_attr
=
fluid
.
ParamAttr
(
name
=
'weights'
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
),
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
with
scope
(
"logit"
):
data
=
conv
(
data
,
num_classes
,
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
return
data
def
_get_loss
(
self
,
logit
,
label
,
mask
):
avg_loss
=
0
if
not
(
self
.
use_dice_loss
or
self
.
use_bce_loss
):
avg_loss
+=
softmax_with_loss
(
logit
,
label
,
mask
,
num_classes
=
self
.
num_classes
,
weight
=
self
.
class_weight
,
ignore_index
=
self
.
ignore_index
)
else
:
if
self
.
use_dice_loss
:
avg_loss
+=
dice_loss
(
logit
,
label
,
mask
)
if
self
.
use_bce_loss
:
avg_loss
+=
bce_loss
(
logit
,
label
,
mask
,
ignore_index
=
self
.
ignore_index
)
return
avg_loss
def
generate_inputs
(
self
):
inputs
=
OrderedDict
()
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
self
.
input_channel
,
None
,
None
],
name
=
'image'
)
if
self
.
mode
==
'train'
:
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
elif
self
.
mode
==
'eval'
:
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
return
inputs
def
build_net
(
self
,
inputs
):
# 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
if
self
.
use_dice_loss
or
self
.
use_bce_loss
:
self
.
num_classes
=
1
image
=
inputs
[
'image'
]
encode_data
,
short_cuts
=
self
.
_encode
(
image
)
decode_data
=
self
.
_decode
(
encode_data
,
short_cuts
)
logit
=
self
.
_get_logit
(
decode_data
,
self
.
num_classes
)
if
self
.
num_classes
==
1
:
out
=
sigmoid_to_softmax
(
logit
)
out
=
fluid
.
layers
.
transpose
(
out
,
[
0
,
2
,
3
,
1
])
else
:
out
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
out
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
if
self
.
mode
==
'train'
:
label
=
inputs
[
'label'
]
mask
=
label
!=
self
.
ignore_index
return
self
.
_get_loss
(
logit
,
label
,
mask
)
elif
self
.
mode
==
'eval'
:
label
=
inputs
[
'label'
]
mask
=
label
!=
self
.
ignore_index
loss
=
self
.
_get_loss
(
logit
,
label
,
mask
)
return
loss
,
pred
,
label
,
mask
else
:
if
self
.
num_classes
==
1
:
logit
=
sigmoid_to_softmax
(
logit
)
else
:
logit
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
return
pred
,
logit
contrib/RemoteSensing/readers/__init__.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.reader
import
Reader
contrib/RemoteSensing/readers/base.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
threading
import
Thread
import
multiprocessing
import
collections
import
numpy
as
np
import
six
import
sys
import
copy
import
random
import
platform
import
chardet
from
..utils
import
logging
class
EndSignal
():
pass
def
is_pic
(
img_name
):
valid_suffix
=
[
'JPEG'
,
'jpeg'
,
'JPG'
,
'jpg'
,
'BMP'
,
'bmp'
,
'PNG'
,
'png'
]
suffix
=
img_name
.
split
(
'.'
)[
-
1
]
if
suffix
not
in
valid_suffix
:
return
False
return
True
def
is_valid
(
sample
):
if
sample
is
None
:
return
False
if
isinstance
(
sample
,
tuple
):
for
s
in
sample
:
if
s
is
None
:
return
False
elif
isinstance
(
s
,
np
.
ndarray
)
and
s
.
size
==
0
:
return
False
elif
isinstance
(
s
,
collections
.
Sequence
)
and
len
(
s
)
==
0
:
return
False
return
True
def
get_encoding
(
path
):
f
=
open
(
path
,
'rb'
)
data
=
f
.
read
()
file_encoding
=
chardet
.
detect
(
data
).
get
(
'encoding'
)
return
file_encoding
def
multithread_reader
(
mapper
,
reader
,
num_workers
=
4
,
buffer_size
=
1024
,
batch_size
=
8
,
drop_last
=
True
):
from
queue
import
Queue
end
=
EndSignal
()
# define a worker to read samples from reader to in_queue
def
read_worker
(
reader
,
in_queue
):
for
i
in
reader
():
in_queue
.
put
(
i
)
in_queue
.
put
(
end
)
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
def
handle_worker
(
in_queue
,
out_queue
,
mapper
):
sample
=
in_queue
.
get
()
while
not
isinstance
(
sample
,
EndSignal
):
if
len
(
sample
)
==
2
:
r
=
mapper
(
sample
[
0
],
sample
[
1
])
elif
len
(
sample
)
==
3
:
r
=
mapper
(
sample
[
0
],
sample
[
1
],
sample
[
2
])
else
:
raise
Exception
(
'The sample
\'
s length must be 2 or 3.'
)
if
is_valid
(
r
):
out_queue
.
put
(
r
)
sample
=
in_queue
.
get
()
in_queue
.
put
(
end
)
out_queue
.
put
(
end
)
def
xreader
():
in_queue
=
Queue
(
buffer_size
)
out_queue
=
Queue
(
buffer_size
)
# start a read worker in a thread
target
=
read_worker
t
=
Thread
(
target
=
target
,
args
=
(
reader
,
in_queue
))
t
.
daemon
=
True
t
.
start
()
# start several handle_workers
target
=
handle_worker
args
=
(
in_queue
,
out_queue
,
mapper
)
workers
=
[]
for
i
in
range
(
num_workers
):
worker
=
Thread
(
target
=
target
,
args
=
args
)
worker
.
daemon
=
True
workers
.
append
(
worker
)
for
w
in
workers
:
w
.
start
()
batch_data
=
[]
sample
=
out_queue
.
get
()
while
not
isinstance
(
sample
,
EndSignal
):
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
yield
batch_data
batch_data
=
[]
sample
=
out_queue
.
get
()
finish
=
1
while
finish
<
num_workers
:
sample
=
out_queue
.
get
()
if
isinstance
(
sample
,
EndSignal
):
finish
+=
1
else
:
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
yield
batch_data
batch_data
=
[]
if
not
drop_last
and
len
(
batch_data
)
!=
0
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
yield
batch_data
batch_data
=
[]
return
xreader
def
multiprocess_reader
(
mapper
,
reader
,
num_workers
=
4
,
buffer_size
=
1024
,
batch_size
=
8
,
drop_last
=
True
):
from
.shared_queue
import
SharedQueue
as
Queue
def
_read_into_queue
(
samples
,
mapper
,
queue
):
end
=
EndSignal
()
try
:
for
sample
in
samples
:
if
sample
is
None
:
raise
ValueError
(
"sample has None"
)
if
len
(
sample
)
==
2
:
result
=
mapper
(
sample
[
0
],
sample
[
1
])
elif
len
(
sample
)
==
3
:
result
=
mapper
(
sample
[
0
],
sample
[
1
],
sample
[
2
])
else
:
raise
Exception
(
'The sample
\'
s length must be 2 or 3.'
)
if
is_valid
(
result
):
queue
.
put
(
result
)
queue
.
put
(
end
)
except
:
queue
.
put
(
""
)
six
.
reraise
(
*
sys
.
exc_info
())
def
queue_reader
():
queue
=
Queue
(
buffer_size
,
memsize
=
3
*
1024
**
3
)
total_samples
=
[[]
for
i
in
range
(
num_workers
)]
for
i
,
sample
in
enumerate
(
reader
()):
index
=
i
%
num_workers
total_samples
[
index
].
append
(
sample
)
for
i
in
range
(
num_workers
):
p
=
multiprocessing
.
Process
(
target
=
_read_into_queue
,
args
=
(
total_samples
[
i
],
mapper
,
queue
))
p
.
start
()
finish_num
=
0
batch_data
=
list
()
while
finish_num
<
num_workers
:
sample
=
queue
.
get
()
if
isinstance
(
sample
,
EndSignal
):
finish_num
+=
1
elif
sample
==
""
:
raise
ValueError
(
"multiprocess reader raises an exception"
)
else
:
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
yield
batch_data
batch_data
=
[]
if
len
(
batch_data
)
!=
0
and
not
drop_last
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
yield
batch_data
batch_data
=
[]
return
queue_reader
def
GenerateMiniBatch
(
batch_data
):
if
len
(
batch_data
)
==
1
:
return
batch_data
width
=
[
data
[
0
].
shape
[
2
]
for
data
in
batch_data
]
height
=
[
data
[
0
].
shape
[
1
]
for
data
in
batch_data
]
if
len
(
set
(
width
))
==
1
and
len
(
set
(
height
))
==
1
:
return
batch_data
max_shape
=
np
.
array
([
data
[
0
].
shape
for
data
in
batch_data
]).
max
(
axis
=
0
)
padding_batch
=
[]
for
data
in
batch_data
:
im_c
,
im_h
,
im_w
=
data
[
0
].
shape
[:]
padding_im
=
np
.
zeros
((
im_c
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
data
[
0
]
padding_batch
.
append
((
padding_im
,
)
+
data
[
1
:])
return
padding_batch
class
BaseReader
:
def
__init__
(
self
,
transforms
=
None
,
num_workers
=
4
,
buffer_size
=
100
,
parallel_method
=
'thread'
,
shuffle
=
False
):
if
transforms
is
None
:
raise
Exception
(
"transform should be defined."
)
self
.
transforms
=
transforms
self
.
num_workers
=
num_workers
self
.
buffer_size
=
buffer_size
self
.
parallel_method
=
parallel_method
self
.
shuffle
=
shuffle
def
generator
(
self
,
batch_size
=
1
,
drop_last
=
True
):
self
.
batch_size
=
batch_size
parallel_reader
=
multithread_reader
if
self
.
parallel_method
==
"process"
:
if
platform
.
platform
().
startswith
(
"Windows"
):
logging
.
debug
(
"multiprocess_reader is not supported in Windows platform, force to use multithread_reader."
)
else
:
parallel_reader
=
multiprocess_reader
return
parallel_reader
(
self
.
transforms
,
self
.
iterator
,
num_workers
=
self
.
num_workers
,
buffer_size
=
self
.
buffer_size
,
batch_size
=
batch_size
,
drop_last
=
drop_last
)
contrib/RemoteSensing/readers/reader.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
import
os.path
as
osp
import
random
from
..utils
import
logging
from
.base
import
BaseReader
from
.base
import
get_encoding
from
collections
import
OrderedDict
from
.base
import
is_pic
class
Reader
(
BaseReader
):
"""读取语分分割任务数据集,并对样本进行相应的处理。
Args:
data_dir (str): 数据集所在的目录路径。
file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
label_list (str): 描述数据集包含的类别信息文件路径。
transforms (list): 数据集中每个样本的预处理/增强算子。
num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
线程和'process'进程两种方式。默认为'thread'。
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
"""
def
__init__
(
self
,
data_dir
,
file_list
,
label_list
,
transforms
=
None
,
num_workers
=
4
,
buffer_size
=
100
,
parallel_method
=
'thread'
,
shuffle
=
False
):
super
(
Reader
,
self
).
__init__
(
transforms
=
transforms
,
num_workers
=
num_workers
,
buffer_size
=
buffer_size
,
parallel_method
=
parallel_method
,
shuffle
=
shuffle
)
self
.
file_list
=
OrderedDict
()
self
.
labels
=
list
()
self
.
_epoch
=
0
with
open
(
label_list
,
encoding
=
get_encoding
(
label_list
))
as
f
:
for
line
in
f
:
item
=
line
.
strip
()
self
.
labels
.
append
(
item
)
with
open
(
file_list
,
encoding
=
get_encoding
(
file_list
))
as
f
:
for
line
in
f
:
items
=
line
.
strip
().
split
()
full_path_im
=
osp
.
join
(
data_dir
,
items
[
0
])
full_path_label
=
osp
.
join
(
data_dir
,
items
[
1
])
if
not
osp
.
exists
(
full_path_im
):
raise
IOError
(
'The image file {} is not exist!'
.
format
(
full_path_im
))
if
not
osp
.
exists
(
full_path_label
):
raise
IOError
(
'The image file {} is not exist!'
.
format
(
full_path_label
))
self
.
file_list
[
full_path_im
]
=
full_path_label
self
.
num_samples
=
len
(
self
.
file_list
)
logging
.
info
(
"{} samples in file {}"
.
format
(
len
(
self
.
file_list
),
file_list
))
def
iterator
(
self
):
self
.
_epoch
+=
1
self
.
_pos
=
0
files
=
list
(
self
.
file_list
.
keys
())
if
self
.
shuffle
:
random
.
shuffle
(
files
)
files
=
files
[:
self
.
num_samples
]
self
.
num_samples
=
len
(
files
)
for
f
in
files
:
label_path
=
self
.
file_list
[
f
]
sample
=
[
f
,
None
,
label_path
]
yield
sample
contrib/RemoteSensing/transforms/__init__.py
0 → 100644
浏览文件 @
9a5a9f45
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
transforms
from
.
import
ops
contrib/RemoteSensing/transforms/ops.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
cv2
import
math
import
numpy
as
np
from
PIL
import
Image
,
ImageEnhance
def
normalize
(
im
,
mean
,
std
):
im
=
im
.
astype
(
np
.
float32
,
copy
=
False
)
/
255.0
im
-=
mean
im
/=
std
return
im
def
permute
(
im
,
to_bgr
=
False
):
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
np
.
swapaxes
(
im
,
1
,
0
)
if
to_bgr
:
im
=
im
[[
2
,
1
,
0
],
:,
:]
return
im
def
_resize
(
im
,
shape
):
return
cv2
.
resize
(
im
,
shape
)
def
resize_short
(
im
,
short_size
=
224
):
percent
=
float
(
short_size
)
/
min
(
im
.
shape
[
0
],
im
.
shape
[
1
])
resized_width
=
int
(
round
(
im
.
shape
[
1
]
*
percent
))
resized_height
=
int
(
round
(
im
.
shape
[
0
]
*
percent
))
im
=
_resize
(
im
,
shape
=
(
resized_width
,
resized_height
))
return
im
def
resize_long
(
im
,
long_size
=
224
,
interpolation
=
cv2
.
INTER_LINEAR
):
value
=
max
(
im
.
shape
[
0
],
im
.
shape
[
1
])
scale
=
float
(
long_size
)
/
float
(
value
)
im
=
cv2
.
resize
(
im
,
(
0
,
0
),
fx
=
scale
,
fy
=
scale
,
interpolation
=
interpolation
)
return
im
def
random_crop
(
im
,
crop_size
=
224
,
lower_scale
=
0.08
,
lower_ratio
=
3.
/
4
,
upper_ratio
=
4.
/
3
):
scale
=
[
lower_scale
,
1.0
]
ratio
=
[
lower_ratio
,
upper_ratio
]
aspect_ratio
=
math
.
sqrt
(
np
.
random
.
uniform
(
*
ratio
))
w
=
1.
*
aspect_ratio
h
=
1.
/
aspect_ratio
bound
=
min
((
float
(
im
.
shape
[
0
])
/
im
.
shape
[
1
])
/
(
h
**
2
),
(
float
(
im
.
shape
[
1
])
/
im
.
shape
[
0
])
/
(
w
**
2
))
scale_max
=
min
(
scale
[
1
],
bound
)
scale_min
=
min
(
scale
[
0
],
bound
)
target_area
=
im
.
shape
[
0
]
*
im
.
shape
[
1
]
*
np
.
random
.
uniform
(
scale_min
,
scale_max
)
target_size
=
math
.
sqrt
(
target_area
)
w
=
int
(
target_size
*
w
)
h
=
int
(
target_size
*
h
)
i
=
np
.
random
.
randint
(
0
,
im
.
shape
[
0
]
-
h
+
1
)
j
=
np
.
random
.
randint
(
0
,
im
.
shape
[
1
]
-
w
+
1
)
im
=
im
[
i
:
i
+
h
,
j
:
j
+
w
,
:]
im
=
_resize
(
im
,
shape
=
(
crop_size
,
crop_size
))
return
im
def
center_crop
(
im
,
crop_size
=
224
):
height
,
width
=
im
.
shape
[:
2
]
w_start
=
(
width
-
crop_size
)
//
2
h_start
=
(
height
-
crop_size
)
//
2
w_end
=
w_start
+
crop_size
h_end
=
h_start
+
crop_size
im
=
im
[
h_start
:
h_end
,
w_start
:
w_end
,
:]
return
im
def
horizontal_flip
(
im
):
if
len
(
im
.
shape
)
==
3
:
im
=
im
[:,
::
-
1
,
:]
elif
len
(
im
.
shape
)
==
2
:
im
=
im
[:,
::
-
1
]
return
im
def
vertical_flip
(
im
):
if
len
(
im
.
shape
)
==
3
:
im
=
im
[::
-
1
,
:,
:]
elif
len
(
im
.
shape
)
==
2
:
im
=
im
[::
-
1
,
:]
return
im
def
bgr2rgb
(
im
):
return
im
[:,
:,
::
-
1
]
def
brightness
(
im
,
brightness_lower
,
brightness_upper
):
brightness_delta
=
np
.
random
.
uniform
(
brightness_lower
,
brightness_upper
)
im
=
ImageEnhance
.
Brightness
(
im
).
enhance
(
brightness_delta
)
return
im
def
contrast
(
im
,
contrast_lower
,
contrast_upper
):
contrast_delta
=
np
.
random
.
uniform
(
contrast_lower
,
contrast_upper
)
im
=
ImageEnhance
.
Contrast
(
im
).
enhance
(
contrast_delta
)
return
im
def
saturation
(
im
,
saturation_lower
,
saturation_upper
):
saturation_delta
=
np
.
random
.
uniform
(
saturation_lower
,
saturation_upper
)
im
=
ImageEnhance
.
Color
(
im
).
enhance
(
saturation_delta
)
return
im
def
hue
(
im
,
hue_lower
,
hue_upper
):
hue_delta
=
np
.
random
.
uniform
(
hue_lower
,
hue_upper
)
im
=
np
.
array
(
im
.
convert
(
'HSV'
))
im
[:,
:,
0
]
=
im
[:,
:,
0
]
+
hue_delta
im
=
Image
.
fromarray
(
im
,
mode
=
'HSV'
).
convert
(
'RGB'
)
return
im
def
rotate
(
im
,
rotate_lower
,
rotate_upper
):
rotate_delta
=
np
.
random
.
uniform
(
rotate_lower
,
rotate_upper
)
im
=
im
.
rotate
(
int
(
rotate_delta
))
return
im
def
resize_padding
(
im
,
max_side_len
=
2400
):
'''
resize image to a size multiple of 32 which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
'''
h
,
w
,
_
=
im
.
shape
resize_w
=
w
resize_h
=
h
# limit the max side
if
max
(
resize_h
,
resize_w
)
>
max_side_len
:
ratio
=
float
(
max_side_len
)
/
resize_h
if
resize_h
>
resize_w
else
float
(
max_side_len
)
/
resize_w
else
:
ratio
=
1.
resize_h
=
int
(
resize_h
*
ratio
)
resize_w
=
int
(
resize_w
*
ratio
)
resize_h
=
resize_h
if
resize_h
%
32
==
0
else
(
resize_h
//
32
-
1
)
*
32
resize_w
=
resize_w
if
resize_w
%
32
==
0
else
(
resize_w
//
32
-
1
)
*
32
resize_h
=
max
(
32
,
resize_h
)
resize_w
=
max
(
32
,
resize_w
)
im
=
cv2
.
resize
(
im
,
(
int
(
resize_w
),
int
(
resize_h
)))
#im = cv2.resize(im, (512, 512))
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
_ratio
=
np
.
array
([
ratio_h
,
ratio_w
]).
reshape
(
-
1
,
2
)
return
im
,
_ratio
contrib/RemoteSensing/transforms/transforms.py
0 → 100644
浏览文件 @
9a5a9f45
此差异已折叠。
点击以展开。
contrib/RemoteSensing/utils/__init__.py
0 → 100644
浏览文件 @
9a5a9f45
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
logging
from
.
import
utils
from
.metrics
import
ConfusionMatrix
from
.utils
import
*
contrib/RemoteSensing/utils/logging.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
import
os
import
sys
import
RemoteSensing
levels
=
{
0
:
'ERROR'
,
1
:
'WARNING'
,
2
:
'INFO'
,
3
:
'DEBUG'
}
def
log
(
level
=
2
,
message
=
""
):
current_time
=
time
.
time
()
time_array
=
time
.
localtime
(
current_time
)
current_time
=
time
.
strftime
(
"%Y-%m-%d %H:%M:%S"
,
time_array
)
if
RemoteSensing
.
log_level
>=
level
:
print
(
"{} [{}]
\t
{}"
.
format
(
current_time
,
levels
[
level
],
message
).
encode
(
"utf-8"
).
decode
(
"latin1"
))
sys
.
stdout
.
flush
()
def
debug
(
message
=
""
):
log
(
level
=
3
,
message
=
message
)
def
info
(
message
=
""
):
log
(
level
=
2
,
message
=
message
)
def
warning
(
message
=
""
):
log
(
level
=
1
,
message
=
message
)
def
error
(
message
=
""
):
log
(
level
=
0
,
message
=
message
)
contrib/RemoteSensing/utils/metrics.py
0 → 100644
浏览文件 @
9a5a9f45
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
from
scipy.sparse
import
csr_matrix
class
ConfusionMatrix
(
object
):
"""
Confusion Matrix for segmentation evaluation
"""
def
__init__
(
self
,
num_classes
=
2
,
streaming
=
False
):
self
.
confusion_matrix
=
np
.
zeros
([
num_classes
,
num_classes
],
dtype
=
'int64'
)
self
.
num_classes
=
num_classes
self
.
streaming
=
streaming
def
calculate
(
self
,
pred
,
label
,
ignore
=
None
):
# If not in streaming mode, clear matrix everytime when call `calculate`
if
not
self
.
streaming
:
self
.
zero_matrix
()
label
=
np
.
transpose
(
label
,
(
0
,
2
,
3
,
1
))
ignore
=
np
.
transpose
(
ignore
,
(
0
,
2
,
3
,
1
))
mask
=
np
.
array
(
ignore
)
==
1
label
=
np
.
asarray
(
label
)[
mask
]
pred
=
np
.
asarray
(
pred
)[
mask
]
one
=
np
.
ones_like
(
pred
)
# Accumuate ([row=label, col=pred], 1) into sparse matrix
spm
=
csr_matrix
((
one
,
(
label
,
pred
)),
shape
=
(
self
.
num_classes
,
self
.
num_classes
))
spm
=
spm
.
todense
()
self
.
confusion_matrix
+=
spm
def
zero_matrix
(
self
):
""" Clear confusion matrix """
self
.
confusion_matrix
=
np
.
zeros
([
self
.
num_classes
,
self
.
num_classes
],
dtype
=
'int64'
)
def
mean_iou
(
self
):
iou_list
=
[]
avg_iou
=
0
# TODO: use numpy sum axis api to simpliy
vji
=
np
.
zeros
(
self
.
num_classes
,
dtype
=
int
)
vij
=
np
.
zeros
(
self
.
num_classes
,
dtype
=
int
)
for
j
in
range
(
self
.
num_classes
):
v_j
=
0
for
i
in
range
(
self
.
num_classes
):
v_j
+=
self
.
confusion_matrix
[
j
][
i
]
vji
[
j
]
=
v_j
for
i
in
range
(
self
.
num_classes
):
v_i
=
0
for
j
in
range
(
self
.
num_classes
):
v_i
+=
self
.
confusion_matrix
[
j
][
i
]
vij
[
i
]
=
v_i
for
c
in
range
(
self
.
num_classes
):
total
=
vji
[
c
]
+
vij
[
c
]
-
self
.
confusion_matrix
[
c
][
c
]
if
total
==
0
:
iou
=
0
else
:
iou
=
float
(
self
.
confusion_matrix
[
c
][
c
])
/
total
avg_iou
+=
iou
iou_list
.
append
(
iou
)
avg_iou
=
float
(
avg_iou
)
/
float
(
self
.
num_classes
)
return
np
.
array
(
iou_list
),
avg_iou
def
accuracy
(
self
):
total
=
self
.
confusion_matrix
.
sum
()
total_right
=
0
for
c
in
range
(
self
.
num_classes
):
total_right
+=
self
.
confusion_matrix
[
c
][
c
]
if
total
==
0
:
avg_acc
=
0
else
:
avg_acc
=
float
(
total_right
)
/
total
vij
=
np
.
zeros
(
self
.
num_classes
,
dtype
=
int
)
for
i
in
range
(
self
.
num_classes
):
v_i
=
0
for
j
in
range
(
self
.
num_classes
):
v_i
+=
self
.
confusion_matrix
[
j
][
i
]
vij
[
i
]
=
v_i
acc_list
=
[]
for
c
in
range
(
self
.
num_classes
):
if
vij
[
c
]
==
0
:
acc
=
0
else
:
acc
=
self
.
confusion_matrix
[
c
][
c
]
/
float
(
vij
[
c
])
acc_list
.
append
(
acc
)
return
np
.
array
(
acc_list
),
avg_acc
def
kappa
(
self
):
vji
=
np
.
zeros
(
self
.
num_classes
)
vij
=
np
.
zeros
(
self
.
num_classes
)
for
j
in
range
(
self
.
num_classes
):
v_j
=
0
for
i
in
range
(
self
.
num_classes
):
v_j
+=
self
.
confusion_matrix
[
j
][
i
]
vji
[
j
]
=
v_j
for
i
in
range
(
self
.
num_classes
):
v_i
=
0
for
j
in
range
(
self
.
num_classes
):
v_i
+=
self
.
confusion_matrix
[
j
][
i
]
vij
[
i
]
=
v_i
total
=
self
.
confusion_matrix
.
sum
()
# avoid spillovers
# TODO: is it reasonable to hard code 10000.0?
total
=
float
(
total
)
/
10000.0
vji
=
vji
/
10000.0
vij
=
vij
/
10000.0
tp
=
0
tc
=
0
for
c
in
range
(
self
.
num_classes
):
tp
+=
vji
[
c
]
*
vij
[
c
]
tc
+=
self
.
confusion_matrix
[
c
][
c
]
tc
=
tc
/
10000.0
pe
=
tp
/
(
total
*
total
)
po
=
tc
/
total
kappa
=
(
po
-
pe
)
/
(
1
-
pe
)
return
kappa
contrib/RemoteSensing/utils/pretrain_weights.py
0 → 100644
浏览文件 @
9a5a9f45
import
RemoteSensing
import
os
import
os.path
as
osp
def
get_pretrain_weights
(
flag
,
backbone
,
save_dir
):
if
flag
is
None
:
return
None
elif
osp
.
isdir
(
flag
):
return
flag
else
:
raise
Exception
(
"pretrain_weights need to be defined as directory path."
)
contrib/RemoteSensing/utils/utils.py
0 → 100644
浏览文件 @
9a5a9f45
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
time
import
os
import
os.path
as
osp
import
numpy
as
np
import
six
import
yaml
import
math
from
.
import
logging
def
seconds_to_hms
(
seconds
):
h
=
math
.
floor
(
seconds
/
3600
)
m
=
math
.
floor
((
seconds
-
h
*
3600
)
/
60
)
s
=
int
(
seconds
-
h
*
3600
-
m
*
60
)
hms_str
=
"{}:{}:{}"
.
format
(
h
,
m
,
s
)
return
hms_str
def
setting_environ_flags
():
if
'FLAGS_eager_delete_tensor_gb'
not
in
os
.
environ
:
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
if
'FLAGS_allocator_strategy'
not
in
os
.
environ
:
os
.
environ
[
'FLAGS_allocator_strategy'
]
=
'auto_growth'
if
"CUDA_VISIBLE_DEVICES"
in
os
.
environ
:
if
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
].
count
(
"-1"
)
>
0
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
""
def
get_environ_info
():
setting_environ_flags
()
import
paddle.fluid
as
fluid
info
=
dict
()
info
[
'place'
]
=
'cpu'
info
[
'num'
]
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
if
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
!=
""
:
if
hasattr
(
fluid
.
core
,
'get_cuda_device_count'
):
gpu_num
=
0
try
:
gpu_num
=
fluid
.
core
.
get_cuda_device_count
()
except
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
pass
if
gpu_num
>
0
:
info
[
'place'
]
=
'cuda'
info
[
'num'
]
=
fluid
.
core
.
get_cuda_device_count
()
return
info
def
parse_param_file
(
param_file
,
return_shape
=
True
):
from
paddle.fluid.proto.framework_pb2
import
VarType
f
=
open
(
param_file
,
'rb'
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
lod_level
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
for
i
in
range
(
int
(
lod_level
)):
_size
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
_
=
f
.
read
(
_size
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
=
VarType
.
TensorDesc
()
tensor_desc_size
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
.
ParseFromString
(
f
.
read
(
int
(
tensor_desc_size
)))
tensor_shape
=
tuple
(
tensor_desc
.
dims
)
if
return_shape
:
f
.
close
()
return
tuple
(
tensor_desc
.
dims
)
if
tensor_desc
.
data_type
!=
5
:
raise
Exception
(
"Unexpected data type while parse {}"
.
format
(
param_file
))
data_size
=
4
for
i
in
range
(
len
(
tensor_shape
)):
data_size
*=
tensor_shape
[
i
]
weight
=
np
.
fromstring
(
f
.
read
(
data_size
),
dtype
=
'float32'
)
f
.
close
()
return
np
.
reshape
(
weight
,
tensor_shape
)
def
fuse_bn_weights
(
exe
,
main_prog
,
weights_dir
):
import
paddle.fluid
as
fluid
logging
.
info
(
"Try to fuse weights of batch_norm..."
)
bn_vars
=
list
()
for
block
in
main_prog
.
blocks
:
ops
=
list
(
block
.
ops
)
for
op
in
ops
:
if
op
.
type
==
'affine_channel'
:
scale_name
=
op
.
input
(
'Scale'
)[
0
]
bias_name
=
op
.
input
(
'Bias'
)[
0
]
prefix
=
scale_name
[:
-
5
]
mean_name
=
prefix
+
'mean'
variance_name
=
prefix
+
'variance'
if
not
osp
.
exists
(
osp
.
join
(
weights_dir
,
mean_name
))
or
not
osp
.
exists
(
osp
.
join
(
weights_dir
,
variance_name
)):
logging
.
info
(
"There's no batch_norm weight found to fuse, skip fuse_bn."
)
return
bias
=
block
.
var
(
bias_name
)
pretrained_shape
=
parse_param_file
(
osp
.
join
(
weights_dir
,
bias_name
))
actual_shape
=
tuple
(
bias
.
shape
)
if
pretrained_shape
!=
actual_shape
:
continue
bn_vars
.
append
(
[
scale_name
,
bias_name
,
mean_name
,
variance_name
])
eps
=
1e-5
for
names
in
bn_vars
:
scale_name
,
bias_name
,
mean_name
,
variance_name
=
names
scale
=
parse_param_file
(
osp
.
join
(
weights_dir
,
scale_name
),
return_shape
=
False
)
bias
=
parse_param_file
(
osp
.
join
(
weights_dir
,
bias_name
),
return_shape
=
False
)
mean
=
parse_param_file
(
osp
.
join
(
weights_dir
,
mean_name
),
return_shape
=
False
)
variance
=
parse_param_file
(
osp
.
join
(
weights_dir
,
variance_name
),
return_shape
=
False
)
bn_std
=
np
.
sqrt
(
np
.
add
(
variance
,
eps
))
new_scale
=
np
.
float32
(
np
.
divide
(
scale
,
bn_std
))
new_bias
=
bias
-
mean
*
new_scale
scale_tensor
=
fluid
.
global_scope
().
find_var
(
scale_name
).
get_tensor
()
bias_tensor
=
fluid
.
global_scope
().
find_var
(
bias_name
).
get_tensor
()
scale_tensor
.
set
(
new_scale
,
exe
.
place
)
bias_tensor
.
set
(
new_bias
,
exe
.
place
)
if
len
(
bn_vars
)
==
0
:
logging
.
info
(
"There's no batch_norm weight found to fuse, skip fuse_bn."
)
else
:
logging
.
info
(
"There's {} batch_norm ops been fused."
.
format
(
len
(
bn_vars
)))
def
load_pdparams
(
exe
,
main_prog
,
model_dir
):
import
paddle.fluid
as
fluid
from
paddle.fluid.proto.framework_pb2
import
VarType
from
paddle.fluid.framework
import
Program
vars_to_load
=
list
()
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
params_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
unused_vars
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
var
.
name
not
in
params_dict
:
raise
Exception
(
"{} is not in saved model"
.
format
(
var
.
name
))
if
var
.
shape
!=
params_dict
[
var
.
name
].
shape
:
unused_vars
.
append
(
var
.
name
)
logging
.
warning
(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
var
.
name
,
params_dict
[
var
.
name
].
shape
,
var
.
shape
))
continue
vars_to_load
.
append
(
var
)
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
for
var_name
in
unused_vars
:
del
params_dict
[
var_name
]
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
logging
.
warning
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
logging
.
info
(
"There are {} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
model_dir
))
def
load_pretrain_weights
(
exe
,
main_prog
,
weights_dir
,
fuse_bn
=
False
):
if
not
osp
.
exists
(
weights_dir
):
raise
Exception
(
"Path {} not exists."
.
format
(
weights_dir
))
if
osp
.
exists
(
osp
.
join
(
weights_dir
,
"model.pdparams"
)):
return
load_pdparams
(
exe
,
main_prog
,
weights_dir
)
import
paddle.fluid
as
fluid
vars_to_load
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
not
osp
.
exists
(
osp
.
join
(
weights_dir
,
var
.
name
)):
logging
.
debug
(
"[SKIP] Pretrained weight {}/{} doesn't exist"
.
format
(
weights_dir
,
var
.
name
))
continue
pretrained_shape
=
parse_param_file
(
osp
.
join
(
weights_dir
,
var
.
name
))
actual_shape
=
tuple
(
var
.
shape
)
if
pretrained_shape
!=
actual_shape
:
logging
.
warning
(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
weights_dir
,
var
.
name
,
pretrained_shape
,
actual_shape
))
continue
vars_to_load
.
append
(
var
)
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
fluid
.
io
.
load_vars
(
executor
=
exe
,
dirname
=
weights_dir
,
main_program
=
main_prog
,
vars
=
vars_to_load
)
if
len
(
vars_to_load
)
==
0
:
logging
.
warning
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
logging
.
info
(
"There are {} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
weights_dir
))
if
fuse_bn
:
fuse_bn_weights
(
exe
,
main_prog
,
weights_dir
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录