From eb9e2ed78b11e44f78ba3b97c7e4a6c27cdb6865 Mon Sep 17 00:00:00 2001 From: dessyang Date: Fri, 14 Aug 2020 15:22:32 -0400 Subject: [PATCH] Add an example of training NASNet in MindSpore fix pylint --- model_zoo/official/cv/nasnet/README.md | 111 +++ model_zoo/official/cv/nasnet/eval.py | 53 + model_zoo/official/cv/nasnet/export.py | 39 + .../scripts/run_distribute_train_for_gpu.sh | 17 + .../cv/nasnet/scripts/run_eval_for_gpu.sh | 19 + .../scripts/run_standalone_train_for_gpu.sh | 19 + model_zoo/official/cv/nasnet/src/config.py | 56 ++ model_zoo/official/cv/nasnet/src/dataset.py | 70 ++ model_zoo/official/cv/nasnet/src/loss.py | 38 + .../official/cv/nasnet/src/lr_generator.py | 43 + .../official/cv/nasnet/src/nasnet_a_mobile.py | 937 ++++++++++++++++++ model_zoo/official/cv/nasnet/train.py | 117 +++ 12 files changed, 1519 insertions(+) create mode 100755 model_zoo/official/cv/nasnet/README.md create mode 100755 model_zoo/official/cv/nasnet/eval.py create mode 100755 model_zoo/official/cv/nasnet/export.py create mode 100755 model_zoo/official/cv/nasnet/scripts/run_distribute_train_for_gpu.sh create mode 100755 model_zoo/official/cv/nasnet/scripts/run_eval_for_gpu.sh create mode 100755 model_zoo/official/cv/nasnet/scripts/run_standalone_train_for_gpu.sh create mode 100755 model_zoo/official/cv/nasnet/src/config.py create mode 100755 model_zoo/official/cv/nasnet/src/dataset.py create mode 100755 model_zoo/official/cv/nasnet/src/loss.py create mode 100755 model_zoo/official/cv/nasnet/src/lr_generator.py create mode 100755 model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py create mode 100755 model_zoo/official/cv/nasnet/train.py diff --git a/model_zoo/official/cv/nasnet/README.md b/model_zoo/official/cv/nasnet/README.md new file mode 100755 index 000000000..f9bf9d17a --- /dev/null +++ b/model_zoo/official/cv/nasnet/README.md @@ -0,0 +1,111 @@ +# NASNet Example + +## Description + +This is an example of training NASNet-A-Mobile in MindSpore. + +## Requirements + +- Install [Mindspore](http://www.mindspore.cn/install/en). +- Download the dataset. + +## Structure + +```shell +. +└─nasnet + ├─README.md + ├─scripts + ├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p) + ├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p) + └─run_eval_for_gpu.sh # launch evaluating with gpu platform + ├─src + ├─config.py # parameter configuration + ├─dataset.py # data preprocessing + ├─loss.py # Customized CrossEntropy loss function + ├─lr_generator.py # learning rate generator + ├─nasnet_a_mobile.py # network definition + ├─eval.py # eval net + ├─export.py # convert checkpoint + └─train.py # train net + +``` + +## Parameter Configuration + +Parameters for both training and evaluating can be set in config.py + +``` +'random_seed': 1, # fix random seed +'rank': 0, # local rank of distributed +'group_size': 1, # world size of distributed +'work_nums': 8, # number of workers to read the data +'epoch_size': 250, # total epoch numbers +'keep_checkpoint_max': 100, # max numbers to keep checkpoints +'ckpt_path': './checkpoint/', # save checkpoint path +'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters +'batch_size': 32, # input batchsize +'num_classes': 1000, # dataset class numbers +'label_smooth_factor': 0.1, # label smoothing factor +'aux_factor': 0.4, # loss factor of aux logit +'lr_init': 0.04, # initiate learning rate +'lr_decay_rate': 0.97, # decay rate of learning rate +'num_epoch_per_decay': 2.4, # decay epoch number +'weight_decay': 0.00004, # weight decay +'momentum': 0.9, # momentum +'opt_eps': 1.0, # epsilon +'rmsprop_decay': 0.9, # rmsprop decay +'loss_scale': 1, # loss scale + +``` + + + +## Running the example + +### Train + +#### Usage + +``` +# distribute training example(8p) +sh run_distribute_train_for_gpu.sh DATA_DIR +# standalone training +sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR +``` + +#### Launch + +```bash +# distributed training example(8p) for GPU +sh scripts/run_distribute_train_for_gpu.sh /dataset/train +# standalone training example for GPU +sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train +``` + +#### Result + +You can find checkpoint file together with result in log. + +### Evaluation + +#### Usage + +``` +# Evaluation +sh run_eval_for_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT +``` + +#### Launch + +```bash +# Evaluation with checkpoint +sh scripts/run_eval_for_gpu.sh 0 /dataset/val ./checkpoint/nasnet-a-mobile-rank0-248_10009.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. + diff --git a/model_zoo/official/cv/nasnet/eval.py b/model_zoo/official/cv/nasnet/eval.py new file mode 100755 index 000000000..822975920 --- /dev/null +++ b/model_zoo/official/cv/nasnet/eval.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate imagenet""" +import argparse +import os + +import mindspore.nn as nn +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import nasnet_a_mobile_config_gpu as cfg +from src.dataset import create_dataset +from src.nasnet_a_mobile import NASNetAMobile +from src.loss import CrossEntropy_Val + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification evaluation') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of nasnet_a_mobile (Default: None)') + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + args_opt = parser.parse_args() + + if args_opt.platform == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform) + net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False) + ckpt = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, ckpt) + net.set_train(False) + dataset = create_dataset(args_opt.dataset_path, cfg, False) + loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + model = Model(net, loss, optimizer=None, metrics=eval_metrics) + metrics = model.eval(dataset) + print("metric: ", metrics) diff --git a/model_zoo/official/cv/nasnet/export.py b/model_zoo/official/cv/nasnet/export.py new file mode 100755 index 000000000..9afa84a22 --- /dev/null +++ b/model_zoo/official/cv/nasnet/export.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into geir and onnx models################# +""" +import argparse +import numpy as np + +import mindspore as ms +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export + +from src.config import nasnet_a_mobile_config_gpu as cfg +from src.nasnet_a_mobile import NASNetAMobile + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='checkpoint export') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of nasnet_a_mobile (Default: None)') + args_opt = parser.parse_args() + + net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False) + param_dict = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, cfg.image_size, cfg.image_size]), ms.float32) + export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX") + export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR") diff --git a/model_zoo/official/cv/nasnet/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/nasnet/scripts/run_distribute_train_for_gpu.sh new file mode 100755 index 000000000..305f1dcff --- /dev/null +++ b/model_zoo/official/cv/nasnet/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_DIR=$1 +mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & diff --git a/model_zoo/official/cv/nasnet/scripts/run_eval_for_gpu.sh b/model_zoo/official/cv/nasnet/scripts/run_eval_for_gpu.sh new file mode 100755 index 000000000..0ecd63a43 --- /dev/null +++ b/model_zoo/official/cv/nasnet/scripts/run_eval_for_gpu.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT=$3 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 & diff --git a/model_zoo/official/cv/nasnet/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/nasnet/scripts/run_standalone_train_for_gpu.sh new file mode 100755 index 000000000..7b856bbcf --- /dev/null +++ b/model_zoo/official/cv/nasnet/scripts/run_standalone_train_for_gpu.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_DIR=$2 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & + diff --git a/model_zoo/official/cv/nasnet/src/config.py b/model_zoo/official/cv/nasnet/src/config.py new file mode 100755 index 000000000..2646e600e --- /dev/null +++ b/model_zoo/official/cv/nasnet/src/config.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in main.py +""" +from easydict import EasyDict as edict + + +nasnet_a_mobile_config_gpu = edict({ + 'random_seed': 1, + 'rank': 0, + 'group_size': 1, + 'work_nums': 8, + 'epoch_size': 312, + 'keep_checkpoint_max': 100, + 'ckpt_path': './nasnet_a_mobile_checkpoint/', + 'is_save_on_master': 0, + + ### Dataset Config + 'batch_size': 32, + 'image_size': 224, + 'num_classes': 1000, + + ### Loss Config + 'label_smooth_factor': 0.1, + 'aux_factor': 0.4, + + ### Learning Rate Config + # 'lr_decay_method': 'exponential', + 'lr_init': 0.04, + 'lr_decay_rate': 0.97, + 'num_epoch_per_decay': 2.4, + + ### Optimization Config + 'weight_decay': 0.00004, + 'momentum': 0.9, + 'opt_eps': 1.0, + 'rmsprop_decay': 0.9, + "loss_scale": 1, + + ### onnx&air Config + 'onnx_filename': 'nasnet_a_mobile.onnx', + 'air_filename': 'nasnet_a_mobile.air' +}) diff --git a/model_zoo/official/cv/nasnet/src/dataset.py b/model_zoo/official/cv/nasnet/src/dataset.py new file mode 100755 index 000000000..c5e2d0303 --- /dev/null +++ b/model_zoo/official/cv/nasnet/src/dataset.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.c_transforms as C + + +def create_dataset(dataset_path, config, do_train, repeat_num=1): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + config(dict): config of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1. + + Returns: + dataset + """ + rank = config.rank + group_size = config.group_size + if group_size == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=config.work_nums, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=config.work_nums, shuffle=True, + num_shards=group_size, shard_id=rank) + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(config.image_size), + C.RandomHorizontalFlip(prob=0.5), + C.RandomColorAdjust(brightness=0.4, saturation=0.5) # fast mode + #C.RandomColorAdjust(brightness=0.4, contrast=0.5, saturation=0.5, hue=0.2) + ] + else: + trans = [ + C.Decode(), + C.Resize(int(config.image_size/0.875)), + C.CenterCrop(config.image_size) + ] + trans += [ + C.Rescale(1.0 / 255.0, 0.0), + C.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + C.HWC2CHW() + ] + type_cast_op = C2.TypeCast(mstype.int32) + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums) + # apply batch operations + ds = ds.batch(config.batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + return ds diff --git a/model_zoo/official/cv/nasnet/src/loss.py b/model_zoo/official/cv/nasnet/src/loss.py new file mode 100755 index 000000000..ea4ae4e40 --- /dev/null +++ b/model_zoo/official/cv/nasnet/src/loss.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define evaluation loss function for network.""" +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropy_Val(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" + def __init__(self, smooth_factor=0, num_classes=1000): + super(CrossEntropy_Val, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss_logit = self.ce(logits, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + return loss_logit diff --git a/model_zoo/official/cv/nasnet/src/lr_generator.py b/model_zoo/official/cv/nasnet/src/lr_generator.py new file mode 100755 index 000000000..f3a6d3135 --- /dev/null +++ b/model_zoo/official/cv/nasnet/src/lr_generator.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate exponential decay generator""" +import math +import numpy as np + + +def get_lr(lr_init, lr_decay_rate, num_epoch_per_decay, total_epochs, steps_per_epoch, is_stair=False): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_decay_rate (float): + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + is_stair(bool): If `True` decay the learning rate at discrete intervals + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + decay_steps = steps_per_epoch * num_epoch_per_decay + for i in range(total_steps): + p = i/decay_steps + if is_stair: + p = math.floor(p) + lr_each_step.append(lr_init * math.pow(lr_decay_rate, p)) + learning_rate = np.array(lr_each_step).astype(np.float32) + return learning_rate diff --git a/model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py b/model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py new file mode 100755 index 000000000..fbf13ec99 --- /dev/null +++ b/model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py @@ -0,0 +1,937 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""NASNet-A-Mobile model definition""" +import numpy as np + +from mindspore import Tensor +import mindspore.nn as nn +from mindspore.nn.loss.loss import _Loss +import mindspore.ops.operations as P +import mindspore.ops.functional as F +import mindspore.ops.composite as C +import mindspore.common.dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean + + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 10.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") + + +# pylint: disable=consider-using-in +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor]: clipped gradients. + """ + if clip_type != 0 and clip_type != 1: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): + super(CrossEntropy, self).__init__() + self.factor = factor + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + logit, aux = logits + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss_logit = self.ce(logit, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) + loss_aux = self.ce(aux, one_hot_label_aux) + loss_aux = self.mean(loss_aux, 0) + return loss_logit + self.factor*loss_aux + + +class AuxLogits(nn.Cell): + + def __init__(self, in_channels, out_channels, name=None): + super(AuxLogits, self).__init__() + self.relu = nn.ReLU() + self.pool = nn.AvgPool2d(5, stride=3, pad_mode='valid') + self.conv = nn.Conv2d(in_channels, 128, kernel_size=1) + self.bn = nn.BatchNorm2d(128) + self.conv_1 = nn.Conv2d(128, 768, (4, 4), pad_mode='valid') + self.bn_1 = nn.BatchNorm2d(768) + self.flatten = nn.Flatten() + if name == 'large': + self.fc = nn.Dense(6912, out_channels) # large: 6912, mobile:768 + else: + self.fc = nn.Dense(768, out_channels) + + def construct(self, x): + x = self.relu(x) + x = self.pool(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.conv_1(x) + x = self.bn_1(x) + x = self.relu(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class SeparableConv2d(nn.Cell): + + def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=dw_kernel, + stride=dw_stride, pad_mode='pad', padding=dw_padding, group=in_channels, + has_bias=bias) + self.pointwise_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, + pad_mode='pad', has_bias=bias) + + def construct(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Cell): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + super(BranchSeparables, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, in_channels, kernel_size, stride, padding, bias=bias + ) + self.bn_sep_1 = nn.BatchNorm2d(num_features=in_channels, eps=0.001, momentum=0.9, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d( + in_channels, out_channels, kernel_size, 1, padding, bias=bias + ) + self.bn_sep_2 = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9, affine=True) + + def construct(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesStem(nn.Cell): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + super(BranchSeparablesStem, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, out_channels, kernel_size, stride, padding, bias=bias + ) + self.bn_sep_1 = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d( + out_channels, out_channels, kernel_size, 1, padding, bias=bias + ) + self.bn_sep_2 = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9, affine=True) + + def construct(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesReduction(BranchSeparables): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False): + BranchSeparables.__init__( + self, in_channels, out_channels, kernel_size, stride, padding, bias + ) + self.padding = nn.Pad(paddings=((0, 0), (0, 0), (z_padding, 0), (z_padding, 0)), mode="CONSTANT") + + def construct(self, x): + x = self.relu(x) + x = self.padding(x) + x = self.separable_1(x) + x = x[:, :, 1:, 1:] + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Cell): + + def __init__(self, stem_filters, num_filters=42): + super(CellStem0, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=self.stem_filters, out_channels=self.num_filters, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=self.num_filters, eps=0.001, momentum=0.9, affine=True) + ]) + + self.comb_iter_0_left = BranchSeparables( + self.num_filters, self.num_filters, 5, 2, 2 + ) + self.comb_iter_0_right = BranchSeparablesStem( + self.stem_filters, self.num_filters, 7, 2, 3, bias=False + ) + + self.comb_iter_1_left = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + self.comb_iter_1_right = BranchSeparablesStem( + self.stem_filters, self.num_filters, 7, 2, 3, bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same') + self.comb_iter_2_right = BranchSeparablesStem( + self.stem_filters, self.num_filters, 5, 2, 2, bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_4_left = BranchSeparables( + self.num_filters, self.num_filters, 3, 1, 1, bias=False + ) + self.comb_iter_4_right = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + + def construct(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4)) + return x_out + + +class CellStem1(nn.Cell): + + def __init__(self, stem_filters, num_filters): + super(CellStem1, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=2*self.num_filters, out_channels=self.num_filters, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=self.num_filters, eps=0.001, momentum=0.9, affine=True)]) + + self.relu = nn.ReLU() + self.path_1 = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid'), + nn.Conv2d(in_channels=self.stem_filters, out_channels=self.num_filters//2, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False)]) + + self.path_2 = nn.CellList([]) + self.path_2.append(nn.Pad(paddings=((0, 0), (0, 0), (0, 1), (0, 1)), mode="CONSTANT")) + self.path_2.append( + nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid') + ) + self.path_2.append( + nn.Conv2d(in_channels=self.stem_filters, out_channels=self.num_filters//2, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False) + ) + + self.final_path_bn = nn.BatchNorm2d(num_features=self.num_filters, eps=0.001, momentum=0.9, affine=True) + + self.comb_iter_0_left = BranchSeparables( + self.num_filters, + self.num_filters, + 5, + 2, + 2, + bias=False + ) + self.comb_iter_0_right = BranchSeparables( + self.num_filters, + self.num_filters, + 7, + 2, + 3, + bias=False + ) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, pad_mode='same') + self.comb_iter_1_right = BranchSeparables( + self.num_filters, + self.num_filters, + 7, + 2, + 3, + bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, pad_mode='same') + self.comb_iter_2_right = BranchSeparables( + self.num_filters, + self.num_filters, + 5, + 2, + 2, + bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_4_left = BranchSeparables( + self.num_filters, + self.num_filters, + 3, + 1, + 1, + bias=False + ) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, pad_mode='same') + self.shape = P.Shape() + + def construct(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + x_relu = self.relu(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2[0](x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2[1](x_path2) + x_path2 = self.path_2[2](x_path2) + # final path + x_right = self.final_path_bn(P.Concat(1)((x_path1, x_path2))) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4)) + return x_out + + +class FirstCell(nn.Cell): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(FirstCell, self).__init__() + self.conv_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)]) + + self.relu = nn.ReLU() + self.path_1 = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid'), + nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False)]) + + self.path_2 = nn.CellList([]) + self.path_2.append(nn.Pad(paddings=((0, 0), (0, 0), (0, 1), (0, 1)), mode="CONSTANT")) + self.path_2.append( + nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid') + ) + self.path_2.append( + nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False) + ) + + self.final_path_bn = nn.BatchNorm2d(num_features=out_channels_left*2, eps=0.001, momentum=0.9, affine=True) + + self.comb_iter_0_left = BranchSeparables( + out_channels_right, out_channels_right, 5, 1, 2, bias=False + ) + self.comb_iter_0_right = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + self.comb_iter_1_left = BranchSeparables( + out_channels_right, out_channels_right, 5, 1, 2, bias=False + ) + self.comb_iter_1_right = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_3_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_4_left = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + def construct(self, x, x_prev): + x_relu = self.relu(x_prev) + x_path1 = self.path_1(x_relu) + x_path2 = self.path_2[0](x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2[1](x_path2) + x_path2 = self.path_2[2](x_path2) + # final path + x_left = self.final_path_bn(P.Concat(1)((x_path1, x_path2))) + + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = P.Concat(1)((x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4)) + return x_out + + +class NormalCell(nn.Cell): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=out_channels_left, eps=0.001, momentum=0.9, affine=True)]) + + self.conv_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)]) + + self.comb_iter_0_left = BranchSeparables( + out_channels_right, out_channels_right, 5, 1, 2, bias=False + ) + self.comb_iter_0_right = BranchSeparables( + out_channels_left, out_channels_left, 3, 1, 1, bias=False + ) + + self.comb_iter_1_left = BranchSeparables( + out_channels_left, out_channels_left, 5, 1, 2, bias=False + ) + self.comb_iter_1_right = BranchSeparables( + out_channels_left, out_channels_left, 3, 1, 1, bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_3_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_4_left = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + def construct(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = P.Concat(1)((x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4)) + return x_out + + +class ReductionCell0(nn.Cell): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=out_channels_left, eps=0.001, momentum=0.9, affine=True)]) + + self.conv_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)]) + + self.comb_iter_0_left = BranchSeparablesReduction( + out_channels_right, out_channels_right, 5, 2, 2, bias=False + ) + self.comb_iter_0_right = BranchSeparablesReduction( + out_channels_right, out_channels_right, 7, 2, 3, bias=False + ) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, pad_mode='same') + self.comb_iter_1_right = BranchSeparablesReduction( + out_channels_right, out_channels_right, 7, 2, 3, bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, pad_mode='same') + self.comb_iter_2_right = BranchSeparablesReduction( + out_channels_right, out_channels_right, 5, 2, 2, bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_4_left = BranchSeparablesReduction( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, pad_mode='same') + + def construct(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4)) + return x_out + + +class ReductionCell1(nn.Cell): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=out_channels_left, eps=0.001, momentum=0.9, affine=True)]) + + self.conv_1x1 = nn.SequentialCell([ + nn.ReLU(), + nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1, + pad_mode='pad', has_bias=False), + nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)]) + + self.comb_iter_0_left = BranchSeparables( + out_channels_right, + out_channels_right, + 5, + 2, + 2, + bias=False + ) + self.comb_iter_0_right = BranchSeparables( + out_channels_right, + out_channels_right, + 7, + 2, + 3, + bias=False + ) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, pad_mode='same') + self.comb_iter_1_right = BranchSeparables( + out_channels_right, + out_channels_right, + 7, + 2, + 3, + bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, pad_mode='same') + self.comb_iter_2_right = BranchSeparables( + out_channels_right, + out_channels_right, + 5, + 2, + 2, + bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same') + + self.comb_iter_4_left = BranchSeparables( + out_channels_right, + out_channels_right, + 3, + 1, + 1, + bias=False + ) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, pad_mode='same') + + def construct(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4)) + return x_out + + +class NASNetAMobile(nn.Cell): + """Neural Architecture Search (NAS). + + Reference: + Zoph et al. Learning Transferable Architectures + for Scalable Image Recognition. CVPR 2018. + - ``nasnetamobile``: NASNet-A Mobile. + """ + + def __init__(self, num_classes, is_training=True, + stem_filters=32, penultimate_filters=1056, filters_multiplier=2): + super(NASNetAMobile, self).__init__() + self.is_training = is_training + self.stem_filters = stem_filters + self.penultimate_filters = penultimate_filters + self.filters_multiplier = filters_multiplier + + filters = self.penultimate_filters//24 + # 24 is default value for the architecture + + self.conv0 = nn.SequentialCell([ + nn.Conv2d(in_channels=3, out_channels=self.stem_filters, kernel_size=3, stride=2, pad_mode='pad', padding=0, + has_bias=False), + nn.BatchNorm2d(num_features=self.stem_filters, eps=0.001, momentum=0.9, affine=True) + ]) + + self.cell_stem_0 = CellStem0( + self.stem_filters, num_filters=filters//(filters_multiplier**2) + ) + self.cell_stem_1 = CellStem1( + self.stem_filters, num_filters=filters//filters_multiplier + ) + + self.cell_0 = FirstCell( + in_channels_left=filters, + out_channels_left=filters//2, # 1, 0.5 + in_channels_right=2*filters, + out_channels_right=filters + ) # 2, 1 + self.cell_1 = NormalCell( + in_channels_left=2*filters, + out_channels_left=filters, # 2, 1 + in_channels_right=6*filters, + out_channels_right=filters + ) # 6, 1 + self.cell_2 = NormalCell( + in_channels_left=6*filters, + out_channels_left=filters, # 6, 1 + in_channels_right=6*filters, + out_channels_right=filters + ) # 6, 1 + self.cell_3 = NormalCell( + in_channels_left=6*filters, + out_channels_left=filters, # 6, 1 + in_channels_right=6*filters, + out_channels_right=filters + ) # 6, 1 + + self.reduction_cell_0 = ReductionCell0( + in_channels_left=6*filters, + out_channels_left=2*filters, # 6, 2 + in_channels_right=6*filters, + out_channels_right=2*filters + ) # 6, 2 + + self.cell_6 = FirstCell( + in_channels_left=6*filters, + out_channels_left=filters, # 6, 1 + in_channels_right=8*filters, + out_channels_right=2*filters + ) # 8, 2 + self.cell_7 = NormalCell( + in_channels_left=8*filters, + out_channels_left=2*filters, # 8, 2 + in_channels_right=12*filters, + out_channels_right=2*filters + ) # 12, 2 + self.cell_8 = NormalCell( + in_channels_left=12*filters, + out_channels_left=2*filters, # 12, 2 + in_channels_right=12*filters, + out_channels_right=2*filters + ) # 12, 2 + self.cell_9 = NormalCell( + in_channels_left=12*filters, + out_channels_left=2*filters, # 12, 2 + in_channels_right=12*filters, + out_channels_right=2*filters + ) # 12, 2 + + if is_training: + self.aux_logits = AuxLogits(in_channels=12*filters, out_channels=num_classes) + + self.reduction_cell_1 = ReductionCell1( + in_channels_left=12*filters, + out_channels_left=4*filters, # 12, 4 + in_channels_right=12*filters, + out_channels_right=4*filters + ) # 12, 4 + + self.cell_12 = FirstCell( + in_channels_left=12*filters, + out_channels_left=2*filters, # 12, 2 + in_channels_right=16*filters, + out_channels_right=4*filters + ) # 16, 4 + self.cell_13 = NormalCell( + in_channels_left=16*filters, + out_channels_left=4*filters, # 16, 4 + in_channels_right=24*filters, + out_channels_right=4*filters + ) # 24, 4 + self.cell_14 = NormalCell( + in_channels_left=24*filters, + out_channels_left=4*filters, # 24, 4 + in_channels_right=24*filters, + out_channels_right=4*filters + ) # 24, 4 + self.cell_15 = NormalCell( + in_channels_left=24*filters, + out_channels_left=4*filters, # 24, 4 + in_channels_right=24*filters, + out_channels_right=4*filters + ) # 24, 4 + + self.relu = nn.ReLU() + self.dropout = nn.Dropout(keep_prob=0.5) + self.classifier = nn.Dense(in_channels=24*filters, out_channels=num_classes) + self.shape = P.Shape() + self.reshape = P.Reshape() + self._initialize_weights() + + def _initialize_weights(self): + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels + m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2./n), + m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal( + 0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + + def construct(self, x): + x_conv0 = self.conv0(x) + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_3, x_cell_2) + + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_3) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + + if self.is_training: + aux_logits = self.aux_logits(x_cell_9) + else: + aux_logits = None + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_9, x_cell_8) + + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_9) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + + x_cell_15 = self.relu(x_cell_15) + x_cell_15 = nn.AvgPool2d(F.shape(x_cell_15)[2:])(x_cell_15) # global average pool + x_cell_15 = self.reshape(x_cell_15, (self.shape(x_cell_15)[0], -1,)) + x_cell_15 = self.dropout(x_cell_15) + logits = self.classifier(x_cell_15) + + if self.is_training: + return logits, aux_logits + return logits + + +class NASNetAMobileWithLoss(nn.Cell): + """ + Provide nasnet-a-mobile training loss through network. + + Args: + config (dict): The config of nasnet-a-mobile. + is_training (bool): Specifies whether to use the training mode. + + Returns: + Tensor: the loss of the network. + """ + + def __init__(self, config, is_training=True): + super(NASNetAMobileWithLoss, self).__init__() + self.network = NASNetAMobile(config.num_classes, is_training) + self.loss = CrossEntropy(smooth_factor=config.label_smooth_factor, + num_classes=config.num_classes, factor=config.aux_factor) + self.cast = P.Cast() + + def construct(self, data, label): + prediction_scores = self.network(data) + total_loss = self.loss(prediction_scores, label) + return self.cast(total_loss, mstype.float32) + + +class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell): + + def __init__(self, network, optimizer, sens=1.0): + super(NASNetAMobileTrainOneStepWithClipGradient, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.hyper_map = C.HyperMap() + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + parallel_mode = _get_parallel_mode() + if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): + self.reducer_flag = True + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, *inputs): + weights = self.weights + loss = self.network(*inputs) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*inputs, sens) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) diff --git a/model_zoo/official/cv/nasnet/train.py b/model_zoo/official/cv/nasnet/train.py new file mode 100755 index 000000000..02e9b9f29 --- /dev/null +++ b/model_zoo/official/cv/nasnet/train.py @@ -0,0 +1,117 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train imagenet.""" +import argparse +import os +import random +import numpy as np + +from mindspore import Tensor +from mindspore import context +from mindspore import ParallelMode +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.nn.optim.rmsprop import RMSProp +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import dataset as de + +from src.config import nasnet_a_mobile_config_gpu as cfg +from src.dataset import create_dataset +from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient +from src.lr_generator import get_lr + + +random.seed(cfg.random_seed) +np.random.seed(cfg.random_seed) +de.config.set_seed(cfg.random_seed) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification training') + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') + parser.add_argument('--is_distributed', action='store_true', default=False, + help='distributed training') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + # init distributed + if args_opt.is_distributed: + if args_opt.platform == "Ascend": + init() + else: + init("nccl") + cfg.rank = get_rank() + cfg.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + cfg.rank = 0 + cfg.group_size = 1 + + # dataloader + dataset = create_dataset(args_opt.dataset_path, cfg, True) + batches_per_epoch = dataset.get_dataset_size() + + # network + net_with_loss = NASNetAMobileWithLoss(cfg) + if args_opt.resume: + ckpt = load_checkpoint(args_opt.resume) + load_param_into_net(net_with_loss, ckpt) + + # learning rate schedule + lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate, + num_epoch_per_decay=cfg.num_epoch_per_decay, total_epochs=cfg.epoch_size, + steps_per_epoch=batches_per_epoch, is_stair=True) + lr = Tensor(lr) + + # optimizer + decayed_params = [] + no_decayed_params = [] + for param in net_with_loss.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net_with_loss.trainable_params()}] + optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay, + momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale) + + net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer) + net_with_grads.set_train() + model = Model(net_with_grads) + + print("============== Starting Training ==============") + loss_cb = LossMonitor(per_print_times=batches_per_epoch) + time_cb = TimeMonitor(data_size=batches_per_epoch) + callbacks = [loss_cb, time_cb] + config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=f"nasnet-a-mobile-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck) + if args_opt.is_distributed & cfg.is_save_on_master: + if cfg.rank == 0: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + else: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + print("train success") -- GitLab