提交 ca3e9774 编写于 作者: K kinghuin 提交者: wuzewu

Kinghuin optimpred (#280)

* finish coding predict

* optimize predict pr

* optimize _postprocessing

* optim namedtuple
上级 6ee63f7e
......@@ -41,7 +41,7 @@ args = parser.parse_args()
if __name__ == '__main__':
# Load Paddlehub BERT pretrained model
module = hub.Module(name="ernie_eng_base.hub_module")
module = hub.Module(name="ernie_v2_eng_base")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
......@@ -97,12 +97,14 @@ if __name__ == '__main__':
index = 0
run_states = multi_label_cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for result in results:
# get predict index
label_ids = []
for i in range(dataset.num_labels):
label_val = np.argmax(result[i])
label_ids.append(label_val)
print("%s\tpredict=%s" % (data[index][0], label_ids))
index += 1
all_result = []
for batch_state in run_states:
batch_result = batch_state.run_results
for sample_id in range(len(batch_result[0])):
sample_result = []
for category_id in range(dataset.num_labels):
sample_category_prob = batch_result[category_id][sample_id]
sample_result.append(np.argmax(sample_category_prob))
all_result.append(sample_result)
print(all_result)
......@@ -767,7 +767,7 @@ class BaseTask(object):
self._eval_end_event(run_states)
return run_states
def predict(self, data, load_best_model=True):
def predict(self, data, load_best_model=True, return_result=False):
with self.phase_guard(phase="predict"):
if load_best_model:
self.init_if_load_best_model()
......@@ -778,8 +778,17 @@ class BaseTask(object):
run_states = self._run()
self._predict_end_event(run_states)
self._predict_data = None
if return_result:
return self._postprocessing(run_states)
return run_states
def _postprocessing(self, run_states):
results = []
for batch_state in run_states:
batch_result = batch_state.run_results[0]
results += [result[0] for result in batch_result]
return results
def _run(self, do_eval=False):
with fluid.program_guard(self.main_program, self.startup_program):
if self.config.use_pyreader:
......
......@@ -134,6 +134,22 @@ class ClassifierTask(BaseTask):
return scores, avg_loss, run_speed
def _postprocessing(self, run_states):
try:
id2label = {
val: key
for key, val in self._base_data_reader.label_map.items()
}
except:
raise Exception(
"image-classification does not support return_result now")
results = []
for batch_state in run_states:
batch_result = batch_state.run_results
batch_infer = np.argmax(batch_result, axis=2)[0]
results += [id2label[sample_infer] for sample_infer in batch_infer]
return results
ImageClassifierTask = ClassifierTask
......@@ -301,3 +317,16 @@ class MultiLabelClassifierTask(ClassifierTask):
if self.is_train_phase or self.is_test_phase:
return [metric.name for metric in self.metrics] + [self.loss.name]
return self.outputs
def _postprocessing(self, run_states):
results = []
for batch_state in run_states:
batch_result = batch_state.run_results
for sample_id in range(len(batch_result[0])):
sample_result = []
for category_id in range(
self._base_data_reader.dataset.num_labels):
sample_category_prob = batch_result[category_id][sample_id]
sample_result.append(np.argmax(sample_category_prob))
results.append(sample_result)
return results
......@@ -171,18 +171,15 @@ def get_final_text(pred_text, orig_text, do_lower_case, is_english):
return output_text
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file,
version_2_with_negative, null_score_diff_threshold,
is_english):
def get_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, version_2_with_negative,
null_score_diff_threshold, is_english):
_PrelimPrediction = collections.namedtuple("PrelimPrediction", [
"feature_index", "start_index", "end_index", "start_logit", "end_logit"
])
_NbestPrediction = collections.namedtuple(
"NbestPrediction", ["text", "start_logit", "end_logit"])
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
......@@ -363,25 +360,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
"""Write final predictions to the json file and log-odds of null if needed."""
with open(output_prediction_file, "w") as writer:
logger.info("Writing predictions to: %s" % (output_prediction_file))
writer.write(
json.dumps(all_predictions, indent=4, ensure_ascii=is_english) +
"\n")
with open(output_nbest_file, "w") as writer:
logger.info("Writing nbest to: %s" % (output_nbest_file))
writer.write(
json.dumps(all_nbest_json, indent=4, ensure_ascii=is_english) +
"\n")
if version_2_with_negative:
logger.info("Writing null_log_odds to: %s" % (output_nbest_file))
with open(output_null_log_odds_file, "w") as writer:
writer.write(
json.dumps(scores_diff_json, indent=4, ensure_ascii=is_english)
+ "\n")
return all_predictions, all_nbest_json, scores_diff_json
class ReadingComprehensionTask(BaseTask):
......@@ -419,6 +399,8 @@ class ReadingComprehensionTask(BaseTask):
self.null_score_diff_threshold = null_score_diff_threshold
self.n_best_size = n_best_size
self.max_answer_length = max_answer_length
self.RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"])
self.RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"])
......@@ -522,24 +504,15 @@ class ReadingComprehensionTask(BaseTask):
scores = OrderedDict()
# If none of metrics has been implemented, loss will be used to evaluate.
if self.is_test_phase:
output_prediction_file = os.path.join(self.config.checkpoint_dir,
"predictions.json")
output_nbest_file = os.path.join(self.config.checkpoint_dir,
"nbest_predictions.json")
output_null_log_odds_file = os.path.join(self.config.checkpoint_dir,
"null_odds.json")
all_examples = self.data_reader.all_examples[self.phase]
all_features = self.data_reader.all_features[self.phase]
write_predictions(
all_predictions, all_nbest_json, scores_diff_json = get_predictions(
all_examples=all_examples,
all_features=all_features,
all_results=all_results,
n_best_size=self.n_best_size,
max_answer_length=self.max_answer_length,
do_lower_case=True,
output_prediction_file=output_prediction_file,
output_nbest_file=output_nbest_file,
output_null_log_odds_file=output_null_log_odds_file,
version_2_with_negative=self.version_2_with_negative,
null_score_diff_threshold=self.null_score_diff_threshold,
is_english=self.is_english)
......@@ -558,25 +531,17 @@ class ReadingComprehensionTask(BaseTask):
else:
raise Exception("Error phase: %s when runing _calculate_metrics"
% self.phase)
with open(
output_prediction_file, 'r',
encoding="utf8") as prediction_file:
predictions = json.load(prediction_file)
if self.sub_task == "squad":
scores = squad1_evaluate.evaluate(dataset, predictions)
scores = squad1_evaluate.evaluate(dataset, all_predictions)
elif self.sub_task == "squad2.0":
with open(
output_null_log_odds_file, 'r',
encoding="utf8") as odds_file:
na_probs = json.load(odds_file)
scores = squad2_evaluate.evaluate(dataset, predictions,
na_probs)
scores = squad2_evaluate.evaluate(dataset, all_predictions,
scores_diff_json)
elif self.sub_task in ["cmrc2018", "drcd"]:
scores = cmrc2018_evaluate.get_eval(dataset, predictions)
scores = cmrc2018_evaluate.get_eval(dataset, all_predictions)
return scores, avg_loss, run_speed
def _default_predict_end_event(self, run_states):
def _postprocessing(self, run_states):
all_results = []
for run_state in run_states:
np_unique_ids = run_state.run_results[0]
......@@ -591,29 +556,16 @@ class ReadingComprehensionTask(BaseTask):
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
# If none of metrics has been implemented, loss will be used to evaluate.
output_prediction_file = os.path.join(self.config.checkpoint_dir,
"predict_predictions.json")
output_nbest_file = os.path.join(self.config.checkpoint_dir,
"predict_nbest_predictions.json")
output_null_log_odds_file = os.path.join(self.config.checkpoint_dir,
"predict_null_odds.json")
all_examples = self.data_reader.all_examples[self.phase]
all_features = self.data_reader.all_features[self.phase]
write_predictions(
all_predictions, all_nbest_json, scores_diff_json = get_predictions(
all_examples=all_examples,
all_features=all_features,
all_results=all_results,
n_best_size=self.n_best_size,
max_answer_length=self.max_answer_length,
do_lower_case=True,
output_prediction_file=output_prediction_file,
output_nbest_file=output_nbest_file,
output_null_log_odds_file=output_null_log_odds_file,
version_2_with_negative=self.version_2_with_negative,
null_score_diff_threshold=self.null_score_diff_threshold,
is_english=self.is_english)
logger.info("PaddleHub predict finished.")
logger.info("You can see the prediction in %s" % output_prediction_file)
return all_predictions
......@@ -120,3 +120,10 @@ class RegressionTask(BaseTask):
else:
raise ValueError("Not Support Metric: \"%s\"" % metric)
return scores, avg_loss, run_speed
def _postprocessing(self, run_states):
results = []
for batch_state in run_states:
batch_result = batch_state.run_results[0]
results += [result[0] for result in batch_result]
return results
......@@ -216,3 +216,22 @@ class SequenceLabelTask(BaseTask):
elif self.is_predict_phase:
return [self.ret_infers.name] + [self.seq_len.name]
return [output.name for output in self.outputs]
def _postprocessing(self, run_states):
id2label = {
val: key
for key, val in self._base_data_reader.label_map.items()
}
results = []
for batch_states in run_states:
batch_results = batch_states.run_results
batch_infers = batch_results[0].reshape([-1]).astype(
np.int32).tolist()
seq_lens = batch_results[1].reshape([-1]).astype(np.int32).tolist()
current_id = 0
for length in seq_lens:
seq_infers = batch_infers[current_id:current_id + length]
seq_result = list(map(id2label.get, seq_infers[1:-1]))
current_id += int(length)
results.append(seq_result)
return results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册