提交 2e7f71d9 编写于 作者: Y Yibing Liu

Enable ce for sequence_tagging_for_ner_ce

上级 062478b9
###!/bin/bash
####This file is only used for continuous evaluation.
export CE_MODE_X=1
python train.py | python _ce.py
####this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi, DurationKpi, AccKpi
#### NOTE kpi.py should shared in models in some way!!!!
train_acc_kpi = AccKpi('train_precision', 0.005, actived=True)
test_acc_kpi = CostKpi('test_precision', 0.005, actived=True)
train_duration_kpi = DurationKpi('train_duration', 0.05, actived=True)
tracking_kpis = [
train_acc_kpi,
test_acc_kpi,
train_duration_kpi,
]
def parse_log(log):
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
print("-----%s" % fs)
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
print("*****")
print(log)
print("****")
log_to_ce(log)
import os import os
import math import math
import time
import numpy as np import numpy as np
import paddle import paddle
...@@ -65,21 +66,31 @@ def main(train_data_file, ...@@ -65,21 +66,31 @@ def main(train_data_file,
test_target = chunk_evaluator.metrics + chunk_evaluator.states test_target = chunk_evaluator.metrics + chunk_evaluator.states
inference_program = fluid.io.get_inference_program(test_target) inference_program = fluid.io.get_inference_program(test_target)
train_reader = paddle.batch( if "CE_MODE_X" not in os.environ:
paddle.reader.shuffle( train_reader = paddle.batch(
paddle.reader.shuffle(
reader.data_reader(train_data_file, word_dict, label_dict),
buf_size=20000),
batch_size=batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.data_reader(test_data_file, word_dict, label_dict),
buf_size=20000),
batch_size=batch_size)
else:
train_reader = paddle.batch(
reader.data_reader(train_data_file, word_dict, label_dict), reader.data_reader(train_data_file, word_dict, label_dict),
buf_size=20000), batch_size=batch_size)
batch_size=batch_size) test_reader = paddle.batch(
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.data_reader(test_data_file, word_dict, label_dict), reader.data_reader(test_data_file, word_dict, label_dict),
buf_size=20000), batch_size=batch_size)
batch_size=batch_size)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[word, mark, target], place=place) feeder = fluid.DataFeeder(feed_list=[word, mark, target], place=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
if "CE_MODE_X" in os.environ:
fluid.default_startup_program().random_seed = 110
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
embedding_name = 'emb' embedding_name = 'emb'
...@@ -114,6 +125,13 @@ def main(train_data_file, ...@@ -114,6 +125,13 @@ def main(train_data_file,
fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'], fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'],
crf_decode, exe) crf_decode, exe)
if ("CE_MODE_X" in os.environ) and (pass_id % 50 == 0):
if pass_id > 0:
print("kpis train_precision %f" % pass_precision)
print("kpis test_precision %f" % test_pass_precision)
print("kpis train_duration %f" % (time.time() - time_begin))
time_begin = time.time()
if __name__ == "__main__": if __name__ == "__main__":
main( main(
...@@ -123,7 +141,7 @@ if __name__ == "__main__": ...@@ -123,7 +141,7 @@ if __name__ == "__main__":
target_file="data/target.txt", target_file="data/target.txt",
emb_file="data/wordVectors.txt", emb_file="data/wordVectors.txt",
model_save_dir="models", model_save_dir="models",
num_passes=100, num_passes=1000,
batch_size=1, batch_size=1,
use_gpu=False, use_gpu=False,
parallel=False) parallel=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册