Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
docs
提交
4fdbc9c9
D
docs
项目概览
MindSpore
/
docs
通知
4
Star
2
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
docs
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4fdbc9c9
编写于
6月 17, 2020
作者:
C
caojian05
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update lstm support infomation
上级
2228c317
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
21 addition
and
15 deletion
+21
-15
docs/source_en/network_list.md
docs/source_en/network_list.md
+1
-1
docs/source_zh_cn/network_list.md
docs/source_zh_cn/network_list.md
+1
-1
tutorials/tutorial_code/lstm/main.py
tutorials/tutorial_code/lstm/main.py
+19
-13
未找到文件。
docs/source_en/network_list.md
浏览文件 @
4fdbc9c9
...
...
@@ -15,4 +15,4 @@
| Computer Version (CV) | Targets Detection |
[
YoloV3
](
https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py
)
| Supported | Doing | Doing
| Computer Version (CV) | Semantic Segmentation |
[
Deeplabv3
](
https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py
)
| Supported | Doing | Doing
| Natural Language Processing (NLP) | Natural Language Understanding |
[
BERT
](
https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py
)
| Supported | Doing | Doing
| Natural Language Processing (NLP) | Natural Language Understanding |
[
SentimentNet
](
https://gitee.com/mindspore/mindspore/blob/master/m
indspore/model_zoo/lstm.py
)
| Doing | Supported | Doing
| Natural Language Processing (NLP) | Natural Language Understanding |
[
SentimentNet
](
https://gitee.com/mindspore/mindspore/blob/master/m
odel_zoo/lstm/src/lstm.py
)
| Doing | Supported | Supported
docs/source_zh_cn/network_list.md
浏览文件 @
4fdbc9c9
...
...
@@ -15,4 +15,4 @@
| 计算机视觉(CV) | 目标检测(Targets Detection) |
[
YoloV3
](
https://gitee.com/mindspore/mindspore/blob/master/mindspore/model_zoo/yolov3.py
)
| Supported | Doing | Doing
| 计算机视觉(CV) | 语义分割(Semantic Segmentation) |
[
Deeplabv3
](
https://gitee.com/mindspore/mindspore/blob/master/model_zoo/deeplabv3/src/deeplabv3.py
)
| Supported | Doing | Doing
| 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) |
[
BERT
](
https://gitee.com/mindspore/mindspore/blob/master/model_zoo/bert/src/bert_model.py
)
| Supported | Doing | Doing
| 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) |
[
SentimentNet
](
https://gitee.com/mindspore/mindspore/blob/master/m
indspore/model_zoo/lstm.py
)
| Doing | Supported | Doing
| 自然语言处理(NLP) | 自然语言理解(Natural Language Understanding) |
[
SentimentNet
](
https://gitee.com/mindspore/mindspore/blob/master/m
odel_zoo/lstm/src/lstm.py
)
| Doing | Supported | Supported
tutorials/tutorial_code/lstm/main.py
浏览文件 @
4fdbc9c9
...
...
@@ -36,7 +36,7 @@ from mindspore.mindrecord import FileWriter
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
# Install gensim with 'pip install gensim'
import
gensim
...
...
@@ -281,26 +281,25 @@ def create_dataset(base_path, batch_size, num_epochs, is_train):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore LSTM Example'
)
parser
.
add_argument
(
'--preprocess'
,
type
=
str
,
default
=
'false'
,
choices
=
[
'true'
,
'false'
],
help
=
'
Whether to perform data preprocessing
'
)
help
=
'
whether to preprocess data.
'
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
default
=
"train"
,
choices
=
[
'train'
,
'test'
],
help
=
'implement phase, set to train or test'
)
# Download dataset from 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz' and extract to 'aclimdb_path'
parser
.
add_argument
(
'--aclimdb_path'
,
type
=
str
,
default
=
"./aclImdb"
,
help
=
'path where the dataset is store'
)
# Download glove from 'http://nlp.stanford.edu/data/glove.6B.zip' and extract to 'glove_path'
# Add a new line '400000 300' at the beginning of 'glove.6B.300d.txt' with '40000' for total words and '300' for vector length
help
=
'path where the dataset is stored.'
)
parser
.
add_argument
(
'--glove_path'
,
type
=
str
,
default
=
"./glove"
,
help
=
'path where the
glove is store
'
)
help
=
'path where the
GloVe is stored.
'
)
parser
.
add_argument
(
'--preprocess_path'
,
type
=
str
,
default
=
"./preprocess"
,
help
=
'path where the pre-process data is store'
)
parser
.
add_argument
(
'--ckpt_path'
,
type
=
str
,
default
=
"./ckpt"
,
help
=
'if mode is test, must provide
\
path where the trained ckpt file'
)
help
=
'path where the pre-process data is stored.'
)
parser
.
add_argument
(
'--ckpt_path'
,
type
=
str
,
default
=
"./"
,
help
=
'if mode is test, must provide path where the trained ckpt file.'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"GPU"
,
choices
=
[
'GPU'
,
'CPU'
],
help
=
'the target device to run, support "GPU", "CPU". Default: "GPU".'
)
args
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
False
,
device_target
=
"GPU"
)
device_target
=
args
.
device_target
)
if
args
.
preprocess
==
'true'
:
print
(
"============== Starting Data Pre-processing =============="
)
...
...
@@ -329,13 +328,20 @@ if __name__ == '__main__':
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"lstm"
,
directory
=
args
.
ckpt_path
,
config
=
config_ck
)
model
.
train
(
cfg
.
num_epochs
,
ds_train
,
callbacks
=
[
ckpoint_cb
,
loss_cb
])
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
if
args
.
device_target
==
"CPU"
:
model
.
train
(
cfg
.
num_epochs
,
ds_train
,
callbacks
=
[
time_cb
,
ckpoint_cb
,
loss_cb
],
dataset_sink_mode
=
False
)
else
:
model
.
train
(
cfg
.
num_epochs
,
ds_train
,
callbacks
=
[
time_cb
,
ckpoint_cb
,
loss_cb
])
elif
args
.
mode
==
'test'
:
print
(
"============== Starting Testing =============="
)
ds_eval
=
create_dataset
(
args
.
preprocess_path
,
cfg
.
batch_size
,
1
,
False
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
acc
=
model
.
eval
(
ds_eval
)
if
args
.
device_target
==
"CPU"
:
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
else
:
acc
=
model
.
eval
(
ds_eval
)
print
(
"============== Accuracy:{} =============="
.
format
(
acc
))
else
:
raise
RuntimeError
(
'mode should be train or test, rather than {}'
.
format
(
args
.
mode
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录