提交 90857a5d 编写于 作者: C caoying03

clean codes of text classification and ner.

上级 68caa8ca
import gzip
import reader
from network_conf import *
from utils import *
import paddle.v2 as paddle
from network_conf import ner_net
from utils import load_dict, load_reverse_dict
def infer(model_path, batch_size, test_data_file, vocab_file, target_file):
......
......@@ -2,8 +2,11 @@ import gzip
import numpy as np
import reader
from utils import *
from network_conf import *
from utils import logger, load_dict, get_embedding
from network_conf import ner_net
import paddle.v2 as paddle
import paddle.v2.evaluator as evaluator
def main(train_data_file,
......@@ -11,8 +14,12 @@ def main(train_data_file,
vocab_file,
target_file,
emb_file,
model_save_dir,
num_passes=10,
batch_size=32):
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
word_dict = load_dict(vocab_file)
label_dict = load_dict(target_file)
......@@ -77,8 +84,9 @@ def main(train_data_file,
if isinstance(event, paddle.event.EndPass):
# save parameters
with gzip.open("models/params_pass_%d.tar.gz" % event.pass_id,
"w") as f:
with gzip.open(
os.path.join(model_save_dir, "params_pass_%d.tar.gz" %
event.pass_id), "w") as f:
parameters.to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding)
......@@ -94,8 +102,8 @@ def main(train_data_file,
if __name__ == "__main__":
main(
train_data_file='data/train',
test_data_file='data/test',
vocab_file='data/vocab.txt',
target_file='data/target.txt',
emb_file='data/wordVectors.txt')
train_data_file="data/train",
test_data_file="data/test",
vocab_file="data/vocab.txt",
target_file="data/target.txt",
emb_file="data/wordVectors.txt")
......@@ -6,15 +6,15 @@ import gzip
import paddle.v2 as paddle
import network_conf
import reader
from utils import *
from network_conf import fc_net, convolution_net
from utils import logger, load_dict
def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
batch_size):
def _infer_a_batch(inferer, test_batch, ids_2_word, ids_2_label):
probs = inferer.infer(input=test_batch, field=['value'])
probs = inferer.infer(input=test_batch, field=["value"])
assert len(probs) == len(test_batch)
for word_ids, prob in zip(test_batch, probs):
word_text = " ".join([ids_2_word[id] for id in word_ids[0]])
......@@ -22,7 +22,7 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
" ".join(["{:0.4f}".format(p)
for p in prob]), word_text))
logger.info('begin to predict...')
logger.info("begin to predict...")
use_default_data = (data_dir is None)
if use_default_data:
......@@ -33,9 +33,9 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
test_reader = paddle.dataset.imdb.test(word_dict)
else:
assert os.path.exists(
word_dict_path), 'the word dictionary file does not exist'
word_dict_path), "the word dictionary file does not exist"
assert os.path.exists(
label_dict_path), 'the label dictionary file does not exist'
label_dict_path), "the label dictionary file does not exist"
word_dict = load_dict(word_dict_path)
word_reverse_dict = load_reverse_dict(word_dict_path)
......@@ -52,7 +52,7 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
# load the trained models
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(model_path, 'r'))
gzip.open(model_path, "r"))
inferer = paddle.inference.Inference(
output_layer=prob_layer, parameters=parameters)
......@@ -70,19 +70,19 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
test_batch = []
if __name__ == '__main__':
model_path = 'dnn_params_pass_00000.tar.gz'
if __name__ == "__main__":
model_path = "models/dnn_params_pass_00000.tar.gz"
assert os.path.exists(model_path), "the trained model does not exist."
nn_type = 'dnn'
nn_type = "dnn"
test_dir = None
word_dict = None
label_dict = None
if nn_type == 'dnn':
topology = network_conf.fc_net
elif nn_type == 'cnn':
topology = network_conf.convolution_net
if nn_type == "dnn":
topology = fc_net
elif nn_type == "cnn":
topology = convolution_net
infer(
topology=topology,
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import gzip
import paddle.v2 as paddle
import network_conf
import reader
from utils import *
from utils import logger, parse_train_cmd, build_dict, load_dict
from network_conf import fc_net, convolution_net
def train(topology,
......@@ -15,6 +16,7 @@ def train(topology,
test_data_dir=None,
word_dict_path=None,
label_dict_path=None,
model_save_dir="models",
batch_size=32,
num_passes=10):
"""
......@@ -33,6 +35,8 @@ def train(topology,
:params num_pass: train pass number
:type num_pass: int
"""
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
use_default_data = (train_data_dir is None)
......@@ -136,8 +140,9 @@ def train(topology,
result = trainer.test(reader=test_reader, feeding=feeding)
logger.info("Test at Pass %d, %s \n" % (event.pass_id,
result.metrics))
with gzip.open("dnn_params_pass_%05d.tar.gz" % event.pass_id,
"w") as f:
with gzip.open(
os.path.join(model_save_dir, "dnn_params_pass_%05d.tar.gz" %
event.pass_id), "w") as f:
parameters.to_tar(f)
trainer.train(
......@@ -151,9 +156,9 @@ def train(topology,
def main(args):
if args.nn_type == "dnn":
topology = network_conf.fc_net
topology = fc_net
elif args.nn_type == "cnn":
topology = network_conf.convolution_net
topology = convolution_net
train(
topology=topology,
......@@ -162,7 +167,8 @@ def main(args):
word_dict_path=args.word_dict,
label_dict_path=args.label_dict,
batch_size=args.batch_size,
num_passes=args.num_passes)
num_passes=args.num_passes,
model_save_dir=args.model_save_dir)
if __name__ == "__main__":
......
......@@ -5,7 +5,7 @@ import os
import argparse
from collections import defaultdict
logger = logging.getLogger("logger")
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
......@@ -60,6 +60,12 @@ def parse_train_cmd():
help="the number of training examples in one forward/backward pass")
parser.add_argument(
"--num_passes", type=int, default=10, help="number of passes to train")
parser.add_argument(
"--model_save_dir",
type=str,
required=False,
help=("path to save the trained models."),
default="models")
return parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册