diff --git a/chapter03/lenet/lenet.py b/chapter03/lenet/lenet.py index deef8729ac888bca34d6de5d0cf6b9a284ce0a25..5c08ed1b7719d2e53a59d7e841a992f9faa45598 100644 --- a/chapter03/lenet/lenet.py +++ b/chapter03/lenet/lenet.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ """LeNet.""" -import mindspore.ops.operations as P import mindspore.nn as nn from mindspore.common.initializer import TruncatedNormal diff --git a/chapter04/alexnet/alexnet.py b/chapter04/alexnet/alexnet.py index 5c5e723aae442b3e5bdb7679da8a29d401605750..b0c326046104c0a75da54af250d339a0f47a4d14 100644 --- a/chapter04/alexnet/alexnet.py +++ b/chapter04/alexnet/alexnet.py @@ -14,7 +14,6 @@ # ============================================================================ """Alexnet.""" from config import alexnet_cfg as cfg -import mindspore.ops.operations as P import mindspore.nn as nn from mindspore.common.initializer import TruncatedNormal diff --git a/chapter04/alexnet/main.py b/chapter04/alexnet/main.py index 1bf940bfcfa95160f3230675ca0d41373f0b7eb2..8ac57657adc52d1488506a7e6b7c2c4b4978cbdc 100644 --- a/chapter04/alexnet/main.py +++ b/chapter04/alexnet/main.py @@ -17,9 +17,9 @@ AlexNet example tutorial Usage: python alexnet.py with --device_target=GPU: After 20 epoch training, the accuracy is up to 80% +with --device_target=Ascend: After 10 epoch training, the accuracy is up to 81% """ -import os import argparse from config import alexnet_cfg as cfg from alexnet import AlexNet @@ -35,7 +35,7 @@ from mindspore.nn.metrics import Accuracy from mindspore.common import dtype as mstype -def create_dataset(data_path, batch_size=32, repeat_size=1): +def create_dataset(data_path, batch_size=32, repeat_size=1, mode="train"): """ create dataset for train or test """ @@ -46,21 +46,23 @@ def create_dataset(data_path, batch_size=32, repeat_size=1): resize_op = CV.Resize((cfg.image_height, cfg.image_width)) rescale_op = CV.Rescale(rescale, shift) normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) - random_horizontal_op = CV.RandomHorizontalFlip() + if mode == "train": + random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) + random_horizontal_op = CV.RandomHorizontalFlip() channel_swap_op = CV.HWC2CHW() typecast_op = C.TypeCast(mstype.int32) cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op) - cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op) - cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op) + if mode == "train": + cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op) + cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op) cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op) cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op) cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op) cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op) cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size) - cifar_ds = cifar_ds.repeat(repeat_size) cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) + cifar_ds = cifar_ds.repeat(repeat_size) return cifar_ds @@ -88,7 +90,8 @@ if __name__ == "__main__": print("============== Starting Training ==============") ds_train = create_dataset(args.data_path, cfg.batch_size, - repeat_size) + repeat_size, + args.mode) config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck) @@ -98,7 +101,7 @@ if __name__ == "__main__": print("============== Starting Testing ==============") param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) - ds_eval = create_dataset(args.data_path) + ds_eval = create_dataset(args.data_path, mode=args.mode) acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) print("============== Accuracy:{} ==============".format(acc)) else: