Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ac2e2f9a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ac2e2f9a
编写于
6月 12, 2020
作者:
U
unknown
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
updata deeplabv3 CI
上级
c3d78e2a
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
1819 addition
and
0 deletion
+1819
-0
tests/st/networks/models/deeplabv3/src/__init__.py
tests/st/networks/models/deeplabv3/src/__init__.py
+23
-0
tests/st/networks/models/deeplabv3/src/backbone/__init__.py
tests/st/networks/models/deeplabv3/src/backbone/__init__.py
+21
-0
tests/st/networks/models/deeplabv3/src/backbone/resnet_deeplab.py
.../networks/models/deeplabv3/src/backbone/resnet_deeplab.py
+577
-0
tests/st/networks/models/deeplabv3/src/config.py
tests/st/networks/models/deeplabv3/src/config.py
+38
-0
tests/st/networks/models/deeplabv3/src/deeplabv3.py
tests/st/networks/models/deeplabv3/src/deeplabv3.py
+457
-0
tests/st/networks/models/deeplabv3/src/ei_dataset.py
tests/st/networks/models/deeplabv3/src/ei_dataset.py
+84
-0
tests/st/networks/models/deeplabv3/src/losses.py
tests/st/networks/models/deeplabv3/src/losses.py
+63
-0
tests/st/networks/models/deeplabv3/src/md_dataset.py
tests/st/networks/models/deeplabv3/src/md_dataset.py
+116
-0
tests/st/networks/models/deeplabv3/src/miou_precision.py
tests/st/networks/models/deeplabv3/src/miou_precision.py
+72
-0
tests/st/networks/models/deeplabv3/src/utils/__init__.py
tests/st/networks/models/deeplabv3/src/utils/__init__.py
+14
-0
tests/st/networks/models/deeplabv3/src/utils/adapter.py
tests/st/networks/models/deeplabv3/src/utils/adapter.py
+67
-0
tests/st/networks/models/deeplabv3/src/utils/custom_transforms.py
.../networks/models/deeplabv3/src/utils/custom_transforms.py
+149
-0
tests/st/networks/models/deeplabv3/src/utils/file_io.py
tests/st/networks/models/deeplabv3/src/utils/file_io.py
+36
-0
tests/st/networks/models/deeplabv3/test_deeplabv3.py
tests/st/networks/models/deeplabv3/test_deeplabv3.py
+102
-0
未找到文件。
tests/st/networks/models/deeplabv3/src/__init__.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init DeepLabv3."""
from
.deeplabv3
import
ASPP
,
DeepLabV3
,
deeplabv3_resnet50
from
.backbone
import
*
__all__
=
[
"ASPP"
,
"DeepLabV3"
,
"deeplabv3_resnet50"
]
__all__
.
extend
(
backbone
.
__all__
)
tests/st/networks/models/deeplabv3/src/backbone/__init__.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init backbone."""
from
.resnet_deeplab
import
Subsample
,
DepthwiseConv2dNative
,
SpaceToBatch
,
BatchToSpace
,
ResNetV1
,
\
RootBlockBeta
,
resnet50_dl
__all__
=
[
"Subsample"
,
"DepthwiseConv2dNative"
,
"SpaceToBatch"
,
"BatchToSpace"
,
"ResNetV1"
,
"RootBlockBeta"
,
"resnet50_dl"
]
tests/st/networks/models/deeplabv3/src/backbone/resnet_deeplab.py
0 → 100644
浏览文件 @
ac2e2f9a
此差异已折叠。
点击以展开。
tests/st/networks/models/deeplabv3/src/config.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and evaluation.py
"""
from
easydict
import
EasyDict
as
ed
config
=
ed
({
"learning_rate"
:
0.0014
,
"weight_decay"
:
0.00005
,
"momentum"
:
0.97
,
"crop_size"
:
513
,
"eval_scales"
:
[
0.5
,
0.75
,
1.0
,
1.25
,
1.5
,
1.75
],
"atrous_rates"
:
None
,
"image_pyramid"
:
None
,
"output_stride"
:
16
,
"fine_tune_batch_norm"
:
False
,
"ignore_label"
:
255
,
"decoder_output_stride"
:
None
,
"seg_num_classes"
:
21
,
"epoch_size"
:
6
,
"batch_size"
:
2
,
"enable_save_ckpt"
:
True
,
"save_checkpoint_steps"
:
10000
,
"save_checkpoint_num"
:
1
})
tests/st/networks/models/deeplabv3/src/deeplabv3.py
0 → 100644
浏览文件 @
ac2e2f9a
此差异已折叠。
点击以展开。
tests/st/networks/models/deeplabv3/src/ei_dataset.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Process Dataset."""
import
abc
import
os
import
time
from
.utils.adapter
import
get_raw_samples
,
read_image
class
BaseDataset
:
"""
Create dataset.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def
__init__
(
self
,
data_url
,
usage
):
self
.
data_url
=
data_url
self
.
usage
=
usage
self
.
cur_index
=
0
self
.
samples
=
[]
_s_time
=
time
.
time
()
self
.
_load_samples
()
_e_time
=
time
.
time
()
print
(
f
"load samples success~, time cost =
{
_e_time
-
_s_time
}
"
)
def
__getitem__
(
self
,
item
):
sample
=
self
.
samples
[
item
]
return
self
.
_next_data
(
sample
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
@
staticmethod
def
_next_data
(
sample
):
image_path
=
sample
[
0
]
mask_image_path
=
sample
[
1
]
image
=
read_image
(
image_path
)
mask_image
=
read_image
(
mask_image_path
)
return
[
image
,
mask_image
]
@
abc
.
abstractmethod
def
_load_samples
(
self
):
pass
class
HwVocRawDataset
(
BaseDataset
):
"""
Create dataset with raw data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def
__init__
(
self
,
data_url
,
usage
=
"train"
):
super
().
__init__
(
data_url
,
usage
)
def
_load_samples
(
self
):
try
:
self
.
samples
=
get_raw_samples
(
os
.
path
.
join
(
self
.
data_url
,
self
.
usage
))
except
Exception
as
e
:
print
(
"load HwVocRawDataset failed!!!"
)
raise
e
tests/st/networks/models/deeplabv3/src/losses.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""OhemLoss."""
import
mindspore.nn
as
nn
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
class
OhemLoss
(
nn
.
Cell
):
"""Ohem loss cell."""
def
__init__
(
self
,
num
,
ignore_label
):
super
(
OhemLoss
,
self
).
__init__
()
self
.
mul
=
P
.
Mul
()
self
.
shape
=
P
.
Shape
()
self
.
one_hot
=
nn
.
OneHot
(
-
1
,
num
,
1.0
,
0.0
)
self
.
squeeze
=
P
.
Squeeze
()
self
.
num
=
num
self
.
cross_entropy
=
P
.
SoftmaxCrossEntropyWithLogits
()
self
.
mean
=
P
.
ReduceMean
()
self
.
select
=
P
.
Select
()
self
.
reshape
=
P
.
Reshape
()
self
.
cast
=
P
.
Cast
()
self
.
not_equal
=
P
.
NotEqual
()
self
.
equal
=
P
.
Equal
()
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
False
)
self
.
fill
=
P
.
Fill
()
self
.
transpose
=
P
.
Transpose
()
self
.
ignore_label
=
ignore_label
self
.
loss_weight
=
1.0
def
construct
(
self
,
logits
,
labels
):
logits
=
self
.
transpose
(
logits
,
(
0
,
2
,
3
,
1
))
logits
=
self
.
reshape
(
logits
,
(
-
1
,
self
.
num
))
labels
=
F
.
cast
(
labels
,
mstype
.
int32
)
labels
=
self
.
reshape
(
labels
,
(
-
1
,))
one_hot_labels
=
self
.
one_hot
(
labels
)
losses
=
self
.
cross_entropy
(
logits
,
one_hot_labels
)[
0
]
weights
=
self
.
cast
(
self
.
not_equal
(
labels
,
self
.
ignore_label
),
mstype
.
float32
)
*
self
.
loss_weight
weighted_losses
=
self
.
mul
(
losses
,
weights
)
loss
=
self
.
reduce_sum
(
weighted_losses
,
(
0
,))
zeros
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
weights
),
0.0
)
ones
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
weights
),
1.0
)
present
=
self
.
select
(
self
.
equal
(
weights
,
zeros
),
zeros
,
ones
)
present
=
self
.
reduce_sum
(
present
,
(
0
,))
zeros
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
present
),
0.0
)
min_control
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
present
),
1.0
)
present
=
self
.
select
(
self
.
equal
(
present
,
zeros
),
min_control
,
present
)
loss
=
loss
/
present
return
loss
tests/st/networks/models/deeplabv3/src/md_dataset.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset module."""
from
PIL
import
Image
import
mindspore.dataset
as
de
import
mindspore.dataset.transforms.vision.c_transforms
as
C
import
numpy
as
np
from
.ei_dataset
import
HwVocRawDataset
from
.utils
import
custom_transforms
as
tr
class
DataTransform
:
"""Transform dataset for DeepLabV3."""
def
__init__
(
self
,
args
,
usage
):
self
.
args
=
args
self
.
usage
=
usage
def
__call__
(
self
,
image
,
label
):
if
self
.
usage
==
"train"
:
return
self
.
_train
(
image
,
label
)
if
self
.
usage
==
"eval"
:
return
self
.
_eval
(
image
,
label
)
return
None
def
_train
(
self
,
image
,
label
):
"""
Process training data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image
=
Image
.
fromarray
(
image
)
label
=
Image
.
fromarray
(
label
)
rsc_tr
=
tr
.
RandomScaleCrop
(
base_size
=
self
.
args
.
base_size
,
crop_size
=
self
.
args
.
crop_size
)
image
,
label
=
rsc_tr
(
image
,
label
)
rhf_tr
=
tr
.
RandomHorizontalFlip
()
image
,
label
=
rhf_tr
(
image
,
label
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
label
=
np
.
array
(
label
).
astype
(
np
.
float32
)
return
image
,
label
def
_eval
(
self
,
image
,
label
):
"""
Process eval data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image
=
Image
.
fromarray
(
image
)
label
=
Image
.
fromarray
(
label
)
fsc_tr
=
tr
.
FixScaleCrop
(
crop_size
=
self
.
args
.
crop_size
)
image
,
label
=
fsc_tr
(
image
,
label
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
label
=
np
.
array
(
label
).
astype
(
np
.
float32
)
return
image
,
label
def
create_dataset
(
args
,
data_url
,
epoch_num
=
1
,
batch_size
=
1
,
usage
=
"train"
,
shuffle
=
True
):
"""
Create Dataset for DeepLabV3.
Args:
args (dict): Train parameters.
data_url (str): Dataset path.
epoch_num (int): Epoch of dataset (default=1).
batch_size (int): Batch size of dataset (default=1).
usage (str): Whether is use to train or eval (default='train').
Returns:
Dataset.
"""
# create iter dataset
dataset
=
HwVocRawDataset
(
data_url
,
usage
=
usage
)
dataset_len
=
len
(
dataset
)
# wrapped with GeneratorDataset
dataset
=
de
.
GeneratorDataset
(
dataset
,
[
"image"
,
"label"
],
sampler
=
None
)
dataset
.
set_dataset_size
(
dataset_len
)
dataset
=
dataset
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
DataTransform
(
args
,
usage
=
usage
))
channelswap_op
=
C
.
HWC2CHW
()
dataset
=
dataset
.
map
(
input_columns
=
"image"
,
operations
=
channelswap_op
)
# 1464 samples / batch_size 8 = 183 batches
# epoch_num is num of steps
# 3658 steps / 183 = 20 epochs
if
usage
==
"train"
and
shuffle
:
dataset
=
dataset
.
shuffle
(
1464
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
(
usage
==
"train"
))
dataset
=
dataset
.
repeat
(
count
=
epoch_num
)
dataset
.
map_model
=
4
return
dataset
tests/st/networks/models/deeplabv3/src/miou_precision.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mIou."""
import
numpy
as
np
from
mindspore.nn.metrics.metric
import
Metric
def
confuse_matrix
(
target
,
pred
,
n
):
k
=
(
target
>=
0
)
&
(
target
<
n
)
return
np
.
bincount
(
n
*
target
[
k
].
astype
(
int
)
+
pred
[
k
],
minlength
=
n
**
2
).
reshape
(
n
,
n
)
def
iou
(
hist
):
denominator
=
hist
.
sum
(
1
)
+
hist
.
sum
(
0
)
-
np
.
diag
(
hist
)
res
=
np
.
diag
(
hist
)
/
np
.
where
(
denominator
>
0
,
denominator
,
1
)
res
=
np
.
sum
(
res
)
/
np
.
count_nonzero
(
denominator
)
return
res
class
MiouPrecision
(
Metric
):
"""Calculate miou precision."""
def
__init__
(
self
,
num_class
=
21
):
super
(
MiouPrecision
,
self
).
__init__
()
if
not
isinstance
(
num_class
,
int
):
raise
TypeError
(
'num_class should be integer type, but got {}'
.
format
(
type
(
num_class
)))
if
num_class
<
1
:
raise
ValueError
(
'num_class must be at least 1, but got {}'
.
format
(
num_class
))
self
.
_num_class
=
num_class
self
.
_mIoU
=
[]
self
.
clear
()
def
clear
(
self
):
self
.
_hist
=
np
.
zeros
((
self
.
_num_class
,
self
.
_num_class
))
self
.
_mIoU
=
[]
def
update
(
self
,
*
inputs
):
if
len
(
inputs
)
!=
2
:
raise
ValueError
(
'Need 2 inputs (y_pred, y), but got {}'
.
format
(
len
(
inputs
)))
predict_in
=
self
.
_convert_data
(
inputs
[
0
])
label_in
=
self
.
_convert_data
(
inputs
[
1
])
if
predict_in
.
shape
[
1
]
!=
self
.
_num_class
:
raise
ValueError
(
'Class number not match, last input data contain {} classes, but current data contain {} '
'classes'
.
format
(
self
.
_num_class
,
predict_in
.
shape
[
1
]))
pred
=
np
.
argmax
(
predict_in
,
axis
=
1
)
label
=
label_in
if
len
(
label
.
flatten
())
!=
len
(
pred
.
flatten
()):
print
(
'Skipping: len(gt) = {:d}, len(pred) = {:d}'
.
format
(
len
(
label
.
flatten
()),
len
(
pred
.
flatten
())))
raise
ValueError
(
'Class number not match, last input data contain {} classes, but current data contain {} '
'classes'
.
format
(
self
.
_num_class
,
predict_in
.
shape
[
1
]))
self
.
_hist
=
confuse_matrix
(
label
.
flatten
(),
pred
.
flatten
(),
self
.
_num_class
)
mIoUs
=
iou
(
self
.
_hist
)
self
.
_mIoU
.
append
(
mIoUs
)
def
eval
(
self
):
"""
Computes the mIoU categorical accuracy.
"""
mIoU
=
np
.
nanmean
(
self
.
_mIoU
)
print
(
'mIoU = {}'
.
format
(
mIoU
))
return
mIoU
tests/st/networks/models/deeplabv3/src/utils/__init__.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
tests/st/networks/models/deeplabv3/src/utils/adapter.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Adapter dataset."""
import
fnmatch
import
io
import
os
import
numpy
as
np
from
PIL
import
Image
from
..utils
import
file_io
def
get_raw_samples
(
data_url
):
"""
Get dataset from raw data.
Args:
data_url (str): Dataset path.
Returns:
list, a file list.
"""
def
_list_files
(
dir_path
,
pattern
):
full_files
=
[]
_
,
_
,
files
=
next
(
file_io
.
walk
(
dir_path
))
for
f
in
files
:
if
fnmatch
.
fnmatch
(
f
.
lower
(),
pattern
.
lower
()):
full_files
.
append
(
os
.
path
.
join
(
dir_path
,
f
))
return
full_files
img_files
=
_list_files
(
os
.
path
.
join
(
data_url
,
"Images"
),
"*.jpg"
)
seg_files
=
_list_files
(
os
.
path
.
join
(
data_url
,
"SegmentationClassRaw"
),
"*.png"
)
files
=
[]
for
img_file
in
img_files
:
_
,
file_name
=
os
.
path
.
split
(
img_file
)
name
,
_
=
os
.
path
.
splitext
(
file_name
)
seg_file
=
os
.
path
.
join
(
data_url
,
"SegmentationClassRaw"
,
"."
.
join
([
name
,
"png"
]))
if
seg_file
in
seg_files
:
files
.
append
([
img_file
,
seg_file
])
return
files
def
read_image
(
img_path
):
"""
Read image from file.
Args:
img_path (str): image path.
"""
img
=
file_io
.
read
(
img_path
.
strip
(),
binary
=
True
)
data
=
io
.
BytesIO
(
img
)
img
=
Image
.
open
(
data
)
return
np
.
array
(
img
)
tests/st/networks/models/deeplabv3/src/utils/custom_transforms.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Random process dataset."""
import
random
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
ImageFilter
class
Normalize
:
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def
__init__
(
self
,
mean
=
(
0.
,
0.
,
0.
),
std
=
(
1.
,
1.
,
1.
)):
self
.
mean
=
mean
self
.
std
=
std
def
__call__
(
self
,
img
,
mask
):
img
=
np
.
array
(
img
).
astype
(
np
.
float32
)
mask
=
np
.
array
(
mask
).
astype
(
np
.
float32
)
img
=
((
img
-
self
.
mean
)
/
self
.
std
).
astype
(
np
.
float32
)
return
img
,
mask
class
RandomHorizontalFlip
:
"""Randomly decide whether to horizontal flip."""
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
img
=
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
mask
=
mask
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
img
,
mask
class
RandomRotate
:
"""
Randomly decide whether to rotate.
Args:
degree (float): The degree of rotate.
"""
def
__init__
(
self
,
degree
):
self
.
degree
=
degree
def
__call__
(
self
,
img
,
mask
):
rotate_degree
=
random
.
uniform
(
-
1
*
self
.
degree
,
self
.
degree
)
img
=
img
.
rotate
(
rotate_degree
,
Image
.
BILINEAR
)
mask
=
mask
.
rotate
(
rotate_degree
,
Image
.
NEAREST
)
return
img
,
mask
class
RandomGaussianBlur
:
"""Randomly decide whether to filter image with gaussian blur."""
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
img
=
img
.
filter
(
ImageFilter
.
GaussianBlur
(
radius
=
random
.
random
()))
return
img
,
mask
class
RandomScaleCrop
:
"""Randomly decide whether to scale and crop image."""
def
__init__
(
self
,
base_size
,
crop_size
,
fill
=
0
):
self
.
base_size
=
base_size
self
.
crop_size
=
crop_size
self
.
fill
=
fill
def
__call__
(
self
,
img
,
mask
):
# random scale (short edge)
short_size
=
random
.
randint
(
int
(
self
.
base_size
*
0.5
),
int
(
self
.
base_size
*
2.0
))
w
,
h
=
img
.
size
if
h
>
w
:
ow
=
short_size
oh
=
int
(
1.0
*
h
*
ow
/
w
)
else
:
oh
=
short_size
ow
=
int
(
1.0
*
w
*
oh
/
h
)
img
=
img
.
resize
((
ow
,
oh
),
Image
.
BILINEAR
)
mask
=
mask
.
resize
((
ow
,
oh
),
Image
.
NEAREST
)
# pad crop
if
short_size
<
self
.
crop_size
:
padh
=
self
.
crop_size
-
oh
if
oh
<
self
.
crop_size
else
0
padw
=
self
.
crop_size
-
ow
if
ow
<
self
.
crop_size
else
0
img
=
ImageOps
.
expand
(
img
,
border
=
(
0
,
0
,
padw
,
padh
),
fill
=
0
)
mask
=
ImageOps
.
expand
(
mask
,
border
=
(
0
,
0
,
padw
,
padh
),
fill
=
self
.
fill
)
# random crop crop_size
w
,
h
=
img
.
size
x1
=
random
.
randint
(
0
,
w
-
self
.
crop_size
)
y1
=
random
.
randint
(
0
,
h
-
self
.
crop_size
)
img
=
img
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
mask
=
mask
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
return
img
,
mask
class
FixScaleCrop
:
"""Scale and crop image with fixing size."""
def
__init__
(
self
,
crop_size
):
self
.
crop_size
=
crop_size
def
__call__
(
self
,
img
,
mask
):
w
,
h
=
img
.
size
if
w
>
h
:
oh
=
self
.
crop_size
ow
=
int
(
1.0
*
w
*
oh
/
h
)
else
:
ow
=
self
.
crop_size
oh
=
int
(
1.0
*
h
*
ow
/
w
)
img
=
img
.
resize
((
ow
,
oh
),
Image
.
BILINEAR
)
mask
=
mask
.
resize
((
ow
,
oh
),
Image
.
NEAREST
)
# center crop
w
,
h
=
img
.
size
x1
=
int
(
round
((
w
-
self
.
crop_size
)
/
2.
))
y1
=
int
(
round
((
h
-
self
.
crop_size
)
/
2.
))
img
=
img
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
mask
=
mask
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
return
img
,
mask
class
FixedResize
:
"""Resize image with fixing size."""
def
__init__
(
self
,
size
):
self
.
size
=
(
size
,
size
)
def
__call__
(
self
,
img
,
mask
):
assert
img
.
size
==
mask
.
size
img
=
img
.
resize
(
self
.
size
,
Image
.
BILINEAR
)
mask
=
mask
.
resize
(
self
.
size
,
Image
.
NEAREST
)
return
img
,
mask
tests/st/networks/models/deeplabv3/src/utils/file_io.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""File operation module."""
import
os
def
_is_obs
(
url
):
return
url
.
startswith
(
"obs://"
)
or
url
.
startswith
(
"s3://"
)
def
read
(
url
,
binary
=
False
):
if
_is_obs
(
url
):
# TODO read cloud file.
return
None
with
open
(
url
,
"rb"
if
binary
else
"r"
)
as
f
:
return
f
.
read
()
def
walk
(
url
):
if
_is_obs
(
url
):
# TODO read cloud file.
return
None
return
os
.
walk
(
url
)
tests/st/networks/models/deeplabv3/test_deeplabv3.py
0 → 100644
浏览文件 @
ac2e2f9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train."""
import
argparse
import
time
import
pytest
import
numpy
as
np
from
mindspore
import
context
,
Tensor
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore
import
Model
from
mindspore.train.callback
import
Callback
from
src.md_dataset
import
create_dataset
from
src.losses
import
OhemLoss
from
src.deeplabv3
import
deeplabv3_resnet50
from
src.config
import
config
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
#--train
#--eval
# --Images
# --2008_001135.jpg
# --2008_001404.jpg
# --SegmentationClassRaw
# --2008_001135.png
# --2008_001404.png
data_url
=
"/home/workspace/mindspore_dataset/voc/voc2012"
class
LossCallBack
(
Callback
):
"""
Monitor the loss in training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def
__init__
(
self
,
data_size
,
per_print_times
=
1
):
super
(
LossCallBack
,
self
).
__init__
()
if
not
isinstance
(
per_print_times
,
int
)
or
per_print_times
<
0
:
raise
ValueError
(
"print_step must be int and >= 0"
)
self
.
data_size
=
data_size
self
.
_per_print_times
=
per_print_times
self
.
time
=
1000
self
.
loss
=
0
def
epoch_begin
(
self
,
run_context
):
self
.
epoch_time
=
time
.
time
()
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
epoch_mseconds
=
(
time
.
time
()
-
self
.
epoch_time
)
*
1000
self
.
time
=
epoch_mseconds
/
self
.
data_size
self
.
loss
=
cb_params
.
net_outputs
print
(
"epoch: {}, step: {}, outputs are {}"
.
format
(
cb_params
.
cur_epoch_num
,
cb_params
.
cur_step_num
,
str
(
cb_params
.
net_outputs
)))
def
model_fine_tune
(
train_net
,
fix_weight_layer
):
for
para
in
train_net
.
trainable_params
():
para
.
set_parameter_data
(
Tensor
(
np
.
ones
(
para
.
data
.
shape
).
astype
(
np
.
float32
)
*
0.02
))
if
fix_weight_layer
in
para
.
name
:
para
.
requires_grad
=
False
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_deeplabv3_1p
():
start_time
=
time
.
time
()
epoch_size
=
100
args_opt
=
argparse
.
Namespace
(
base_size
=
513
,
crop_size
=
513
,
batch_size
=
2
)
args_opt
.
base_size
=
config
.
crop_size
args_opt
.
crop_size
=
config
.
crop_size
args_opt
.
batch_size
=
config
.
batch_size
train_dataset
=
create_dataset
(
args_opt
,
data_url
,
epoch_size
,
config
.
batch_size
,
usage
=
"eval"
)
dataset_size
=
train_dataset
.
get_dataset_size
()
callback
=
LossCallBack
(
dataset_size
)
net
=
deeplabv3_resnet50
(
config
.
seg_num_classes
,
[
config
.
batch_size
,
3
,
args_opt
.
crop_size
,
args_opt
.
crop_size
],
infer_scale_sizes
=
config
.
eval_scales
,
atrous_rates
=
config
.
atrous_rates
,
decoder_output_stride
=
config
.
decoder_output_stride
,
output_stride
=
config
.
output_stride
,
fine_tune_batch_norm
=
config
.
fine_tune_batch_norm
,
image_pyramid
=
config
.
image_pyramid
)
net
.
set_train
()
model_fine_tune
(
net
,
'layer'
)
loss
=
OhemLoss
(
config
.
seg_num_classes
,
config
.
ignore_label
)
opt
=
Momentum
(
filter
(
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
and
'depth'
not
in
x
.
name
and
'bias'
not
in
x
.
name
,
net
.
trainable_params
()),
learning_rate
=
config
.
learning_rate
,
momentum
=
config
.
momentum
,
weight_decay
=
config
.
weight_decay
)
model
=
Model
(
net
,
loss
,
opt
)
model
.
train
(
epoch_size
,
train_dataset
,
callback
)
print
(
time
.
time
()
-
start_time
)
print
(
"expect loss: "
,
callback
.
loss
)
print
(
"expect time: "
,
callback
.
time
)
expect_loss
=
0.92
expect_time
=
40
assert
callback
.
loss
.
asnumpy
()
<=
expect_loss
assert
callback
.
time
<=
expect_time
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录