提交 c48b1c51 编写于 作者: J jshower

rm reduntant code

上级 6b8b92b2
......@@ -64,7 +64,7 @@ def infer(model_path, batch_size, test_data_file, vocab_file, target_file,
if __name__ == "__main__":
infer(
model_path="models/params_pass_16",
model_path="models/params_pass_0",
batch_size=6,
test_data_file="data/test",
vocab_file="data/vocab.txt",
......
......@@ -65,7 +65,7 @@ def main(train_data_file,
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
label=target,
chunk_scheme="IOB",
......@@ -74,14 +74,9 @@ def main(train_data_file,
inference_program = fluid.default_main_program().clone(for_test=True)
test_fetch_list = [num_infer_chunks, num_label_chunks, num_correct_chunks]
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-2)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
#print(inference_program)
#exit(0)
#test_fetch_list = [num_infer_chunks, num_label_chunks, num_correct_chunks]
if "CE_MODE_X" not in os.environ:
train_reader = paddle.batch(
paddle.reader.shuffle(
......@@ -115,7 +110,6 @@ def main(train_data_file,
for pass_id in six.moves.xrange(num_passes):
chunk_evaluator.reset()
for batch_id, data in enumerate(train_reader()):
#print(data)
cost_var, nums_infer, nums_label, nums_correct = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
......@@ -151,8 +145,8 @@ def main(train_data_file,
if __name__ == "__main__":
main(
train_data_file="data/train.txt",
test_data_file="data/test.txt",
train_data_file="data/train",
test_data_file="data/test",
vocab_file="data/vocab.txt",
target_file="data/target.txt",
emb_file="data/wordVectors.txt",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册