提交 6e304d04 编写于 作者: W wukesong

update lenet and alexnet

上级 c929b056
......@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""LeNet."""
import mindspore.ops.operations as P
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
......
......@@ -14,7 +14,6 @@
# ============================================================================
"""Alexnet."""
from config import alexnet_cfg as cfg
import mindspore.ops.operations as P
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
......
......@@ -17,9 +17,9 @@ AlexNet example tutorial
Usage:
python alexnet.py
with --device_target=GPU: After 20 epoch training, the accuracy is up to 80%
with --device_target=Ascend: After 10 epoch training, the accuracy is up to 81%
"""
import os
import argparse
from config import alexnet_cfg as cfg
from alexnet import AlexNet
......@@ -35,7 +35,7 @@ from mindspore.nn.metrics import Accuracy
from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1):
def create_dataset(data_path, batch_size=32, repeat_size=1, mode="train"):
"""
create dataset for train or test
"""
......@@ -46,21 +46,23 @@ def create_dataset(data_path, batch_size=32, repeat_size=1):
resize_op = CV.Resize((cfg.image_height, cfg.image_width))
rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
random_horizontal_op = CV.RandomHorizontalFlip()
if mode == "train":
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
random_horizontal_op = CV.RandomHorizontalFlip()
channel_swap_op = CV.HWC2CHW()
typecast_op = C.TypeCast(mstype.int32)
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op)
if mode == "train":
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op)
cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
cifar_ds = cifar_ds.repeat(repeat_size)
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
cifar_ds = cifar_ds.repeat(repeat_size)
return cifar_ds
......@@ -88,7 +90,8 @@ if __name__ == "__main__":
print("============== Starting Training ==============")
ds_train = create_dataset(args.data_path,
cfg.batch_size,
repeat_size)
repeat_size,
args.mode)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
......@@ -98,7 +101,7 @@ if __name__ == "__main__":
print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
ds_eval = create_dataset(args.data_path)
ds_eval = create_dataset(args.data_path, mode=args.mode)
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册