提交 0399449f 编写于 作者: W wuzewu

Update task finetune api to return run states

上级 93bad059
......@@ -168,7 +168,8 @@ if __name__ == '__main__':
]
index = 0
results = elmo_task.predict(data=data)
run_states = elmo_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
......
......@@ -76,7 +76,8 @@ def predict(args):
label_map = dataset.label_dict()
index = 0
# get classification result
results = task.predict(data=data)
run_states = task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
......
......@@ -99,7 +99,8 @@ if __name__ == '__main__':
]
index = 0
results = multi_label_cls_task.predict(data=data)
run_states = multi_label_cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for result in results:
# get predict index
label_ids = []
......
......@@ -60,7 +60,8 @@ if __name__ == '__main__':
data = ["这家餐厅很好吃", "这部电影真的很差劲"]
results = cls_task.predict(data=data)
run_states = cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
index = 0
for batch_result in results:
batch_result = np.argmax(batch_result, axis=2)[0]
......
......@@ -96,7 +96,8 @@ if __name__ == '__main__':
["不过重在晋趣,略增明人气息,妙在集古有道、不露痕迹罢了。"],
]
results = seq_label_task.predict(data=data)
run_states = seq_label_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for num_batch, batch_results in enumerate(results):
infers = batch_results[0].reshape([-1]).astype(np.int32).tolist()
......
......@@ -97,7 +97,8 @@ if __name__ == '__main__':
]
index = 0
results = cls_task.predict(data=data)
run_states = cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
......
......@@ -498,7 +498,7 @@ class BasicTask(object):
self.exe, dirname=dirname, main_program=self.main_program)
def finetune_and_eval(self):
self.finetune(do_eval=True)
return self.finetune(do_eval=True)
def finetune(self, do_eval=False):
# Start to finetune
......@@ -519,6 +519,7 @@ class BasicTask(object):
self.eval(phase="test")
self._finetune_end_event(run_states)
return run_states
def eval(self, phase="dev"):
with self.phase_guard(phase=phase):
......@@ -526,6 +527,7 @@ class BasicTask(object):
self._eval_start_event()
run_states = self._run()
self._eval_end_event(run_states)
return run_states
def predict(self, data, load_best_model=True):
with self.phase_guard(phase="predict"):
......@@ -539,7 +541,7 @@ class BasicTask(object):
run_states = self._run()
self._predict_end_event(run_states)
self._predict_data = None
return [run_state.run_results for run_state in run_states]
return run_states
def _run(self, do_eval=False):
with fluid.program_guard(self.main_program, self.startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册