提交 0399449f 编写于 作者: W wuzewu

Update task finetune api to return run states

上级 93bad059
...@@ -168,7 +168,8 @@ if __name__ == '__main__': ...@@ -168,7 +168,8 @@ if __name__ == '__main__':
] ]
index = 0 index = 0
results = elmo_task.predict(data=data) run_states = elmo_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results: for batch_result in results:
# get predict index # get predict index
batch_result = np.argmax(batch_result, axis=2)[0] batch_result = np.argmax(batch_result, axis=2)[0]
......
...@@ -76,7 +76,8 @@ def predict(args): ...@@ -76,7 +76,8 @@ def predict(args):
label_map = dataset.label_dict() label_map = dataset.label_dict()
index = 0 index = 0
# get classification result # get classification result
results = task.predict(data=data) run_states = task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results: for batch_result in results:
# get predict index # get predict index
batch_result = np.argmax(batch_result, axis=2)[0] batch_result = np.argmax(batch_result, axis=2)[0]
......
...@@ -99,7 +99,8 @@ if __name__ == '__main__': ...@@ -99,7 +99,8 @@ if __name__ == '__main__':
] ]
index = 0 index = 0
results = multi_label_cls_task.predict(data=data) run_states = multi_label_cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for result in results: for result in results:
# get predict index # get predict index
label_ids = [] label_ids = []
......
...@@ -60,7 +60,8 @@ if __name__ == '__main__': ...@@ -60,7 +60,8 @@ if __name__ == '__main__':
data = ["这家餐厅很好吃", "这部电影真的很差劲"] data = ["这家餐厅很好吃", "这部电影真的很差劲"]
results = cls_task.predict(data=data) run_states = cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
index = 0 index = 0
for batch_result in results: for batch_result in results:
batch_result = np.argmax(batch_result, axis=2)[0] batch_result = np.argmax(batch_result, axis=2)[0]
......
...@@ -96,7 +96,8 @@ if __name__ == '__main__': ...@@ -96,7 +96,8 @@ if __name__ == '__main__':
["不过重在晋趣,略增明人气息,妙在集古有道、不露痕迹罢了。"], ["不过重在晋趣,略增明人气息,妙在集古有道、不露痕迹罢了。"],
] ]
results = seq_label_task.predict(data=data) run_states = seq_label_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for num_batch, batch_results in enumerate(results): for num_batch, batch_results in enumerate(results):
infers = batch_results[0].reshape([-1]).astype(np.int32).tolist() infers = batch_results[0].reshape([-1]).astype(np.int32).tolist()
......
...@@ -97,7 +97,8 @@ if __name__ == '__main__': ...@@ -97,7 +97,8 @@ if __name__ == '__main__':
] ]
index = 0 index = 0
results = cls_task.predict(data=data) run_states = cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results: for batch_result in results:
# get predict index # get predict index
batch_result = np.argmax(batch_result, axis=2)[0] batch_result = np.argmax(batch_result, axis=2)[0]
......
...@@ -498,7 +498,7 @@ class BasicTask(object): ...@@ -498,7 +498,7 @@ class BasicTask(object):
self.exe, dirname=dirname, main_program=self.main_program) self.exe, dirname=dirname, main_program=self.main_program)
def finetune_and_eval(self): def finetune_and_eval(self):
self.finetune(do_eval=True) return self.finetune(do_eval=True)
def finetune(self, do_eval=False): def finetune(self, do_eval=False):
# Start to finetune # Start to finetune
...@@ -519,6 +519,7 @@ class BasicTask(object): ...@@ -519,6 +519,7 @@ class BasicTask(object):
self.eval(phase="test") self.eval(phase="test")
self._finetune_end_event(run_states) self._finetune_end_event(run_states)
return run_states
def eval(self, phase="dev"): def eval(self, phase="dev"):
with self.phase_guard(phase=phase): with self.phase_guard(phase=phase):
...@@ -526,6 +527,7 @@ class BasicTask(object): ...@@ -526,6 +527,7 @@ class BasicTask(object):
self._eval_start_event() self._eval_start_event()
run_states = self._run() run_states = self._run()
self._eval_end_event(run_states) self._eval_end_event(run_states)
return run_states
def predict(self, data, load_best_model=True): def predict(self, data, load_best_model=True):
with self.phase_guard(phase="predict"): with self.phase_guard(phase="predict"):
...@@ -539,7 +541,7 @@ class BasicTask(object): ...@@ -539,7 +541,7 @@ class BasicTask(object):
run_states = self._run() run_states = self._run()
self._predict_end_event(run_states) self._predict_end_event(run_states)
self._predict_data = None self._predict_data = None
return [run_state.run_results for run_state in run_states] return run_states
def _run(self, do_eval=False): def _run(self, do_eval=False):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册