提交 4fdbc9c9 编写于 作者: C caojian05

update lstm support infomation

上级 2228c317
...@@ -15,4 +15,4 @@ ...@@ -15,4 +15,4 @@
| Computer Version (CV) | Targets Detection | [YoloV3](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py) | Supported | Doing | Doing | Computer Version (CV) | Targets Detection | [YoloV3](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py) | Supported | Doing | Doing
| Computer Version (CV) | Semantic Segmentation | [Deeplabv3](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py) | Supported | Doing | Doing | Computer Version (CV) | Semantic Segmentation | [Deeplabv3](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py) | Supported | Doing | Doing
| Natural Language Processing (NLP) | Natural Language Understanding | [BERT](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py) | Supported | Doing | Doing | Natural Language Processing (NLP) | Natural Language Understanding | [BERT](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py) | Supported | Doing | Doing
| Natural Language Processing (NLP) | Natural Language Understanding | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/lstm.py) | Doing | Supported | Doing | Natural Language Processing (NLP) | Natural Language Understanding | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/lstm/src/lstm.py) | Doing | Supported | Supported
...@@ -15,4 +15,4 @@ ...@@ -15,4 +15,4 @@
| 计算机视觉(CV) | 目标检测(Targets Detection) | [YoloV3](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py) | Supported | Doing | Doing | 计算机视觉(CV) | 目标检测(Targets Detection) | [YoloV3](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py) | Supported | Doing | Doing
| 计算机视觉(CV) | 语义分割(Semantic Segmentation) | [Deeplabv3](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py) | Supported | Doing | Doing | 计算机视觉(CV) | 语义分割(Semantic Segmentation) | [Deeplabv3](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py) | Supported | Doing | Doing
| 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) | [BERT](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py) | Supported | Doing | Doing | 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) | [BERT](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py) | Supported | Doing | Doing
| 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/lstm.py) | Doing | Supported | Doing | 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) | [SentimentNet](https://gitee.com/mindspore/mindspore/blob/master/model_zoo/lstm/src/lstm.py) | Doing | Supported | Supported
...@@ -36,7 +36,7 @@ from mindspore.mindrecord import FileWriter ...@@ -36,7 +36,7 @@ from mindspore.mindrecord import FileWriter
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
# Install gensim with 'pip install gensim' # Install gensim with 'pip install gensim'
import gensim import gensim
...@@ -281,26 +281,25 @@ def create_dataset(base_path, batch_size, num_epochs, is_train): ...@@ -281,26 +281,25 @@ def create_dataset(base_path, batch_size, num_epochs, is_train):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MindSpore LSTM Example') parser = argparse.ArgumentParser(description='MindSpore LSTM Example')
parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'],
help='Whether to perform data preprocessing') help='whether to preprocess data.')
parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'], parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'],
help='implement phase, set to train or test') help='implement phase, set to train or test')
# Download dataset from 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz' and extract to 'aclimdb_path'
parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", parser.add_argument('--aclimdb_path', type=str, default="./aclImdb",
help='path where the dataset is store') help='path where the dataset is stored.')
# Download glove from 'http://nlp.stanford.edu/data/glove.6B.zip' and extract to 'glove_path'
# Add a new line '400000 300' at the beginning of 'glove.6B.300d.txt' with '40000' for total words and '300' for vector length
parser.add_argument('--glove_path', type=str, default="./glove", parser.add_argument('--glove_path', type=str, default="./glove",
help='path where the glove is store') help='path where the GloVe is stored.')
parser.add_argument('--preprocess_path', type=str, default="./preprocess", parser.add_argument('--preprocess_path', type=str, default="./preprocess",
help='path where the pre-process data is store') help='path where the pre-process data is stored.')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if mode is test, must provide\ parser.add_argument('--ckpt_path', type=str, default="./",
path where the trained ckpt file') help='if mode is test, must provide path where the trained ckpt file.')
parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'],
help='the target device to run, support "GPU", "CPU". Default: "GPU".')
args = parser.parse_args() args = parser.parse_args()
context.set_context( context.set_context(
mode=context.GRAPH_MODE, mode=context.GRAPH_MODE,
save_graphs=False, save_graphs=False,
device_target="GPU") device_target=args.device_target)
if args.preprocess == 'true': if args.preprocess == 'true':
print("============== Starting Data Pre-processing ==============") print("============== Starting Data Pre-processing ==============")
...@@ -329,13 +328,20 @@ if __name__ == '__main__': ...@@ -329,13 +328,20 @@ if __name__ == '__main__':
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
model.train(cfg.num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb]) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
if args.device_target == "CPU":
model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False)
else:
model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb])
elif args.mode == 'test': elif args.mode == 'test':
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, 1, False) ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, 1, False)
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
acc = model.eval(ds_eval) if args.device_target == "CPU":
acc = model.eval(ds_eval, dataset_sink_mode=False)
else:
acc = model.eval(ds_eval)
print("============== Accuracy:{} ==============".format(acc)) print("============== Accuracy:{} ==============".format(acc))
else: else:
raise RuntimeError('mode should be train or test, rather than {}'.format(args.mode)) raise RuntimeError('mode should be train or test, rather than {}'.format(args.mode))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册