diff --git a/model_zoo/official/cv/alexnet/eval.py b/model_zoo/official/cv/alexnet/eval.py index 6a091aedd893351e3f40644921c1531f2638ca66..b8d7a87c367fbeb930052e414a7647c320370a6a 100644 --- a/model_zoo/official/cv/alexnet/eval.py +++ b/model_zoo/official/cv/alexnet/eval.py @@ -18,6 +18,7 @@ eval alexnet according to model file: python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt """ +import ast import argparse from src.config import alexnet_cfg as cfg from src.dataset import create_dataset_cifar10 @@ -36,7 +37,8 @@ if __name__ == "__main__": parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=True, help='dataset_sink_mode is False or True') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) diff --git a/model_zoo/official/cv/alexnet/train.py b/model_zoo/official/cv/alexnet/train.py index 4512244b922d3a235953144e4945b6b484f34702..37d7ca1b60743640321d2b4c2029bb979e9619a1 100644 --- a/model_zoo/official/cv/alexnet/train.py +++ b/model_zoo/official/cv/alexnet/train.py @@ -18,6 +18,7 @@ train alexnet and get network model files(.ckpt) : python train.py --data_path /YourDataPath """ +import ast import argparse from src.config import alexnet_cfg as cfg from src.dataset import create_dataset_cifar10 @@ -38,7 +39,8 @@ if __name__ == "__main__": parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=True, help='dataset_sink_mode is False or True') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) diff --git a/model_zoo/official/cv/lenet/eval.py b/model_zoo/official/cv/lenet/eval.py index bcd5503c399ccad0752fef450638725ea07bf5e8..4083a06400c23bd9ac0d6b5e484c0de0bac64997 100644 --- a/model_zoo/official/cv/lenet/eval.py +++ b/model_zoo/official/cv/lenet/eval.py @@ -19,6 +19,7 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt """ import os +import ast import argparse import mindspore.nn as nn from mindspore import context @@ -37,7 +38,8 @@ if __name__ == "__main__": help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=False, help='dataset_sink_mode is False or True') args = parser.parse_args() diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py index 2c45c5b32748d41a1814c29c83d276ac803e5980..7cd379134aa02ba0bf41a389ddccef6c14c1bbca 100644 --- a/model_zoo/official/cv/lenet/train.py +++ b/model_zoo/official/cv/lenet/train.py @@ -19,6 +19,7 @@ python train.py --data_path /YourDataPath """ import os +import ast import argparse from src.config import mnist_cfg as cfg from src.dataset import create_dataset @@ -38,7 +39,8 @@ if __name__ == "__main__": help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, default=True, + help='dataset_sink_mode is False or True') args = parser.parse_args()