Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0a97cb8a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a97cb8a
编写于
5月 29, 2020
作者:
Y
yangyongjie
提交者:
unknown
5月 29, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add deeplabv3 to model zoo
上级
803a9159
变更
19
展开全部
隐藏空白更改
内联
并排
Showing
19 changed file
with
2057 addition
and
0 deletion
+2057
-0
example/deeplabv3_voc2012/README.md
example/deeplabv3_voc2012/README.md
+66
-0
example/deeplabv3_voc2012/evaluation.py
example/deeplabv3_voc2012/evaluation.py
+53
-0
example/deeplabv3_voc2012/scripts/run_distribute_train.sh
example/deeplabv3_voc2012/scripts/run_distribute_train.sh
+66
-0
example/deeplabv3_voc2012/scripts/run_eval.sh
example/deeplabv3_voc2012/scripts/run_eval.sh
+32
-0
example/deeplabv3_voc2012/scripts/run_standalone_train.sh
example/deeplabv3_voc2012/scripts/run_standalone_train.sh
+38
-0
example/deeplabv3_voc2012/src/__init__.py
example/deeplabv3_voc2012/src/__init__.py
+23
-0
example/deeplabv3_voc2012/src/backbone/__init__.py
example/deeplabv3_voc2012/src/backbone/__init__.py
+21
-0
example/deeplabv3_voc2012/src/backbone/resnet_deeplab.py
example/deeplabv3_voc2012/src/backbone/resnet_deeplab.py
+577
-0
example/deeplabv3_voc2012/src/config.py
example/deeplabv3_voc2012/src/config.py
+33
-0
example/deeplabv3_voc2012/src/deeplabv3.py
example/deeplabv3_voc2012/src/deeplabv3.py
+457
-0
example/deeplabv3_voc2012/src/ei_dataset.py
example/deeplabv3_voc2012/src/ei_dataset.py
+84
-0
example/deeplabv3_voc2012/src/losses.py
example/deeplabv3_voc2012/src/losses.py
+63
-0
example/deeplabv3_voc2012/src/md_dataset.py
example/deeplabv3_voc2012/src/md_dataset.py
+115
-0
example/deeplabv3_voc2012/src/miou_precision.py
example/deeplabv3_voc2012/src/miou_precision.py
+72
-0
example/deeplabv3_voc2012/src/utils/__init__.py
example/deeplabv3_voc2012/src/utils/__init__.py
+14
-0
example/deeplabv3_voc2012/src/utils/adapter.py
example/deeplabv3_voc2012/src/utils/adapter.py
+67
-0
example/deeplabv3_voc2012/src/utils/custom_transforms.py
example/deeplabv3_voc2012/src/utils/custom_transforms.py
+148
-0
example/deeplabv3_voc2012/src/utils/file_io.py
example/deeplabv3_voc2012/src/utils/file_io.py
+36
-0
example/deeplabv3_voc2012/train.py
example/deeplabv3_voc2012/train.py
+92
-0
未找到文件。
example/deeplabv3_voc2012/README.md
0 → 100644
浏览文件 @
0a97cb8a
# Deeplab-V3 Example
## Description
This is an example of training DeepLabv3 with PASCAL VOC 2012 dataset in MindSpore.
## Requirements
-
Install
[
MindSpore
](
https://www.mindspore.cn/install/en
)
.
-
Download the VOC 2012 dataset for training.
> Notes:
If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
## Running the Example
### Training
-
Set options in config.py.
-
Run
`run_standalone_train.sh`
for non-distributed training.
```
bash
sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_DIR
```
-
Run
`run_distribute_train.sh`
for distributed training.
```
bash
sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH
```
### Evaluation
Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path.
-
Run run_eval.sh for evaluation.
```
bash
sh scripts/run_eval.sh DEVICE_ID DATA_DIR
```
## Options and Parameters
It contains of parameters of Deeplab-V3 model and options for training, which is set in file config.py.
### Options:
```
config.py:
learning_rate Learning rate, default is 0.0014.
weight_decay Weight decay, default is 5e-5.
momentum Momentum, default is 0.97.
crop_size Image crop size [height, width] during training, default is 513.
eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75].
output_stride The ratio of input to output spatial resolution, default is 16.
ignore_label Ignore label value, default is 255.
seg_num_classes Number of semantic classes, including the background class (if exists).
foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21.
fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False.
atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None.
decoder_output_stride The ratio of input to output spatial resolution when employing decoder
to refine segmentation results, default is None.
image_pyramid Input scales for multi-scale feature extraction, default is None.
```
### Parameters:
```
Parameters for dataset and network:
distribute Run distribute, default is false.
epoch_size Epoch size, default is 6.
batch_size batch size of input dataset: N, default is 2.
data_url Train/Evaluation data url, required.
checkpoint_url Checkpoint path, default is None.
enable_save_ckpt Enable save checkpoint, default is true.
save_checkpoint_steps Save checkpoint steps, default is 1000.
save_checkpoint_num Save checkpoint numbers, default is 1.
```
\ No newline at end of file
example/deeplabv3_voc2012/evaluation.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation."""
import
argparse
from
mindspore
import
context
from
mindspore
import
Model
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
src.md_dataset
import
create_dataset
from
src.losses
import
OhemLoss
from
src.miou_precision
import
MiouPrecision
from
src.deeplabv3
import
deeplabv3_resnet50
from
src.config
import
config
parser
=
argparse
.
ArgumentParser
(
description
=
"Deeplabv3 evaluation"
)
parser
.
add_argument
(
'--epoch_size'
,
type
=
int
,
default
=
2
,
help
=
'Epoch size.'
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
2
,
help
=
'Batch size.'
)
parser
.
add_argument
(
'--data_url'
,
required
=
True
,
default
=
None
,
help
=
'Evaluation data url'
)
parser
.
add_argument
(
'--checkpoint_url'
,
default
=
None
,
help
=
'Checkpoint path'
)
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
args_opt
.
device_id
)
print
(
args_opt
)
if
__name__
==
"__main__"
:
args_opt
.
crop_size
=
config
.
crop_size
args_opt
.
base_size
=
config
.
crop_size
eval_dataset
=
create_dataset
(
args_opt
,
args_opt
.
data_url
,
args_opt
.
epoch_size
,
args_opt
.
batch_size
,
usage
=
"eval"
)
net
=
deeplabv3_resnet50
(
config
.
seg_num_classes
,
[
args_opt
.
batch_size
,
3
,
args_opt
.
crop_size
,
args_opt
.
crop_size
],
infer_scale_sizes
=
config
.
eval_scales
,
atrous_rates
=
config
.
atrous_rates
,
decoder_output_stride
=
config
.
decoder_output_stride
,
output_stride
=
config
.
output_stride
,
fine_tune_batch_norm
=
config
.
fine_tune_batch_norm
,
image_pyramid
=
config
.
image_pyramid
)
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_url
)
load_param_into_net
(
net
,
param_dict
)
mIou
=
MiouPrecision
(
config
.
seg_num_classes
)
metrics
=
{
'mIou'
:
mIou
}
loss
=
OhemLoss
(
config
.
seg_num_classes
,
config
.
ignore_label
)
model
=
Model
(
net
,
loss
,
metrics
=
metrics
)
model
.
eval
(
eval_dataset
)
example/deeplabv3_voc2012/scripts/run_distribute_train.sh
0 → 100644
浏览文件 @
0a97cb8a
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo
"=============================================================================================================="
echo
"Please run the scipt as: "
echo
"bash run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH"
echo
"for example: bash run_distribute_train.sh 8 40 /path/zh-wiki/ /path/hccl.json"
echo
"It is better to use absolute path."
echo
"=============================================================================================================="
EPOCH_SIZE
=
$2
DATA_DIR
=
$3
export
MINDSPORE_HCCL_CONFIG_PATH
=
$4
export
RANK_TABLE_FILE
=
$4
export
RANK_SIZE
=
$1
cores
=
`
cat
/proc/cpuinfo|grep
"processor"
|wc
-l
`
echo
"the number of logical core"
$cores
avg_core_per_rank
=
`
expr
$cores
\/
$RANK_SIZE
`
core_gap
=
`
expr
$avg_core_per_rank
\-
1
`
echo
"avg_core_per_rank"
$avg_core_per_rank
echo
"core_gap"
$core_gap
for
((
i
=
0
;
i<RANK_SIZE
;
i++
))
do
start
=
`
expr
$i
\*
$avg_core_per_rank
`
export
DEVICE_ID
=
$i
export
RANK_ID
=
$i
export
DEPLOY_MODE
=
0
export
GE_USE_STATIC_MEMORY
=
1
end
=
`
expr
$start
\+
$core_gap
`
cmdopt
=
$start
"-"
$end
rm
-rf
LOG
$i
mkdir
./LOG
$i
cp
*
.py ./LOG
$i
cd
./LOG
$i
||
exit
echo
"start training for rank
$i
, device
$DEVICE_ID
"
mkdir
-p
ms_log
CUR_DIR
=
`
pwd
`
export
GLOG_log_dir
=
${
CUR_DIR
}
/ms_log
export
GLOG_logtostderr
=
0
env
>
env.log
taskset
-c
$cmdopt
python ../train.py
\
--distribute
=
"true"
\
--epoch_size
=
$EPOCH_SIZE
\
--device_id
=
$DEVICE_ID
\
--enable_save_ckpt
=
"true"
\
--checkpoint_url
=
""
\
--save_checkpoint_steps
=
10000
\
--save_checkpoint_num
=
1
\
--data_url
=
$DATA_DIR
>
log.txt 2>&1 &
cd
../
done
\ No newline at end of file
example/deeplabv3_voc2012/scripts/run_eval.sh
0 → 100644
浏览文件 @
0a97cb8a
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo
"=============================================================================================================="
echo
"Please run the scipt as: "
echo
"bash run_eval.sh DEVICE_ID DATA_DIR"
echo
"for example: bash run_eval.sh /path/zh-wiki/ "
echo
"=============================================================================================================="
DEVICE_ID
=
$1
DATA_DIR
=
$2
mkdir
-p
ms_log
CUR_DIR
=
`
pwd
`
export
GLOG_log_dir
=
${
CUR_DIR
}
/ms_log
export
GLOG_logtostderr
=
0
python evaluation.py
\
--device_id
=
$DEVICE_ID
\
--checkpoint_url
=
""
\
--data_url
=
$DATA_DIR
>
log.txt 2>&1 &
\ No newline at end of file
example/deeplabv3_voc2012/scripts/run_standalone_train.sh
0 → 100644
浏览文件 @
0a97cb8a
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo
"=============================================================================================================="
echo
"Please run the scipt as: "
echo
"bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR"
echo
"for example: bash run_standalone_train.sh 0 40 /path/zh-wiki/ "
echo
"=============================================================================================================="
DEVICE_ID
=
$1
EPOCH_SIZE
=
$2
DATA_DIR
=
$3
mkdir
-p
ms_log
CUR_DIR
=
`
pwd
`
export
GLOG_log_dir
=
${
CUR_DIR
}
/ms_log
export
GLOG_logtostderr
=
0
python train.py
\
--distribute
=
"false"
\
--epoch_size
=
$EPOCH_SIZE
\
--device_id
=
$DEVICE_ID
\
--enable_save_ckpt
=
"true"
\
--checkpoint_url
=
""
\
--save_checkpoint_steps
=
10000
\
--save_checkpoint_num
=
1
\
--data_url
=
$DATA_DIR
>
log.txt 2>&1 &
\ No newline at end of file
example/deeplabv3_voc2012/src/__init__.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init DeepLabv3."""
from
.deeplabv3
import
ASPP
,
DeepLabV3
,
deeplabv3_resnet50
from
.backbone
import
*
__all__
=
[
"ASPP"
,
"DeepLabV3"
,
"deeplabv3_resnet50"
]
__all__
.
extend
(
backbone
.
__all__
)
example/deeplabv3_voc2012/src/backbone/__init__.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init backbone."""
from
.resnet_deeplab
import
Subsample
,
DepthwiseConv2dNative
,
SpaceToBatch
,
BatchToSpace
,
ResNetV1
,
\
RootBlockBeta
,
resnet50_dl
__all__
=
[
"Subsample"
,
"DepthwiseConv2dNative"
,
"SpaceToBatch"
,
"BatchToSpace"
,
"ResNetV1"
,
"RootBlockBeta"
,
"resnet50_dl"
]
example/deeplabv3_voc2012/src/backbone/resnet_deeplab.py
0 → 100644
浏览文件 @
0a97cb8a
此差异已折叠。
点击以展开。
example/deeplabv3_voc2012/src/config.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and evaluation.py
"""
from
easydict
import
EasyDict
as
ed
config
=
ed
({
"learning_rate"
:
0.0014
,
"weight_decay"
:
0.00005
,
"momentum"
:
0.97
,
"crop_size"
:
513
,
"eval_scales"
:
[
0.5
,
0.75
,
1.0
,
1.25
,
1.5
,
1.75
],
"atrous_rates"
:
None
,
"image_pyramid"
:
None
,
"output_stride"
:
16
,
"fine_tune_batch_norm"
:
False
,
"ignore_label"
:
255
,
"decoder_output_stride"
:
None
,
"seg_num_classes"
:
21
})
example/deeplabv3_voc2012/src/deeplabv3.py
0 → 100644
浏览文件 @
0a97cb8a
此差异已折叠。
点击以展开。
example/deeplabv3_voc2012/src/ei_dataset.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Process Dataset."""
import
abc
import
os
import
time
from
.utils.adapter
import
get_raw_samples
,
read_image
class
BaseDataset
:
"""
Create dataset.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def
__init__
(
self
,
data_url
,
usage
):
self
.
data_url
=
data_url
self
.
usage
=
usage
self
.
cur_index
=
0
self
.
samples
=
[]
_s_time
=
time
.
time
()
self
.
_load_samples
()
_e_time
=
time
.
time
()
print
(
f
"load samples success~, time cost =
{
_e_time
-
_s_time
}
"
)
def
__getitem__
(
self
,
item
):
sample
=
self
.
samples
[
item
]
return
self
.
_next_data
(
sample
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
@
staticmethod
def
_next_data
(
sample
):
image_path
=
sample
[
0
]
mask_image_path
=
sample
[
1
]
image
=
read_image
(
image_path
)
mask_image
=
read_image
(
mask_image_path
)
return
[
image
,
mask_image
]
@
abc
.
abstractmethod
def
_load_samples
(
self
):
pass
class
HwVocRawDataset
(
BaseDataset
):
"""
Create dataset with raw data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def
__init__
(
self
,
data_url
,
usage
=
"train"
):
super
().
__init__
(
data_url
,
usage
)
def
_load_samples
(
self
):
try
:
self
.
samples
=
get_raw_samples
(
os
.
path
.
join
(
self
.
data_url
,
self
.
usage
))
except
Exception
as
e
:
print
(
"load HwVocRawDataset failed!!!"
)
raise
e
example/deeplabv3_voc2012/src/losses.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""OhemLoss."""
import
mindspore.nn
as
nn
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
class
OhemLoss
(
nn
.
Cell
):
"""Ohem loss cell."""
def
__init__
(
self
,
num
,
ignore_label
):
super
(
OhemLoss
,
self
).
__init__
()
self
.
mul
=
P
.
Mul
()
self
.
shape
=
P
.
Shape
()
self
.
one_hot
=
nn
.
OneHot
(
-
1
,
num
,
1.0
,
0.0
)
self
.
squeeze
=
P
.
Squeeze
()
self
.
num
=
num
self
.
cross_entropy
=
P
.
SoftmaxCrossEntropyWithLogits
()
self
.
mean
=
P
.
ReduceMean
()
self
.
select
=
P
.
Select
()
self
.
reshape
=
P
.
Reshape
()
self
.
cast
=
P
.
Cast
()
self
.
not_equal
=
P
.
NotEqual
()
self
.
equal
=
P
.
Equal
()
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
False
)
self
.
fill
=
P
.
Fill
()
self
.
transpose
=
P
.
Transpose
()
self
.
ignore_label
=
ignore_label
self
.
loss_weight
=
1.0
def
construct
(
self
,
logits
,
labels
):
logits
=
self
.
transpose
(
logits
,
(
0
,
2
,
3
,
1
))
logits
=
self
.
reshape
(
logits
,
(
-
1
,
self
.
num
))
labels
=
F
.
cast
(
labels
,
mstype
.
int32
)
labels
=
self
.
reshape
(
labels
,
(
-
1
,))
one_hot_labels
=
self
.
one_hot
(
labels
)
losses
=
self
.
cross_entropy
(
logits
,
one_hot_labels
)[
0
]
weights
=
self
.
cast
(
self
.
not_equal
(
labels
,
self
.
ignore_label
),
mstype
.
float32
)
*
self
.
loss_weight
weighted_losses
=
self
.
mul
(
losses
,
weights
)
loss
=
self
.
reduce_sum
(
weighted_losses
,
(
0
,))
zeros
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
weights
),
0.0
)
ones
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
weights
),
1.0
)
present
=
self
.
select
(
self
.
equal
(
weights
,
zeros
),
zeros
,
ones
)
present
=
self
.
reduce_sum
(
present
,
(
0
,))
zeros
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
present
),
0.0
)
min_control
=
self
.
fill
(
mstype
.
float32
,
self
.
shape
(
present
),
1.0
)
present
=
self
.
select
(
self
.
equal
(
present
,
zeros
),
min_control
,
present
)
loss
=
loss
/
present
return
loss
example/deeplabv3_voc2012/src/md_dataset.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset module."""
from
PIL
import
Image
import
mindspore.dataset
as
de
import
mindspore.dataset.transforms.vision.c_transforms
as
C
from
.ei_dataset
import
HwVocRawDataset
from
.utils
import
custom_transforms
as
tr
class
DataTransform
:
"""Transform dataset for DeepLabV3."""
def
__init__
(
self
,
args
,
usage
):
self
.
args
=
args
self
.
usage
=
usage
def
__call__
(
self
,
image
,
label
):
if
self
.
usage
==
"train"
:
return
self
.
_train
(
image
,
label
)
if
self
.
usage
==
"eval"
:
return
self
.
_eval
(
image
,
label
)
return
None
def
_train
(
self
,
image
,
label
):
"""
Process training data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image
=
Image
.
fromarray
(
image
)
label
=
Image
.
fromarray
(
label
)
rsc_tr
=
tr
.
RandomScaleCrop
(
base_size
=
self
.
args
.
base_size
,
crop_size
=
self
.
args
.
crop_size
)
image
,
label
=
rsc_tr
(
image
,
label
)
rhf_tr
=
tr
.
RandomHorizontalFlip
()
image
,
label
=
rhf_tr
(
image
,
label
)
nor_tr
=
tr
.
Normalize
(
mean
=
(
0.485
,
0.456
,
0.406
),
std
=
(
0.229
,
0.224
,
0.225
))
image
,
label
=
nor_tr
(
image
,
label
)
return
image
,
label
def
_eval
(
self
,
image
,
label
):
"""
Process eval data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image
=
Image
.
fromarray
(
image
)
label
=
Image
.
fromarray
(
label
)
fsc_tr
=
tr
.
FixScaleCrop
(
crop_size
=
self
.
args
.
crop_size
)
image
,
label
=
fsc_tr
(
image
,
label
)
nor_tr
=
tr
.
Normalize
(
mean
=
(
0.485
,
0.456
,
0.406
),
std
=
(
0.229
,
0.224
,
0.225
))
image
,
label
=
nor_tr
(
image
,
label
)
return
image
,
label
def
create_dataset
(
args
,
data_url
,
epoch_num
=
1
,
batch_size
=
1
,
usage
=
"train"
):
"""
Create Dataset for DeepLabV3.
Args:
args (dict): Train parameters.
data_url (str): Dataset path.
epoch_num (int): Epoch of dataset (default=1).
batch_size (int): Batch size of dataset (default=1).
usage (str): Whether is use to train or eval (default='train').
Returns:
Dataset.
"""
# create iter dataset
dataset
=
HwVocRawDataset
(
data_url
,
usage
=
usage
)
dataset_len
=
len
(
dataset
)
# wrapped with GeneratorDataset
dataset
=
de
.
GeneratorDataset
(
dataset
,
[
"image"
,
"label"
],
sampler
=
None
)
dataset
.
set_dataset_size
(
dataset_len
)
dataset
=
dataset
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
DataTransform
(
args
,
usage
=
usage
))
channelswap_op
=
C
.
HWC2CHW
()
dataset
=
dataset
.
map
(
input_columns
=
"image"
,
operations
=
channelswap_op
)
# 1464 samples / batch_size 8 = 183 batches
# epoch_num is num of steps
# 3658 steps / 183 = 20 epochs
if
usage
==
"train"
:
dataset
=
dataset
.
shuffle
(
1464
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
(
usage
==
"train"
))
dataset
=
dataset
.
repeat
(
count
=
epoch_num
)
dataset
.
map_model
=
4
return
dataset
example/deeplabv3_voc2012/src/miou_precision.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mIou."""
import
numpy
as
np
from
mindspore.nn.metrics.metric
import
Metric
def
confuse_matrix
(
target
,
pred
,
n
):
k
=
(
target
>=
0
)
&
(
target
<
n
)
return
np
.
bincount
(
n
*
target
[
k
].
astype
(
int
)
+
pred
[
k
],
minlength
=
n
**
2
).
reshape
(
n
,
n
)
def
iou
(
hist
):
denominator
=
hist
.
sum
(
1
)
+
hist
.
sum
(
0
)
-
np
.
diag
(
hist
)
res
=
np
.
diag
(
hist
)
/
np
.
where
(
denominator
>
0
,
denominator
,
1
)
res
=
np
.
sum
(
res
)
/
np
.
count_nonzero
(
denominator
)
return
res
class
MiouPrecision
(
Metric
):
"""Calculate miou precision."""
def
__init__
(
self
,
num_class
=
21
):
super
(
MiouPrecision
,
self
).
__init__
()
if
not
isinstance
(
num_class
,
int
):
raise
TypeError
(
'num_class should be integer type, but got {}'
.
format
(
type
(
num_class
)))
if
num_class
<
1
:
raise
ValueError
(
'num_class must be at least 1, but got {}'
.
format
(
num_class
))
self
.
_num_class
=
num_class
self
.
_mIoU
=
[]
self
.
clear
()
def
clear
(
self
):
self
.
_hist
=
np
.
zeros
((
self
.
_num_class
,
self
.
_num_class
))
self
.
_mIoU
=
[]
def
update
(
self
,
*
inputs
):
if
len
(
inputs
)
!=
2
:
raise
ValueError
(
'Need 2 inputs (y_pred, y), but got {}'
.
format
(
len
(
inputs
)))
predict_in
=
self
.
_convert_data
(
inputs
[
0
])
label_in
=
self
.
_convert_data
(
inputs
[
1
])
if
predict_in
.
shape
[
1
]
!=
self
.
_num_class
:
raise
ValueError
(
'Class number not match, last input data contain {} classes, but current data contain {} '
'classes'
.
format
(
self
.
_num_class
,
predict_in
.
shape
[
1
]))
pred
=
np
.
argmax
(
predict_in
,
axis
=
1
)
label
=
label_in
if
len
(
label
.
flatten
())
!=
len
(
pred
.
flatten
()):
print
(
'Skipping: len(gt) = {:d}, len(pred) = {:d}'
.
format
(
len
(
label
.
flatten
()),
len
(
pred
.
flatten
())))
raise
ValueError
(
'Class number not match, last input data contain {} classes, but current data contain {} '
'classes'
.
format
(
self
.
_num_class
,
predict_in
.
shape
[
1
]))
self
.
_hist
=
confuse_matrix
(
label
.
flatten
(),
pred
.
flatten
(),
self
.
_num_class
)
mIoUs
=
iou
(
self
.
_hist
)
self
.
_mIoU
.
append
(
mIoUs
)
def
eval
(
self
):
"""
Computes the mIoU categorical accuracy.
"""
mIoU
=
np
.
nanmean
(
self
.
_mIoU
)
print
(
'mIoU = {}'
.
format
(
mIoU
))
return
mIoU
example/deeplabv3_voc2012/src/utils/__init__.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
example/deeplabv3_voc2012/src/utils/adapter.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Adapter dataset."""
import
fnmatch
import
io
import
os
import
numpy
as
np
from
PIL
import
Image
from
..utils
import
file_io
def
get_raw_samples
(
data_url
):
"""
Get dataset from raw data.
Args:
data_url (str): Dataset path.
Returns:
list, a file list.
"""
def
_list_files
(
dir_path
,
pattern
):
full_files
=
[]
_
,
_
,
files
=
next
(
file_io
.
walk
(
dir_path
))
for
f
in
files
:
if
fnmatch
.
fnmatch
(
f
.
lower
(),
pattern
.
lower
()):
full_files
.
append
(
os
.
path
.
join
(
dir_path
,
f
))
return
full_files
img_files
=
_list_files
(
os
.
path
.
join
(
data_url
,
"Images"
),
"*.jpg"
)
seg_files
=
_list_files
(
os
.
path
.
join
(
data_url
,
"SegmentationClassRaw"
),
"*.png"
)
files
=
[]
for
img_file
in
img_files
:
_
,
file_name
=
os
.
path
.
split
(
img_file
)
name
,
_
=
os
.
path
.
splitext
(
file_name
)
seg_file
=
os
.
path
.
join
(
data_url
,
"SegmentationClassRaw"
,
"."
.
join
([
name
,
"png"
]))
if
seg_file
in
seg_files
:
files
.
append
([
img_file
,
seg_file
])
return
files
def
read_image
(
img_path
):
"""
Read image from file.
Args:
img_path (str): image path.
"""
img
=
file_io
.
read
(
img_path
.
strip
(),
binary
=
True
)
data
=
io
.
BytesIO
(
img
)
img
=
Image
.
open
(
data
)
return
np
.
array
(
img
)
example/deeplabv3_voc2012/src/utils/custom_transforms.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Random process dataset."""
import
random
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
ImageFilter
class
Normalize
:
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def
__init__
(
self
,
mean
=
(
0.
,
0.
,
0.
),
std
=
(
1.
,
1.
,
1.
)):
self
.
mean
=
mean
self
.
std
=
std
def
__call__
(
self
,
img
,
mask
):
img
=
np
.
array
(
img
).
astype
(
np
.
float32
)
mask
=
np
.
array
(
mask
).
astype
(
np
.
float32
)
return
img
,
mask
class
RandomHorizontalFlip
:
"""Randomly decide whether to horizontal flip."""
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
img
=
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
mask
=
mask
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
img
,
mask
class
RandomRotate
:
"""
Randomly decide whether to rotate.
Args:
degree (float): The degree of rotate.
"""
def
__init__
(
self
,
degree
):
self
.
degree
=
degree
def
__call__
(
self
,
img
,
mask
):
rotate_degree
=
random
.
uniform
(
-
1
*
self
.
degree
,
self
.
degree
)
img
=
img
.
rotate
(
rotate_degree
,
Image
.
BILINEAR
)
mask
=
mask
.
rotate
(
rotate_degree
,
Image
.
NEAREST
)
return
img
,
mask
class
RandomGaussianBlur
:
"""Randomly decide whether to filter image with gaussian blur."""
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
img
=
img
.
filter
(
ImageFilter
.
GaussianBlur
(
radius
=
random
.
random
()))
return
img
,
mask
class
RandomScaleCrop
:
"""Randomly decide whether to scale and crop image."""
def
__init__
(
self
,
base_size
,
crop_size
,
fill
=
0
):
self
.
base_size
=
base_size
self
.
crop_size
=
crop_size
self
.
fill
=
fill
def
__call__
(
self
,
img
,
mask
):
# random scale (short edge)
short_size
=
random
.
randint
(
int
(
self
.
base_size
*
0.5
),
int
(
self
.
base_size
*
2.0
))
w
,
h
=
img
.
size
if
h
>
w
:
ow
=
short_size
oh
=
int
(
1.0
*
h
*
ow
/
w
)
else
:
oh
=
short_size
ow
=
int
(
1.0
*
w
*
oh
/
h
)
img
=
img
.
resize
((
ow
,
oh
),
Image
.
BILINEAR
)
mask
=
mask
.
resize
((
ow
,
oh
),
Image
.
NEAREST
)
# pad crop
if
short_size
<
self
.
crop_size
:
padh
=
self
.
crop_size
-
oh
if
oh
<
self
.
crop_size
else
0
padw
=
self
.
crop_size
-
ow
if
ow
<
self
.
crop_size
else
0
img
=
ImageOps
.
expand
(
img
,
border
=
(
0
,
0
,
padw
,
padh
),
fill
=
0
)
mask
=
ImageOps
.
expand
(
mask
,
border
=
(
0
,
0
,
padw
,
padh
),
fill
=
self
.
fill
)
# random crop crop_size
w
,
h
=
img
.
size
x1
=
random
.
randint
(
0
,
w
-
self
.
crop_size
)
y1
=
random
.
randint
(
0
,
h
-
self
.
crop_size
)
img
=
img
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
mask
=
mask
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
return
img
,
mask
class
FixScaleCrop
:
"""Scale and crop image with fixing size."""
def
__init__
(
self
,
crop_size
):
self
.
crop_size
=
crop_size
def
__call__
(
self
,
img
,
mask
):
w
,
h
=
img
.
size
if
w
>
h
:
oh
=
self
.
crop_size
ow
=
int
(
1.0
*
w
*
oh
/
h
)
else
:
ow
=
self
.
crop_size
oh
=
int
(
1.0
*
h
*
ow
/
w
)
img
=
img
.
resize
((
ow
,
oh
),
Image
.
BILINEAR
)
mask
=
mask
.
resize
((
ow
,
oh
),
Image
.
NEAREST
)
# center crop
w
,
h
=
img
.
size
x1
=
int
(
round
((
w
-
self
.
crop_size
)
/
2.
))
y1
=
int
(
round
((
h
-
self
.
crop_size
)
/
2.
))
img
=
img
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
mask
=
mask
.
crop
((
x1
,
y1
,
x1
+
self
.
crop_size
,
y1
+
self
.
crop_size
))
return
img
,
mask
class
FixedResize
:
"""Resize image with fixing size."""
def
__init__
(
self
,
size
):
self
.
size
=
(
size
,
size
)
def
__call__
(
self
,
img
,
mask
):
assert
img
.
size
==
mask
.
size
img
=
img
.
resize
(
self
.
size
,
Image
.
BILINEAR
)
mask
=
mask
.
resize
(
self
.
size
,
Image
.
NEAREST
)
return
img
,
mask
example/deeplabv3_voc2012/src/utils/file_io.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""File operation module."""
import
os
def
_is_obs
(
url
):
return
url
.
startswith
(
"obs://"
)
or
url
.
startswith
(
"s3://"
)
def
read
(
url
,
binary
=
False
):
if
_is_obs
(
url
):
# TODO read cloud file.
return
None
with
open
(
url
,
"rb"
if
binary
else
"r"
)
as
f
:
return
f
.
read
()
def
walk
(
url
):
if
_is_obs
(
url
):
# TODO read cloud file.
return
None
return
os
.
walk
(
url
)
example/deeplabv3_voc2012/train.py
0 → 100644
浏览文件 @
0a97cb8a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train."""
import
argparse
from
mindspore
import
context
from
mindspore.communication.management
import
init
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore
import
Model
,
ParallelMode
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
Callback
,
CheckpointConfig
,
ModelCheckpoint
,
TimeMonitor
from
src.md_dataset
import
create_dataset
from
src.losses
import
OhemLoss
from
src.deeplabv3
import
deeplabv3_resnet50
from
src.config
import
config
parser
=
argparse
.
ArgumentParser
(
description
=
"Deeplabv3 training"
)
parser
.
add_argument
(
"--distribute"
,
type
=
str
,
default
=
"false"
,
help
=
"Run distribute, default is false."
)
parser
.
add_argument
(
'--epoch_size'
,
type
=
int
,
default
=
6
,
help
=
'Epoch size.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
2
,
help
=
'Batch size.'
)
parser
.
add_argument
(
'--data_url'
,
required
=
True
,
default
=
None
,
help
=
'Train data url'
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
'--checkpoint_url'
,
default
=
None
,
help
=
'Checkpoint path'
)
parser
.
add_argument
(
"--enable_save_ckpt"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable save checkpoint, default is true."
)
parser
.
add_argument
(
"--save_checkpoint_steps"
,
type
=
int
,
default
=
1000
,
help
=
"Save checkpoint steps, default is 1000."
)
parser
.
add_argument
(
"--save_checkpoint_num"
,
type
=
int
,
default
=
1
,
help
=
"Save checkpoint numbers, default is 1."
)
args_opt
=
parser
.
parse_args
()
print
(
args_opt
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
args_opt
.
device_id
)
class
LossCallBack
(
Callback
):
"""
Monitor the loss in training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def
__init__
(
self
,
per_print_times
=
1
):
super
(
LossCallBack
,
self
).
__init__
()
if
not
isinstance
(
per_print_times
,
int
)
or
per_print_times
<
0
:
raise
ValueError
(
"print_step must be int and >= 0"
)
self
.
_per_print_times
=
per_print_times
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
print
(
"epoch: {}, step: {}, outputs are {}"
.
format
(
cb_params
.
cur_epoch_num
,
cb_params
.
cur_step_num
,
str
(
cb_params
.
net_outputs
)))
def
model_fine_tune
(
flags
,
train_net
,
fix_weight_layer
):
checkpoint_path
=
flags
.
checkpoint_url
if
checkpoint_path
is
None
:
return
param_dict
=
load_checkpoint
(
checkpoint_path
)
load_param_into_net
(
train_net
,
param_dict
)
for
para
in
train_net
.
trainable_params
():
if
fix_weight_layer
in
para
.
name
:
para
.
requires_grad
=
False
if
__name__
==
"__main__"
:
if
args_opt
.
distribute
==
"true"
:
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
init
()
args_opt
.
base_size
=
config
.
crop_size
args_opt
.
crop_size
=
config
.
crop_size
train_dataset
=
create_dataset
(
args_opt
,
args_opt
.
data_url
,
args_opt
.
epoch_size
,
args_opt
.
batch_size
,
usage
=
"train"
)
dataset_size
=
train_dataset
.
get_dataset_size
()
time_cb
=
TimeMonitor
(
data_size
=
dataset_size
)
callback
=
[
time_cb
,
LossCallBack
()]
if
args_opt
.
enable_save_ckpt
==
"true"
:
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
args_opt
.
save_checkpoint_steps
,
keep_checkpoint_max
=
args_opt
.
save_checkpoint_num
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
'checkpoint_deeplabv3'
,
config
=
config_ck
)
callback
.
append
(
ckpoint_cb
)
net
=
deeplabv3_resnet50
(
config
.
seg_num_classes
,
[
args_opt
.
batch_size
,
3
,
args_opt
.
crop_size
,
args_opt
.
crop_size
],
infer_scale_sizes
=
config
.
eval_scales
,
atrous_rates
=
config
.
atrous_rates
,
decoder_output_stride
=
config
.
decoder_output_stride
,
output_stride
=
config
.
output_stride
,
fine_tune_batch_norm
=
config
.
fine_tune_batch_norm
,
image_pyramid
=
config
.
image_pyramid
)
net
.
set_train
()
model_fine_tune
(
args_opt
,
net
,
'layer'
)
loss
=
OhemLoss
(
config
.
seg_num_classes
,
config
.
ignore_label
)
opt
=
Momentum
(
filter
(
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
and
'depth'
not
in
x
.
name
and
'bias'
not
in
x
.
name
,
net
.
trainable_params
()),
learning_rate
=
config
.
learning_rate
,
momentum
=
config
.
momentum
,
weight_decay
=
config
.
weight_decay
)
model
=
Model
(
net
,
loss
,
opt
)
model
.
train
(
args_opt
.
epoch_size
,
train_dataset
,
callback
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录