From 6e304d045ccfab325575e559f2ebdaf499005cae Mon Sep 17 00:00:00 2001 From: wukesong Date: Wed, 1 Apr 2020 19:40:26 +0800 Subject: [PATCH] update lenet and alexnet --- chapter03/lenet/lenet.py | 1 - chapter04/alexnet/alexnet.py | 1 - chapter04/alexnet/main.py | 21 ++++++++++++--------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/chapter03/lenet/lenet.py b/chapter03/lenet/lenet.py index deef872..5c08ed1 100644 --- a/chapter03/lenet/lenet.py +++ b/chapter03/lenet/lenet.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ """LeNet.""" -import mindspore.ops.operations as P import mindspore.nn as nn from mindspore.common.initializer import TruncatedNormal diff --git a/chapter04/alexnet/alexnet.py b/chapter04/alexnet/alexnet.py index 5c5e723..b0c3260 100644 --- a/chapter04/alexnet/alexnet.py +++ b/chapter04/alexnet/alexnet.py @@ -14,7 +14,6 @@ # ============================================================================ """Alexnet.""" from config import alexnet_cfg as cfg -import mindspore.ops.operations as P import mindspore.nn as nn from mindspore.common.initializer import TruncatedNormal diff --git a/chapter04/alexnet/main.py b/chapter04/alexnet/main.py index 1bf940b..8ac5765 100644 --- a/chapter04/alexnet/main.py +++ b/chapter04/alexnet/main.py @@ -17,9 +17,9 @@ AlexNet example tutorial Usage: python alexnet.py with --device_target=GPU: After 20 epoch training, the accuracy is up to 80% +with --device_target=Ascend: After 10 epoch training, the accuracy is up to 81% """ -import os import argparse from config import alexnet_cfg as cfg from alexnet import AlexNet @@ -35,7 +35,7 @@ from mindspore.nn.metrics import Accuracy from mindspore.common import dtype as mstype -def create_dataset(data_path, batch_size=32, repeat_size=1): +def create_dataset(data_path, batch_size=32, repeat_size=1, mode="train"): """ create dataset for train or test """ @@ -46,21 +46,23 @@ def create_dataset(data_path, batch_size=32, repeat_size=1): resize_op = CV.Resize((cfg.image_height, cfg.image_width)) rescale_op = CV.Rescale(rescale, shift) normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) - random_horizontal_op = CV.RandomHorizontalFlip() + if mode == "train": + random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) + random_horizontal_op = CV.RandomHorizontalFlip() channel_swap_op = CV.HWC2CHW() typecast_op = C.TypeCast(mstype.int32) cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op) - cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op) - cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op) + if mode == "train": + cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op) + cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op) cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op) cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op) cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op) cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op) cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size) - cifar_ds = cifar_ds.repeat(repeat_size) cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) + cifar_ds = cifar_ds.repeat(repeat_size) return cifar_ds @@ -88,7 +90,8 @@ if __name__ == "__main__": print("============== Starting Training ==============") ds_train = create_dataset(args.data_path, cfg.batch_size, - repeat_size) + repeat_size, + args.mode) config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck) @@ -98,7 +101,7 @@ if __name__ == "__main__": print("============== Starting Testing ==============") param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) - ds_eval = create_dataset(args.data_path) + ds_eval = create_dataset(args.data_path, mode=args.mode) acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) print("============== Accuracy:{} ==============".format(acc)) else: -- GitLab