提交 49ed4299 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1452 add train and eval script for LSTM

Merge pull request !1452 from caojian05/ms_master_dev
# LSTM Example
## Description
This example is for LSTM model training and evaluation.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset aclImdb_v1.
> Unzip the aclImdb_v1 dataset to any path you want and the folder structure should be as follows:
> ```
> .
> ├── train # train dataset
> └── test # infer dataset
> ```
- Download the GloVe file.
> Unzip the glove.6B.zip to any path you want and the folder structure should be as follows:
> ```
> .
> ├── glove.6B.100d.txt
> ├── glove.6B.200d.txt
> ├── glove.6B.300d.txt # we will use this one later.
> └── glove.6B.50d.txt
> ```
> Adding a new line at the beginning of the file which named `glove.6B.300d.txt`.
> It means reading a total of 400,000 words, each represented by a 300-latitude word vector.
> ```
> 400000 300
> ```
## Running the Example
### Training
```
python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path > out.train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `out.train.log`.
After training, you'll get some checkpoint files under the script folder by default.
You will get the loss value as following:
```
# grep "loss is " out.train.log
epoch: 1 step: 390, loss is 0.6003723
epcoh: 2 step: 390, loss is 0.35312173
...
```
### Evaluation
```
python eval.py --ckpt_path=./lstm-20-390.ckpt > out.eval.log 2>&1 &
```
The above python command will run in the background, you can view the results through the file `out.eval.log`.
You will get the accuracy as following:
```
# grep "acc" out.eval.log
result: {'acc': 0.83}
```
## Usage:
### Training
```
usage: train.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH]
[--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH]
[--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}]
parameters/options:
--preprocess whether to preprocess data.
--aclimdb_path path where the dataset is stored.
--glove_path path where the GloVe is stored.
--preprocess_path path where the pre-process data is stored.
--ckpt_path the path to save the checkpoint file.
--device_target the target device to run, support "GPU", "CPU".
```
### Evaluation
```
usage: eval.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH]
[--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH]
[--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}]
parameters/options:
--preprocess whether to preprocess data.
--aclimdb_path path where the dataset is stored.
--glove_path path where the GloVe is stored.
--preprocess_path path where the pre-process data is stored.
--ckpt_path the checkpoint file path used to evaluate model.
--device_target the target device to run, support "GPU", "CPU".
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting
"""
from easydict import EasyDict as edict
# LSTM CONFIG
lstm_cfg = edict({
'num_classes': 2,
'learning_rate': 0.1,
'momentum': 0.9,
'num_epochs': 20,
'batch_size': 64,
'embed_size': 300,
'num_hiddens': 100,
'num_layers': 2,
'bidirectional': True,
'save_checkpoint_steps': 390,
'keep_checkpoint_max': 10
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import numpy as np
from imdb import ImdbParser
import mindspore.dataset as ds
from mindspore.mindrecord import FileWriter
def create_dataset(data_home, batch_size, repeat_num=1, training=True):
"""Data operations."""
ds.config.set_seed(1)
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0")
if not training:
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0")
data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4)
# apply map operations on images
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
data_set = data_set.repeat(count=repeat_num)
return data_set
def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True):
"""
convert imdb dataset to mindrecoed dataset
"""
if weight_np is not None:
np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np)
# write mindrecord
schema_json = {"id": {"type": "int32"},
"label": {"type": "int32"},
"feature": {"type": "int32", "shape": [-1]}}
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord")
if not training:
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord")
def get_imdb_data(features, labels):
data_list = []
for i, (label, feature) in enumerate(zip(labels, features)):
data_json = {"id": i,
"label": int(label),
"feature": feature.reshape(-1)}
data_list.append(data_json)
return data_list
writer = FileWriter(data_dir, shard_num=4)
data = get_imdb_data(features, labels)
writer.add_schema(schema_json, "nlp_schema")
writer.add_index(["id", "label"])
writer.write_raw_data(data)
writer.commit()
def convert_to_mindrecord(embed_size, aclimdb_path, preprocess_path, glove_path):
"""
convert imdb dataset to mindrecoed dataset
"""
parser = ImdbParser(aclimdb_path, glove_path, embed_size)
parser.parse()
if not os.path.exists(preprocess_path):
print(f"preprocess path {preprocess_path} is not exist")
os.makedirs(preprocess_path)
train_features, train_labels, train_weight_np = parser.get_datas('train')
_convert_to_mindrecord(preprocess_path, train_features, train_labels, train_weight_np)
test_features, test_labels, _ = parser.get_datas('test')
_convert_to_mindrecord(preprocess_path, test_features, test_labels, training=False)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train lstm example on aclImdb########################
python eval.py --ckpt_path=./lstm-20-390.ckpt
"""
import argparse
import os
import numpy as np
from config import lstm_cfg as cfg
from dataset import create_dataset, convert_to_mindrecord
from mindspore import Tensor, nn, Model, context
from mindspore.model_zoo.lstm import SentimentNet
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MindSpore LSTM Example')
parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'],
help='whether to preprocess data.')
parser.add_argument('--aclimdb_path', type=str, default="./aclImdb",
help='path where the dataset is stored.')
parser.add_argument('--glove_path', type=str, default="./glove",
help='path where the GloVe is stored.')
parser.add_argument('--preprocess_path', type=str, default="./preprocess",
help='path where the pre-process data is stored.')
parser.add_argument('--ckpt_path', type=str, default=None,
help='the checkpoint file path used to evaluate model.')
parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'],
help='the target device to run, support "GPU", "CPU". Default: "GPU".')
args = parser.parse_args()
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target=args.device_target)
if args.preprocess == "true":
print("============== Starting Data Pre-processing ==============")
convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path)
embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32)
network = SentimentNet(vocab_size=embedding_table.shape[0],
embed_size=cfg.embed_size,
num_hiddens=cfg.num_hiddens,
num_layers=cfg.num_layers,
bidirectional=cfg.bidirectional,
num_classes=cfg.num_classes,
weight=Tensor(embedding_table),
batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
loss_cb = LossMonitor()
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Testing ==============")
ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, training=False)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
if args.device_target == "CPU":
acc = model.eval(ds_eval, dataset_sink_mode=False)
else:
acc = model.eval(ds_eval)
print("============== Accuracy:{} ==============".format(acc))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
imdb dataset parser.
"""
import os
from itertools import chain
import gensim
import numpy as np
class ImdbParser():
"""
parse aclImdb data to features and labels.
sentence->tokenized->encoded->padding->features
"""
def __init__(self, imdb_path, glove_path, embed_size=300):
self.__segs = ['train', 'test']
self.__label_dic = {'pos': 1, 'neg': 0}
self.__imdb_path = imdb_path
self.__glove_dim = embed_size
self.__glove_file = os.path.join(glove_path, 'glove.6B.' + str(self.__glove_dim) + 'd.txt')
# properties
self.__imdb_datas = {}
self.__features = {}
self.__labels = {}
self.__vacab = {}
self.__word2idx = {}
self.__weight_np = {}
self.__wvmodel = None
def parse(self):
"""
parse imdb data to memory
"""
self.__wvmodel = gensim.models.KeyedVectors.load_word2vec_format(self.__glove_file)
for seg in self.__segs:
self.__parse_imdb_datas(seg)
self.__parse_features_and_labels(seg)
self.__gen_weight_np(seg)
def __parse_imdb_datas(self, seg):
"""
load data from txt
"""
data_lists = []
for label_name, label_id in self.__label_dic.items():
sentence_dir = os.path.join(self.__imdb_path, seg, label_name)
for file in os.listdir(sentence_dir):
with open(os.path.join(sentence_dir, file), mode='r', encoding='utf8') as f:
sentence = f.read().replace('\n', '')
data_lists.append([sentence, label_id])
self.__imdb_datas[seg] = data_lists
def __parse_features_and_labels(self, seg):
"""
parse features and labels
"""
features = []
labels = []
for sentence, label in self.__imdb_datas[seg]:
features.append(sentence)
labels.append(label)
self.__features[seg] = features
self.__labels[seg] = labels
# update feature to tokenized
self.__updata_features_to_tokenized(seg)
# parse vacab
self.__parse_vacab(seg)
# encode feature
self.__encode_features(seg)
# padding feature
self.__padding_features(seg)
def __updata_features_to_tokenized(self, seg):
tokenized_features = []
for sentence in self.__features[seg]:
tokenized_sentence = [word.lower() for word in sentence.split(" ")]
tokenized_features.append(tokenized_sentence)
self.__features[seg] = tokenized_features
def __parse_vacab(self, seg):
# vocab
tokenized_features = self.__features[seg]
vocab = set(chain(*tokenized_features))
self.__vacab[seg] = vocab
# word_to_idx: {'hello': 1, 'world':111, ... '<unk>': 0}
word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
word_to_idx['<unk>'] = 0
self.__word2idx[seg] = word_to_idx
def __encode_features(self, seg):
""" encode word to index """
word_to_idx = self.__word2idx['train']
encoded_features = []
for tokenized_sentence in self.__features[seg]:
encoded_sentence = []
for word in tokenized_sentence:
encoded_sentence.append(word_to_idx.get(word, 0))
encoded_features.append(encoded_sentence)
self.__features[seg] = encoded_features
def __padding_features(self, seg, maxlen=500, pad=0):
""" pad all features to the same length """
padded_features = []
for feature in self.__features[seg]:
if len(feature) >= maxlen:
padded_feature = feature[:maxlen]
else:
padded_feature = feature
while len(padded_feature) < maxlen:
padded_feature.append(pad)
padded_features.append(padded_feature)
self.__features[seg] = padded_features
def __gen_weight_np(self, seg):
"""
generate weight by gensim
"""
weight_np = np.zeros((len(self.__word2idx[seg]), self.__glove_dim), dtype=np.float32)
for word, idx in self.__word2idx[seg].items():
if word not in self.__wvmodel:
continue
word_vector = self.__wvmodel.get_vector(word)
weight_np[idx, :] = word_vector
self.__weight_np[seg] = weight_np
def get_datas(self, seg):
"""
return features, labels, and weight
"""
features = np.array(self.__features[seg]).astype(np.int32)
labels = np.array(self.__labels[seg]).astype(np.int32)
weight = np.array(self.__weight_np[seg])
return features, labels, weight
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train lstm example on aclImdb########################
python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path
"""
import argparse
import os
import numpy as np
from config import lstm_cfg as cfg
from dataset import convert_to_mindrecord
from dataset import create_dataset
from mindspore import Tensor, nn, Model, context
from mindspore.model_zoo.lstm import SentimentNet
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MindSpore LSTM Example')
parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'],
help='whether to preprocess data.')
parser.add_argument('--aclimdb_path', type=str, default="./aclImdb",
help='path where the dataset is stored.')
parser.add_argument('--glove_path', type=str, default="./glove",
help='path where the GloVe is stored.')
parser.add_argument('--preprocess_path', type=str, default="./preprocess",
help='path where the pre-process data is stored.')
parser.add_argument('--ckpt_path', type=str, default="./",
help='the path to save the checkpoint file.')
parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'],
help='the target device to run, support "GPU", "CPU". Default: "GPU".')
args = parser.parse_args()
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target=args.device_target)
if args.preprocess == "true":
print("============== Starting Data Pre-processing ==============")
convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path)
embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32)
network = SentimentNet(vocab_size=embedding_table.shape[0],
embed_size=cfg.embed_size,
num_hiddens=cfg.num_hiddens,
num_layers=cfg.num_layers,
bidirectional=cfg.bidirectional,
num_classes=cfg.num_classes,
weight=Tensor(embedding_table),
batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
loss_cb = LossMonitor()
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============")
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
if args.device_target == "CPU":
model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False)
else:
model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("============== Training Success ==============")
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LSTM."""
import math
import numpy as np
from mindspore import Parameter, Tensor, nn
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
def init_lstm_weight(
input_size,
hidden_size,
num_layers,
bidirectional,
has_bias=True):
"""Initialize lstm weight."""
num_directions = 1
if bidirectional:
num_directions = 2
weight_size = 0
gate_size = 4 * hidden_size
for layer in range(num_layers):
for _ in range(num_directions):
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
weight_size += gate_size * input_layer_size
weight_size += gate_size * hidden_size
if has_bias:
weight_size += 2 * gate_size
stdv = 1 / math.sqrt(hidden_size)
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
return w
# Initialize short-term memory (h) and long-term memory (c) to 0
def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
"""init default input."""
num_directions = 1
if bidirectional:
num_directions = 2
h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
return h, c
class SentimentNet(nn.Cell):
"""Sentiment network structure."""
def __init__(self,
vocab_size,
embed_size,
num_hiddens,
num_layers,
bidirectional,
num_classes,
weight,
batch_size):
super(SentimentNet, self).__init__()
# Mapp words to vectors
self.embedding = nn.Embedding(vocab_size,
embed_size,
embedding_table=weight)
self.embedding.embedding_table.requires_grad = False
self.trans = P.Transpose()
self.perm = (1, 0, 2)
self.encoder = nn.LSTM(input_size=embed_size,
hidden_size=num_hiddens,
num_layers=num_layers,
has_bias=True,
bidirectional=bidirectional,
dropout=0.0)
w_init = init_lstm_weight(
embed_size,
num_hiddens,
num_layers,
bidirectional)
self.encoder.weight = w_init
self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
self.concat = P.Concat(1)
if bidirectional:
self.decoder = nn.Dense(num_hiddens * 4, num_classes)
else:
self.decoder = nn.Dense(num_hiddens * 2, num_classes)
def construct(self, inputs):
# input:(64,500,300)
embeddings = self.embedding(inputs)
embeddings = self.trans(embeddings, self.perm)
output, _ = self.encoder(embeddings, (self.h, self.c))
# states[i] size(64,200) -> encoding.size(64,400)
encoding = self.concat((output[0], output[1]))
outputs = self.decoder(encoding)
return outputs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册