From 913b5b03dfcf7091c9a81d9ee65343fb991a48f2 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 29 May 2020 03:17:32 +0800 Subject: [PATCH] modify --- model_zoo/deeplabv3/src/backbone/__init__.py | 8 ++++++ .../deeplabv3/src/backbone/resnet_deeplab.py | 3 +++ model_zoo/deeplabv3/src/ei_datasest.py | 25 +------------------ model_zoo/deeplabv3/src/md_dataset.py | 8 ++---- model_zoo/deeplabv3/train.py | 12 ++++----- 5 files changed, 20 insertions(+), 36 deletions(-) create mode 100644 model_zoo/deeplabv3/src/backbone/__init__.py diff --git a/model_zoo/deeplabv3/src/backbone/__init__.py b/model_zoo/deeplabv3/src/backbone/__init__.py new file mode 100644 index 000000000..4ccf21261 --- /dev/null +++ b/model_zoo/deeplabv3/src/backbone/__init__.py @@ -0,0 +1,8 @@ +from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \ + RootBlockBeta, resnet50_dl + +__all__= [ + "Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", + "resnet50_dl" +] + \ No newline at end of file diff --git a/model_zoo/deeplabv3/src/backbone/resnet_deeplab.py b/model_zoo/deeplabv3/src/backbone/resnet_deeplab.py index fff77beca..187fd7ae7 100644 --- a/model_zoo/deeplabv3/src/backbone/resnet_deeplab.py +++ b/model_zoo/deeplabv3/src/backbone/resnet_deeplab.py @@ -532,3 +532,6 @@ class RootBlockBeta(nn.Cell): x = self.conv2(x) x = self.conv3(x) return x + +class resnet50_dl(fine_tune_batch_norm=False): + return ResNetV1(fine_tune_batch_norm) diff --git a/model_zoo/deeplabv3/src/ei_datasest.py b/model_zoo/deeplabv3/src/ei_datasest.py index b2a20999d..c139bf48e 100644 --- a/model_zoo/deeplabv3/src/ei_datasest.py +++ b/model_zoo/deeplabv3/src/ei_datasest.py @@ -17,7 +17,7 @@ import abc import os import time -from .utils.adapter import get_manifest_samples, get_raw_samples, read_image +from .utils.adapter import get_raw_samples, read_image class BaseDataset(object): @@ -62,29 +62,6 @@ class BaseDataset(object): pass -class HwVocManifestDataset(BaseDataset): - """ - Create dataset with manifest data. - - Args: - data_url (str): The path of data. - usage (str): Whether to use train or eval (default='train'). - - Returns: - Dataset. - """ - - def __init__(self, data_url, usage="train"): - super().__init__(data_url, usage) - - def _load_samples(self): - try: - self.samples = get_manifest_samples(self.data_url, self.usage) - except Exception as e: - print("load HwVocManifestDataset samples failed!!!") - raise e - - class HwVocRawDataset(BaseDataset): """ Create dataset with raw data. diff --git a/model_zoo/deeplabv3/src/md_dataset.py b/model_zoo/deeplabv3/src/md_dataset.py index b18d94cc0..f800f7e12 100644 --- a/model_zoo/deeplabv3/src/md_dataset.py +++ b/model_zoo/deeplabv3/src/md_dataset.py @@ -17,7 +17,7 @@ from PIL import Image import mindspore.dataset as de import mindspore.dataset.transforms.vision.c_transforms as C -from .ei_dataset import HwVocManifestDataset, HwVocRawDataset +from .ei_dataset import HwVocRawDataset from .utils import custom_transforms as tr @@ -77,10 +77,7 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"): Dataset. """ # create iter dataset - if data_url.endswith(".manifest"): - dataset = HwVocManifestDataset(data_url, usage=usage) - else: - dataset = HwVocRawDataset(data_url, usage=usage) + dataset = HwVocRawDataset(data_url, usage=usage) dataset_len = len(dataset) # wrapped with GeneratorDataset @@ -100,5 +97,4 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"): dataset = dataset.repeat(count=epoch_num) dataset.map_model = 4 - dataset.__loop_size__ = 1 return dataset diff --git a/model_zoo/deeplabv3/train.py b/model_zoo/deeplabv3/train.py index 7ffe7879c..ed625ede6 100644 --- a/model_zoo/deeplabv3/train.py +++ b/model_zoo/deeplabv3/train.py @@ -87,13 +87,13 @@ if __name__ == "__main__": keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) callback.append(ckpoint_cb) - net = deeplabv3_resnet50(crop_size.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], - infer_scale_sizes=crop_size.eval_scales, atrous_rates=crop_size.atrous_rates, - decoder_output_stride=crop_size.decoder_output_stride, output_stride = crop_size.output_stride, - fine_tune_batch_norm=crop_size.fine_tune_batch_norm, image_pyramid = crop_size.image_pyramid) + net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], + infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, + decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride, + fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid) net.set_train() model_fine_tune(args_opt, net, 'layer') - loss = OhemLoss(crop_size.seg_num_classes, crop_size.ignore_label) - opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=args_opt.learning_rate, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay) + loss = OhemLoss(config.seg_num_classes, config.ignore_label) + opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) model = Model(net, loss, opt) model.train(args_opt.epoch_size, train_dataset, callback) \ No newline at end of file -- GitLab