From 4c73240d51f27a8f5a6390bcd66fc2f4e48c0854 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Wed, 17 May 2017 17:30:44 +0800 Subject: [PATCH] follow comments. --- demo/semantic_role_labeling/api_train_v2.py | 134 ++++++++++++-------- 1 file changed, 83 insertions(+), 51 deletions(-) diff --git a/demo/semantic_role_labeling/api_train_v2.py b/demo/semantic_role_labeling/api_train_v2.py index f9c1fa80d..3af636aef 100644 --- a/demo/semantic_role_labeling/api_train_v2.py +++ b/demo/semantic_role_labeling/api_train_v2.py @@ -6,6 +6,8 @@ import paddle.v2.dataset.conll05 as conll05 import paddle.v2.evaluator as evaluator import paddle.v2 as paddle +logger = logging.getLogger('paddle') + word_dict, verb_dict, label_dict = conll05.get_dict() word_dict_len = len(word_dict) label_dict_len = len(label_dict) @@ -120,19 +122,7 @@ def load_parameter(file_name, h, w): return np.fromfile(f, dtype=np.float32).reshape(h, w) -def test_a_batch(inferer, test_data, tag_dict): - probs = inferer.infer(input=test_data, field='id') - assert len(probs) == sum(len(x[0]) for x in test_data) - for test_sample in test_data: - start_id = 0 - pre_lab = [ - tag_dict[probs[start_id + i]] for i in xrange(len(test_sample[0])) - ] - print pre_lab - start_id += len(test_sample[0]) - - -def main(is_predict=False): +def train(): paddle.init(use_gpu=False, trainer_count=1) # define network topology @@ -189,12 +179,12 @@ def main(is_predict=False): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) - if event.batch_id % 1000 == 0: + logger.info("Pass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics)) + if event.batch_id and event.batch_id % 1000 == 0: result = trainer.test(reader=reader, feeding=feeding) - print "\nTest with Pass %d, Batch %d, %s" % ( - event.pass_id, event.batch_id, result.metrics) + logger.info("\nTest with Pass %d, Batch %d, %s" % + (event.pass_id, event.batch_id, result.metrics)) if isinstance(event, paddle.event.EndPass): # save parameters @@ -202,44 +192,86 @@ def main(is_predict=False): parameters.to_tar(f) result = trainer.test(reader=reader, feeding=feeding) - print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) - - if not is_predict: - trainer.train( - reader=reader, - event_handler=event_handler, - num_passes=10, - feeding=feeding) - else: - labels_reverse = {} - for (k, v) in label_dict.items(): - labels_reverse[v] = k - test_creator = paddle.dataset.conll05.test() + logger.info("\nTest with Pass %d, %s" % + (event.pass_id, result.metrics)) + + trainer.train( + reader=reader, + event_handler=event_handler, + num_passes=10, + feeding=feeding) - predict = paddle.layer.crf_decoding( - size=label_dict_len, - input=feature_out, - param_attr=paddle.attr.Param(name='crfw')) - test_pass = 0 - with gzip.open('params_pass_%d.tar.gz' % (test_pass)) as f: - parameters = paddle.parameters.Parameters.from_tar(f) - inferer = paddle.inference.Inference( - output_layer=predict, parameters=parameters) +def infer_a_batch(inferer, test_data, word_dict, pred_dict, label_dict): + probs = inferer.infer(input=test_data, field='id') + assert len(probs) == sum(len(x[0]) for x in test_data) - # prepare test data - test_data = [] - test_batch_size = 50 + for idx, test_sample in enumerate(test_data): + start_id = 0 + pred_str = "%s\t" % (pred_dict[test_sample[6][0]]) - for idx, item in enumerate(test_creator()): - test_data.append(item[0:8]) + for w, tag in zip(test_sample[0], + probs[start_id:start_id + len(test_sample[0])]): + pred_str += "%s[%s] " % (word_dict[w], label_dict[tag]) + print(pred_str.strip()) + start_id += len(test_sample[0]) - if idx and (not idx % test_batch_size): - test_a_batch(inferer, test_data, labels_reverse) - test_data = [] - test_a_batch(inferer, test_data, labels_reverse) - test_data = [] + +def infer(): + label_dict_reverse = dict((value, key) + for key, value in label_dict.iteritems()) + word_dict_reverse = dict((value, key) + for key, value in word_dict.iteritems()) + pred_dict_reverse = dict((value, key) + for key, value in verb_dict.iteritems()) + + test_creator = paddle.dataset.conll05.test() + + paddle.init(use_gpu=False, trainer_count=1) + + # define network topology + feature_out = db_lstm() + predict = paddle.layer.crf_decoding( + size=label_dict_len, + input=feature_out, + param_attr=paddle.attr.Param(name='crfw')) + + test_pass = 0 + with gzip.open('params_pass_%d.tar.gz' % (test_pass)) as f: + parameters = paddle.parameters.Parameters.from_tar(f) + inferer = paddle.inference.Inference( + output_layer=predict, parameters=parameters) + + # prepare test data + test_data = [] + test_batch_size = 50 + + for idx, item in enumerate(test_creator()): + test_data.append(item[0:8]) + + if idx and (not idx % test_batch_size): + infer_a_batch( + inferer, + test_data, + word_dict_reverse, + pred_dict_reverse, + label_dict_reverse, ) + test_data = [] + infer_a_batch( + inferer, + test_data, + word_dict_reverse, + pred_dict_reverse, + label_dict_reverse, ) + test_data = [] + + +def main(is_inferring=False): + if is_inferring: + infer() + else: + train() if __name__ == '__main__': - main(is_predict=False) + main(is_inferring=False) -- GitLab