提交 90857a5d 编写于 作者: C caoying03

clean codes of text classification and ner.

上级 68caa8ca
import gzip import gzip
import reader import reader
from network_conf import * import paddle.v2 as paddle
from utils import * from network_conf import ner_net
from utils import load_dict, load_reverse_dict
def infer(model_path, batch_size, test_data_file, vocab_file, target_file): def infer(model_path, batch_size, test_data_file, vocab_file, target_file):
......
...@@ -2,8 +2,11 @@ import gzip ...@@ -2,8 +2,11 @@ import gzip
import numpy as np import numpy as np
import reader import reader
from utils import * from utils import logger, load_dict, get_embedding
from network_conf import * from network_conf import ner_net
import paddle.v2 as paddle
import paddle.v2.evaluator as evaluator
def main(train_data_file, def main(train_data_file,
...@@ -11,8 +14,12 @@ def main(train_data_file, ...@@ -11,8 +14,12 @@ def main(train_data_file,
vocab_file, vocab_file,
target_file, target_file,
emb_file, emb_file,
model_save_dir,
num_passes=10, num_passes=10,
batch_size=32): batch_size=32):
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
word_dict = load_dict(vocab_file) word_dict = load_dict(vocab_file)
label_dict = load_dict(target_file) label_dict = load_dict(target_file)
...@@ -77,8 +84,9 @@ def main(train_data_file, ...@@ -77,8 +84,9 @@ def main(train_data_file,
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with gzip.open("models/params_pass_%d.tar.gz" % event.pass_id, with gzip.open(
"w") as f: os.path.join(model_save_dir, "params_pass_%d.tar.gz" %
event.pass_id), "w") as f:
parameters.to_tar(f) parameters.to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
...@@ -94,8 +102,8 @@ def main(train_data_file, ...@@ -94,8 +102,8 @@ def main(train_data_file,
if __name__ == "__main__": if __name__ == "__main__":
main( main(
train_data_file='data/train', train_data_file="data/train",
test_data_file='data/test', test_data_file="data/test",
vocab_file='data/vocab.txt', vocab_file="data/vocab.txt",
target_file='data/target.txt', target_file="data/target.txt",
emb_file='data/wordVectors.txt') emb_file="data/wordVectors.txt")
...@@ -6,15 +6,15 @@ import gzip ...@@ -6,15 +6,15 @@ import gzip
import paddle.v2 as paddle import paddle.v2 as paddle
import network_conf
import reader import reader
from utils import * from network_conf import fc_net, convolution_net
from utils import logger, load_dict
def infer(topology, data_dir, model_path, word_dict_path, label_dict_path, def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
batch_size): batch_size):
def _infer_a_batch(inferer, test_batch, ids_2_word, ids_2_label): def _infer_a_batch(inferer, test_batch, ids_2_word, ids_2_label):
probs = inferer.infer(input=test_batch, field=['value']) probs = inferer.infer(input=test_batch, field=["value"])
assert len(probs) == len(test_batch) assert len(probs) == len(test_batch)
for word_ids, prob in zip(test_batch, probs): for word_ids, prob in zip(test_batch, probs):
word_text = " ".join([ids_2_word[id] for id in word_ids[0]]) word_text = " ".join([ids_2_word[id] for id in word_ids[0]])
...@@ -22,7 +22,7 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path, ...@@ -22,7 +22,7 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
" ".join(["{:0.4f}".format(p) " ".join(["{:0.4f}".format(p)
for p in prob]), word_text)) for p in prob]), word_text))
logger.info('begin to predict...') logger.info("begin to predict...")
use_default_data = (data_dir is None) use_default_data = (data_dir is None)
if use_default_data: if use_default_data:
...@@ -33,9 +33,9 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path, ...@@ -33,9 +33,9 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
test_reader = paddle.dataset.imdb.test(word_dict) test_reader = paddle.dataset.imdb.test(word_dict)
else: else:
assert os.path.exists( assert os.path.exists(
word_dict_path), 'the word dictionary file does not exist' word_dict_path), "the word dictionary file does not exist"
assert os.path.exists( assert os.path.exists(
label_dict_path), 'the label dictionary file does not exist' label_dict_path), "the label dictionary file does not exist"
word_dict = load_dict(word_dict_path) word_dict = load_dict(word_dict_path)
word_reverse_dict = load_reverse_dict(word_dict_path) word_reverse_dict = load_reverse_dict(word_dict_path)
...@@ -52,7 +52,7 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path, ...@@ -52,7 +52,7 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
# load the trained models # load the trained models
parameters = paddle.parameters.Parameters.from_tar( parameters = paddle.parameters.Parameters.from_tar(
gzip.open(model_path, 'r')) gzip.open(model_path, "r"))
inferer = paddle.inference.Inference( inferer = paddle.inference.Inference(
output_layer=prob_layer, parameters=parameters) output_layer=prob_layer, parameters=parameters)
...@@ -70,19 +70,19 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path, ...@@ -70,19 +70,19 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
test_batch = [] test_batch = []
if __name__ == '__main__': if __name__ == "__main__":
model_path = 'dnn_params_pass_00000.tar.gz' model_path = "models/dnn_params_pass_00000.tar.gz"
assert os.path.exists(model_path), "the trained model does not exist." assert os.path.exists(model_path), "the trained model does not exist."
nn_type = 'dnn' nn_type = "dnn"
test_dir = None test_dir = None
word_dict = None word_dict = None
label_dict = None label_dict = None
if nn_type == 'dnn': if nn_type == "dnn":
topology = network_conf.fc_net topology = fc_net
elif nn_type == 'cnn': elif nn_type == "cnn":
topology = network_conf.convolution_net topology = convolution_net
infer( infer(
topology=topology, topology=topology,
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
import sys import sys
import gzip import gzip
import paddle.v2 as paddle import paddle.v2 as paddle
import network_conf
import reader import reader
from utils import * from utils import logger, parse_train_cmd, build_dict, load_dict
from network_conf import fc_net, convolution_net
def train(topology, def train(topology,
...@@ -15,6 +16,7 @@ def train(topology, ...@@ -15,6 +16,7 @@ def train(topology,
test_data_dir=None, test_data_dir=None,
word_dict_path=None, word_dict_path=None,
label_dict_path=None, label_dict_path=None,
model_save_dir="models",
batch_size=32, batch_size=32,
num_passes=10): num_passes=10):
""" """
...@@ -33,6 +35,8 @@ def train(topology, ...@@ -33,6 +35,8 @@ def train(topology,
:params num_pass: train pass number :params num_pass: train pass number
:type num_pass: int :type num_pass: int
""" """
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
use_default_data = (train_data_dir is None) use_default_data = (train_data_dir is None)
...@@ -136,8 +140,9 @@ def train(topology, ...@@ -136,8 +140,9 @@ def train(topology,
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
logger.info("Test at Pass %d, %s \n" % (event.pass_id, logger.info("Test at Pass %d, %s \n" % (event.pass_id,
result.metrics)) result.metrics))
with gzip.open("dnn_params_pass_%05d.tar.gz" % event.pass_id, with gzip.open(
"w") as f: os.path.join(model_save_dir, "dnn_params_pass_%05d.tar.gz" %
event.pass_id), "w") as f:
parameters.to_tar(f) parameters.to_tar(f)
trainer.train( trainer.train(
...@@ -151,9 +156,9 @@ def train(topology, ...@@ -151,9 +156,9 @@ def train(topology,
def main(args): def main(args):
if args.nn_type == "dnn": if args.nn_type == "dnn":
topology = network_conf.fc_net topology = fc_net
elif args.nn_type == "cnn": elif args.nn_type == "cnn":
topology = network_conf.convolution_net topology = convolution_net
train( train(
topology=topology, topology=topology,
...@@ -162,7 +167,8 @@ def main(args): ...@@ -162,7 +167,8 @@ def main(args):
word_dict_path=args.word_dict, word_dict_path=args.word_dict,
label_dict_path=args.label_dict, label_dict_path=args.label_dict,
batch_size=args.batch_size, batch_size=args.batch_size,
num_passes=args.num_passes) num_passes=args.num_passes,
model_save_dir=args.model_save_dir)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import argparse import argparse
from collections import defaultdict from collections import defaultdict
logger = logging.getLogger("logger") logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -60,6 +60,12 @@ def parse_train_cmd(): ...@@ -60,6 +60,12 @@ def parse_train_cmd():
help="the number of training examples in one forward/backward pass") help="the number of training examples in one forward/backward pass")
parser.add_argument( parser.add_argument(
"--num_passes", type=int, default=10, help="number of passes to train") "--num_passes", type=int, default=10, help="number of passes to train")
parser.add_argument(
"--model_save_dir",
type=str,
required=False,
help=("path to save the trained models."),
default="models")
return parser.parse_args() return parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册