提交 0a97cb8a 编写于 作者: Y yangyongjie 提交者: unknown

add deeplabv3 to model zoo

上级 803a9159
# Deeplab-V3 Example
## Description
This is an example of training DeepLabv3 with PASCAL VOC 2012 dataset in MindSpore.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the VOC 2012 dataset for training.
> Notes:
If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
## Running the Example
### Training
- Set options in config.py.
- Run `run_standalone_train.sh` for non-distributed training.
``` bash
sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_DIR
```
- Run `run_distribute_train.sh` for distributed training.
``` bash
sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH
```
### Evaluation
Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path.
- Run run_eval.sh for evaluation.
``` bash
sh scripts/run_eval.sh DEVICE_ID DATA_DIR
```
## Options and Parameters
It contains of parameters of Deeplab-V3 model and options for training, which is set in file config.py.
### Options:
```
config.py:
learning_rate Learning rate, default is 0.0014.
weight_decay Weight decay, default is 5e-5.
momentum Momentum, default is 0.97.
crop_size Image crop size [height, width] during training, default is 513.
eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75].
output_stride The ratio of input to output spatial resolution, default is 16.
ignore_label Ignore label value, default is 255.
seg_num_classes Number of semantic classes, including the background class (if exists).
foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21.
fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False.
atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None.
decoder_output_stride The ratio of input to output spatial resolution when employing decoder
to refine segmentation results, default is None.
image_pyramid Input scales for multi-scale feature extraction, default is None.
```
### Parameters:
```
Parameters for dataset and network:
distribute Run distribute, default is false.
epoch_size Epoch size, default is 6.
batch_size batch size of input dataset: N, default is 2.
data_url Train/Evaluation data url, required.
checkpoint_url Checkpoint path, default is None.
enable_save_ckpt Enable save checkpoint, default is true.
save_checkpoint_steps Save checkpoint steps, default is 1000.
save_checkpoint_num Save checkpoint numbers, default is 1.
```
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation."""
import argparse
from mindspore import context
from mindspore import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.md_dataset import create_dataset
from src.losses import OhemLoss
from src.miou_precision import MiouPrecision
from src.deeplabv3 import deeplabv3_resnet50
from src.config import config
parser = argparse.ArgumentParser(description="Deeplabv3 evaluation")
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url')
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
print(args_opt)
if __name__ == "__main__":
args_opt.crop_size = config.crop_size
args_opt.base_size = config.crop_size
eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval")
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
param_dict = load_checkpoint(args_opt.checkpoint_url)
load_param_into_net(net, param_dict)
mIou = MiouPrecision(config.seg_num_classes)
metrics = {'mIou': mIou}
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
model = Model(net, loss, metrics=metrics)
model.eval(eval_dataset)
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: bash run_distribute_train.sh 8 40 /path/zh-wiki/ /path/hccl.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
EPOCH_SIZE=$2
DATA_DIR=$3
export MINDSPORE_HCCL_CONFIG_PATH=$4
export RANK_TABLE_FILE=$4
export RANK_SIZE=$1
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
echo "start training for rank $i, device $DEVICE_ID"
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
env > env.log
taskset -c $cmdopt python ../train.py \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_save_ckpt="true" \
--checkpoint_url="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_url=$DATA_DIR > log.txt 2>&1 &
cd ../
done
\ No newline at end of file
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_eval.sh DEVICE_ID DATA_DIR"
echo "for example: bash run_eval.sh /path/zh-wiki/ "
echo "=============================================================================================================="
DEVICE_ID=$1
DATA_DIR=$2
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python evaluation.py \
--device_id=$DEVICE_ID \
--checkpoint_url="" \
--data_url=$DATA_DIR > log.txt 2>&1 &
\ No newline at end of file
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR"
echo "for example: bash run_standalone_train.sh 0 40 /path/zh-wiki/ "
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python train.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_save_ckpt="true" \
--checkpoint_url="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_url=$DATA_DIR > log.txt 2>&1 &
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init DeepLabv3."""
from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50
from .backbone import *
__all__ = [
"ASPP", "DeepLabV3", "deeplabv3_resnet50"
]
__all__.extend(backbone.__all__)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init backbone."""
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
RootBlockBeta, resnet50_dl
__all__ = [
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", "resnet50_dl"
]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and evaluation.py
"""
from easydict import EasyDict as ed
config = ed({
"learning_rate": 0.0014,
"weight_decay": 0.00005,
"momentum": 0.97,
"crop_size": 513,
"eval_scales": [0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
"atrous_rates": None,
"image_pyramid": None,
"output_stride": 16,
"fine_tune_batch_norm": False,
"ignore_label": 255,
"decoder_output_stride": None,
"seg_num_classes": 21
})
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Process Dataset."""
import abc
import os
import time
from .utils.adapter import get_raw_samples, read_image
class BaseDataset:
"""
Create dataset.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage):
self.data_url = data_url
self.usage = usage
self.cur_index = 0
self.samples = []
_s_time = time.time()
self._load_samples()
_e_time = time.time()
print(f"load samples success~, time cost = {_e_time - _s_time}")
def __getitem__(self, item):
sample = self.samples[item]
return self._next_data(sample)
def __len__(self):
return len(self.samples)
@staticmethod
def _next_data(sample):
image_path = sample[0]
mask_image_path = sample[1]
image = read_image(image_path)
mask_image = read_image(mask_image_path)
return [image, mask_image]
@abc.abstractmethod
def _load_samples(self):
pass
class HwVocRawDataset(BaseDataset):
"""
Create dataset with raw data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage="train"):
super().__init__(data_url, usage)
def _load_samples(self):
try:
self.samples = get_raw_samples(os.path.join(self.data_url, self.usage))
except Exception as e:
print("load HwVocRawDataset failed!!!")
raise e
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""OhemLoss."""
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class OhemLoss(nn.Cell):
"""Ohem loss cell."""
def __init__(self, num, ignore_label):
super(OhemLoss, self).__init__()
self.mul = P.Mul()
self.shape = P.Shape()
self.one_hot = nn.OneHot(-1, num, 1.0, 0.0)
self.squeeze = P.Squeeze()
self.num = num
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.select = P.Select()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.not_equal = P.NotEqual()
self.equal = P.Equal()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.fill = P.Fill()
self.transpose = P.Transpose()
self.ignore_label = ignore_label
self.loss_weight = 1.0
def construct(self, logits, labels):
logits = self.transpose(logits, (0, 2, 3, 1))
logits = self.reshape(logits, (-1, self.num))
labels = F.cast(labels, mstype.int32)
labels = self.reshape(labels, (-1,))
one_hot_labels = self.one_hot(labels)
losses = self.cross_entropy(logits, one_hot_labels)[0]
weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight
weighted_losses = self.mul(losses, weights)
loss = self.reduce_sum(weighted_losses, (0,))
zeros = self.fill(mstype.float32, self.shape(weights), 0.0)
ones = self.fill(mstype.float32, self.shape(weights), 1.0)
present = self.select(self.equal(weights, zeros), zeros, ones)
present = self.reduce_sum(present, (0,))
zeros = self.fill(mstype.float32, self.shape(present), 0.0)
min_control = self.fill(mstype.float32, self.shape(present), 1.0)
present = self.select(self.equal(present, zeros), min_control, present)
loss = loss / present
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset module."""
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C
from .ei_dataset import HwVocRawDataset
from .utils import custom_transforms as tr
class DataTransform:
"""Transform dataset for DeepLabV3."""
def __init__(self, args, usage):
self.args = args
self.usage = usage
def __call__(self, image, label):
if self.usage == "train":
return self._train(image, label)
if self.usage == "eval":
return self._eval(image, label)
return None
def _train(self, image, label):
"""
Process training data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image = Image.fromarray(image)
label = Image.fromarray(label)
rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size)
image, label = rsc_tr(image, label)
rhf_tr = tr.RandomHorizontalFlip()
image, label = rhf_tr(image, label)
nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
image, label = nor_tr(image, label)
return image, label
def _eval(self, image, label):
"""
Process eval data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image = Image.fromarray(image)
label = Image.fromarray(label)
fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size)
image, label = fsc_tr(image, label)
nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
image, label = nor_tr(image, label)
return image, label
def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
"""
Create Dataset for DeepLabV3.
Args:
args (dict): Train parameters.
data_url (str): Dataset path.
epoch_num (int): Epoch of dataset (default=1).
batch_size (int): Batch size of dataset (default=1).
usage (str): Whether is use to train or eval (default='train').
Returns:
Dataset.
"""
# create iter dataset
dataset = HwVocRawDataset(data_url, usage=usage)
dataset_len = len(dataset)
# wrapped with GeneratorDataset
dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
dataset.set_dataset_size(dataset_len)
dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))
channelswap_op = C.HWC2CHW()
dataset = dataset.map(input_columns="image", operations=channelswap_op)
# 1464 samples / batch_size 8 = 183 batches
# epoch_num is num of steps
# 3658 steps / 183 = 20 epochs
if usage == "train":
dataset = dataset.shuffle(1464)
dataset = dataset.batch(batch_size, drop_remainder=(usage == "train"))
dataset = dataset.repeat(count=epoch_num)
dataset.map_model = 4
return dataset
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mIou."""
import numpy as np
from mindspore.nn.metrics.metric import Metric
def confuse_matrix(target, pred, n):
k = (target >= 0) & (target < n)
return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n)
def iou(hist):
denominator = hist.sum(1) + hist.sum(0) - np.diag(hist)
res = np.diag(hist) / np.where(denominator > 0, denominator, 1)
res = np.sum(res) / np.count_nonzero(denominator)
return res
class MiouPrecision(Metric):
"""Calculate miou precision."""
def __init__(self, num_class=21):
super(MiouPrecision, self).__init__()
if not isinstance(num_class, int):
raise TypeError('num_class should be integer type, but got {}'.format(type(num_class)))
if num_class < 1:
raise ValueError('num_class must be at least 1, but got {}'.format(num_class))
self._num_class = num_class
self._mIoU = []
self.clear()
def clear(self):
self._hist = np.zeros((self._num_class, self._num_class))
self._mIoU = []
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
predict_in = self._convert_data(inputs[0])
label_in = self._convert_data(inputs[1])
if predict_in.shape[1] != self._num_class:
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
'classes'.format(self._num_class, predict_in.shape[1]))
pred = np.argmax(predict_in, axis=1)
label = label_in
if len(label.flatten()) != len(pred.flatten()):
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
'classes'.format(self._num_class, predict_in.shape[1]))
self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class)
mIoUs = iou(self._hist)
self._mIoU.append(mIoUs)
def eval(self):
"""
Computes the mIoU categorical accuracy.
"""
mIoU = np.nanmean(self._mIoU)
print('mIoU = {}'.format(mIoU))
return mIoU
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Adapter dataset."""
import fnmatch
import io
import os
import numpy as np
from PIL import Image
from ..utils import file_io
def get_raw_samples(data_url):
"""
Get dataset from raw data.
Args:
data_url (str): Dataset path.
Returns:
list, a file list.
"""
def _list_files(dir_path, pattern):
full_files = []
_, _, files = next(file_io.walk(dir_path))
for f in files:
if fnmatch.fnmatch(f.lower(), pattern.lower()):
full_files.append(os.path.join(dir_path, f))
return full_files
img_files = _list_files(os.path.join(data_url, "Images"), "*.jpg")
seg_files = _list_files(os.path.join(data_url, "SegmentationClassRaw"), "*.png")
files = []
for img_file in img_files:
_, file_name = os.path.split(img_file)
name, _ = os.path.splitext(file_name)
seg_file = os.path.join(data_url, "SegmentationClassRaw", ".".join([name, "png"]))
if seg_file in seg_files:
files.append([img_file, seg_file])
return files
def read_image(img_path):
"""
Read image from file.
Args:
img_path (str): image path.
"""
img = file_io.read(img_path.strip(), binary=True)
data = io.BytesIO(img)
img = Image.open(data)
return np.array(img)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Random process dataset."""
import random
import numpy as np
from PIL import Image, ImageOps, ImageFilter
class Normalize:
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
self.mean = mean
self.std = std
def __call__(self, img, mask):
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
return img, mask
class RandomHorizontalFlip:
"""Randomly decide whether to horizontal flip."""
def __call__(self, img, mask):
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return img, mask
class RandomRotate:
"""
Randomly decide whether to rotate.
Args:
degree (float): The degree of rotate.
"""
def __init__(self, degree):
self.degree = degree
def __call__(self, img, mask):
rotate_degree = random.uniform(-1 * self.degree, self.degree)
img = img.rotate(rotate_degree, Image.BILINEAR)
mask = mask.rotate(rotate_degree, Image.NEAREST)
return img, mask
class RandomGaussianBlur:
"""Randomly decide whether to filter image with gaussian blur."""
def __call__(self, img, mask):
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(
radius=random.random()))
return img, mask
class RandomScaleCrop:
"""Randomly decide whether to scale and crop image."""
def __init__(self, base_size, crop_size, fill=0):
self.base_size = base_size
self.crop_size = crop_size
self.fill = fill
def __call__(self, img, mask):
# random scale (short edge)
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < self.crop_size:
padh = self.crop_size - oh if oh < self.crop_size else 0
padw = self.crop_size - ow if ow < self.crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - self.crop_size)
y1 = random.randint(0, h - self.crop_size)
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return img, mask
class FixScaleCrop:
"""Scale and crop image with fixing size."""
def __init__(self, crop_size):
self.crop_size = crop_size
def __call__(self, img, mask):
w, h = img.size
if w > h:
oh = self.crop_size
ow = int(1.0 * w * oh / h)
else:
ow = self.crop_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - self.crop_size) / 2.))
y1 = int(round((h - self.crop_size) / 2.))
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return img, mask
class FixedResize:
"""Resize image with fixing size."""
def __init__(self, size):
self.size = (size, size)
def __call__(self, img, mask):
assert img.size == mask.size
img = img.resize(self.size, Image.BILINEAR)
mask = mask.resize(self.size, Image.NEAREST)
return img, mask
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""File operation module."""
import os
def _is_obs(url):
return url.startswith("obs://") or url.startswith("s3://")
def read(url, binary=False):
if _is_obs(url):
# TODO read cloud file.
return None
with open(url, "rb" if binary else "r") as f:
return f.read()
def walk(url):
if _is_obs(url):
# TODO read cloud file.
return None
return os.walk(url)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train."""
import argparse
from mindspore import context
from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum
from mindspore import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset
from src.losses import OhemLoss
from src.deeplabv3 import deeplabv3_resnet50
from src.config import config
parser = argparse.ArgumentParser(description="Deeplabv3 training")
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument('--epoch_size', type=int, default=6, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
args_opt = parser.parse_args()
print(args_opt)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
class LossCallBack(Callback):
"""
Monitor the loss in training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0")
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
def model_fine_tune(flags, train_net, fix_weight_layer):
checkpoint_path = flags.checkpoint_url
if checkpoint_path is None:
return
param_dict = load_checkpoint(checkpoint_path)
load_param_into_net(train_net, param_dict)
for para in train_net.trainable_params():
if fix_weight_layer in para.name:
para.requires_grad = False
if __name__ == "__main__":
if args_opt.distribute == "true":
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size
train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train")
dataset_size = train_dataset.get_dataset_size()
time_cb = TimeMonitor(data_size=dataset_size)
callback = [time_cb, LossCallBack()]
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
callback.append(ckpoint_cb)
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
net.set_train()
model_fine_tune(args_opt, net, 'layer')
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
model = Model(net, loss, opt)
model.train(args_opt.epoch_size, train_dataset, callback)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册