diff --git a/docs/source_en/network_list.md b/docs/source_en/network_list.md index a46c8e8260d2a493bab80b956d56a218bc194fba..d8a428bfd720fdb1afb111e3a0edb74d2d55b586 100644 --- a/docs/source_en/network_list.md +++ b/docs/source_en/network_list.md @@ -15,4 +15,4 @@ | Computer Version (CV) | Targets Detection | [YoloV3](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py) | Supported | Doing | Doing | Computer Version (CV) | Semantic Segmentation | [Deeplabv3](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py) | Supported | Doing | Doing | Natural Language Processing (NLP) | Natural Language Understanding | [BERT](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py) | Supported | Doing | Doing -| Natural Language Processing (NLP) | Natural Language Understanding | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/lstm.py) | Doing | Supported | Doing +| Natural Language Processing (NLP) | Natural Language Understanding | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/lstm/src/lstm.py) | Doing | Supported | Supported diff --git a/docs/source_zh_cn/network_list.md b/docs/source_zh_cn/network_list.md index b3ed4486e77a1a41abe889b3ccfe944a18d3481a..1d813e6da0234f02584ba1dd8b442807f9d5961a 100644 --- a/docs/source_zh_cn/network_list.md +++ b/docs/source_zh_cn/network_list.md @@ -15,4 +15,4 @@ | 计算机视觉(CV) | 目标检测(Targets Detection) | [YoloV3](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py) | Supported | Doing | Doing | 计算机视觉(CV) | 语义分割(Semantic Segmentation) | [Deeplabv3](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py) | Supported | Doing | Doing | 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) | [BERT](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py) | Supported | Doing | Doing -| 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/lstm.py) | Doing | Supported | Doing +| 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/lstm/src/lstm.py) | Doing | Supported | Supported diff --git a/tutorials/tutorial_code/lstm/main.py b/tutorials/tutorial_code/lstm/main.py index dde13eb85f6a583bca848aad66a596843a75f8c0..cccf171b65051b33fb9e89931aca6824effb1907 100644 --- a/tutorials/tutorial_code/lstm/main.py +++ b/tutorials/tutorial_code/lstm/main.py @@ -36,7 +36,7 @@ from mindspore.mindrecord import FileWriter from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor # Install gensim with 'pip install gensim' import gensim @@ -281,26 +281,25 @@ def create_dataset(base_path, batch_size, num_epochs, is_train): if __name__ == '__main__': parser = argparse.ArgumentParser(description='MindSpore LSTM Example') parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], - help='Whether to perform data preprocessing') + help='whether to preprocess data.') parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'], help='implement phase, set to train or test') - # Download dataset from 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz' and extract to 'aclimdb_path' parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", - help='path where the dataset is store') - # Download glove from 'http://nlp.stanford.edu/data/glove.6B.zip' and extract to 'glove_path' - # Add a new line '400000 300' at the beginning of 'glove.6B.300d.txt' with '40000' for total words and '300' for vector length + help='path where the dataset is stored.') parser.add_argument('--glove_path', type=str, default="./glove", - help='path where the glove is store') + help='path where the GloVe is stored.') parser.add_argument('--preprocess_path', type=str, default="./preprocess", - help='path where the pre-process data is store') - parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if mode is test, must provide\ - path where the trained ckpt file') + help='path where the pre-process data is stored.') + parser.add_argument('--ckpt_path', type=str, default="./", + help='if mode is test, must provide path where the trained ckpt file.') + parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], + help='the target device to run, support "GPU", "CPU". Default: "GPU".') args = parser.parse_args() context.set_context( mode=context.GRAPH_MODE, save_graphs=False, - device_target="GPU") + device_target=args.device_target) if args.preprocess == 'true': print("============== Starting Data Pre-processing ==============") @@ -329,13 +328,20 @@ if __name__ == '__main__': config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) - model.train(cfg.num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb]) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + if args.device_target == "CPU": + model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False) + else: + model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb]) elif args.mode == 'test': print("============== Starting Testing ==============") ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, 1, False) param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) - acc = model.eval(ds_eval) + if args.device_target == "CPU": + acc = model.eval(ds_eval, dataset_sink_mode=False) + else: + acc = model.eval(ds_eval) print("============== Accuracy:{} ==============".format(acc)) else: raise RuntimeError('mode should be train or test, rather than {}'.format(args.mode))