提交 7c1434dd 编写于 作者: J jshower

code style

上级 d9a52223
...@@ -70,8 +70,8 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -70,8 +70,8 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
fluid.layers.embedding( fluid.layers.embedding(
size=[word_dict_len, word_dim], size=[word_dict_len, word_dim],
input=x, input=x,
param_attr=fluid.ParamAttr(name=embedding_name, trainable=False)) param_attr=fluid.ParamAttr(
for x in word_input name=embedding_name, trainable=False)) for x in word_input
] ]
emb_layers.append(predicate_embedding) emb_layers.append(predicate_embedding)
emb_layers.append(mark_embedding) emb_layers.append(mark_embedding)
...@@ -164,7 +164,8 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -164,7 +164,8 @@ def train(use_cuda, save_dirname=None, is_local=True):
crf_cost = fluid.layers.linear_chain_crf( crf_cost = fluid.layers.linear_chain_crf(
input=feature_out, input=feature_out,
label=target, label=target,
param_attr=fluid.ParamAttr(name='crfw', learning_rate=mix_hidden_lr)) param_attr=fluid.ParamAttr(
name='crfw', learning_rate=mix_hidden_lr))
avg_cost = fluid.layers.mean(crf_cost) avg_cost = fluid.layers.mean(crf_cost)
# TODO(qiao) # TODO(qiao)
...@@ -189,7 +190,8 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -189,7 +190,8 @@ def train(use_cuda, save_dirname=None, is_local=True):
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0))) num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
train_data = paddle.batch( train_data = paddle.batch(
paddle.reader.shuffle(paddle.dataset.conll05.test(), buf_size=8192), paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
...@@ -222,25 +224,24 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -222,25 +224,24 @@ def train(use_cuda, save_dirname=None, is_local=True):
exe) exe)
if batch_id % 10 == 0: if batch_id % 10 == 0:
print( print("avg_cost:" + str(cost) + " precision:" + str(
"avg_cost:" + str(cost) + " precision:" + precision) + " recall:" + str(recall) + " f1_score:" +
str(precision) + " recall:" + str(recall) + str(f1_score) + " pass_precision:" + str(
" f1_score:" + str(f1_score) + " pass_precision:" + str( pass_precision) + " pass_recall:" + str(
pass_precision) + " pass_recall:" + str(pass_recall) pass_recall) + " pass_f1_score:" + str(
+ " pass_f1_score:" + str(pass_f1_score)) pass_f1_score))
if batch_id != 0: if batch_id != 0:
print("second per batch: " + str( print("second per batch: " + str((time.time(
(time.time() - start_time) / batch_id)) ) - start_time) / batch_id))
# Set the threshold low to speed up the CI test # Set the threshold low to speed up the CI test
if float(pass_precision) > 0.05: if float(pass_precision) > 0.05:
if save_dirname is not None: if save_dirname is not None:
# TODO(liuyiqun): Change the target to crf_decode # TODO(liuyiqun): Change the target to crf_decode
fluid.io.save_inference_model( fluid.io.save_inference_model(save_dirname, [
save_dirname, [ 'word_data', 'verb_data', 'ctx_n2_data',
'word_data', 'verb_data', 'ctx_n2_data', 'ctx_n1_data', 'ctx_0_data', 'ctx_p1_data',
'ctx_n1_data', 'ctx_0_data', 'ctx_p1_data', 'ctx_p2_data', 'mark_data'
'ctx_p2_data', 'mark_data' ], [feature_out], exe)
], [feature_out], exe)
return return
batch_id = batch_id + 1 batch_id = batch_id + 1
...@@ -320,20 +321,19 @@ def infer(use_cuda, save_dirname=None): ...@@ -320,20 +321,19 @@ def infer(use_cuda, save_dirname=None):
assert feed_target_names[6] == 'ctx_p2_data' assert feed_target_names[6] == 'ctx_p2_data'
assert feed_target_names[7] == 'mark_data' assert feed_target_names[7] == 'mark_data'
results = exe.run( results = exe.run(inference_program,
inference_program, feed={
feed={ feed_target_names[0]: word,
feed_target_names[0]: word, feed_target_names[1]: pred,
feed_target_names[1]: pred, feed_target_names[2]: ctx_n2,
feed_target_names[2]: ctx_n2, feed_target_names[3]: ctx_n1,
feed_target_names[3]: ctx_n1, feed_target_names[4]: ctx_0,
feed_target_names[4]: ctx_0, feed_target_names[5]: ctx_p1,
feed_target_names[5]: ctx_p1, feed_target_names[6]: ctx_p2,
feed_target_names[6]: ctx_p2, feed_target_names[7]: mark
feed_target_names[7]: mark },
}, fetch_list=fetch_targets,
fetch_list=fetch_targets, return_numpy=False)
return_numpy=False)
print(results[0].lod()) print(results[0].lod())
np_data = np.array(results[0]) np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape) print("Inference Shape: ", np_data.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册