提交 6e304d04 编写于 作者: W wukesong

update lenet and alexnet

上级 c929b056
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""LeNet.""" """LeNet."""
import mindspore.ops.operations as P
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""Alexnet.""" """Alexnet."""
from config import alexnet_cfg as cfg from config import alexnet_cfg as cfg
import mindspore.ops.operations as P
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
......
...@@ -17,9 +17,9 @@ AlexNet example tutorial ...@@ -17,9 +17,9 @@ AlexNet example tutorial
Usage: Usage:
python alexnet.py python alexnet.py
with --device_target=GPU: After 20 epoch training, the accuracy is up to 80% with --device_target=GPU: After 20 epoch training, the accuracy is up to 80%
with --device_target=Ascend: After 10 epoch training, the accuracy is up to 81%
""" """
import os
import argparse import argparse
from config import alexnet_cfg as cfg from config import alexnet_cfg as cfg
from alexnet import AlexNet from alexnet import AlexNet
...@@ -35,7 +35,7 @@ from mindspore.nn.metrics import Accuracy ...@@ -35,7 +35,7 @@ from mindspore.nn.metrics import Accuracy
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1): def create_dataset(data_path, batch_size=32, repeat_size=1, mode="train"):
""" """
create dataset for train or test create dataset for train or test
""" """
...@@ -46,21 +46,23 @@ def create_dataset(data_path, batch_size=32, repeat_size=1): ...@@ -46,21 +46,23 @@ def create_dataset(data_path, batch_size=32, repeat_size=1):
resize_op = CV.Resize((cfg.image_height, cfg.image_width)) resize_op = CV.Resize((cfg.image_height, cfg.image_width))
rescale_op = CV.Rescale(rescale, shift) rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) if mode == "train":
random_horizontal_op = CV.RandomHorizontalFlip() random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
random_horizontal_op = CV.RandomHorizontalFlip()
channel_swap_op = CV.HWC2CHW() channel_swap_op = CV.HWC2CHW()
typecast_op = C.TypeCast(mstype.int32) typecast_op = C.TypeCast(mstype.int32)
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op) cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op) if mode == "train":
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op) cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op) cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op) cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op) cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op) cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op)
cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size) cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
cifar_ds = cifar_ds.repeat(repeat_size)
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
cifar_ds = cifar_ds.repeat(repeat_size)
return cifar_ds return cifar_ds
...@@ -88,7 +90,8 @@ if __name__ == "__main__": ...@@ -88,7 +90,8 @@ if __name__ == "__main__":
print("============== Starting Training ==============") print("============== Starting Training ==============")
ds_train = create_dataset(args.data_path, ds_train = create_dataset(args.data_path,
cfg.batch_size, cfg.batch_size,
repeat_size) repeat_size,
args.mode)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
...@@ -98,7 +101,7 @@ if __name__ == "__main__": ...@@ -98,7 +101,7 @@ if __name__ == "__main__":
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
ds_eval = create_dataset(args.data_path) ds_eval = create_dataset(args.data_path, mode=args.mode)
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc)) print("============== Accuracy:{} ==============".format(acc))
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册