未验证 提交 092afb6c 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #57 from wangxiao1021/api

fix bugs
...@@ -337,7 +337,7 @@ def _write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -337,7 +337,7 @@ def _write_predictions(all_examples, all_features, all_results, n_best_size,
nbest_json = [] nbest_json = []
for (i, entry) in enumerate(nbest): for (i, entry) in enumerate(nbest):
output = collections.OrderedDict() output = collections.OrderedDict()
output["text"] = entry.text output["text"] = entry.text.encode('utf-8').decode('utf-8')
output["probability"] = probs[i] output["probability"] = probs[i]
output["start_logit"] = entry.start_logit output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit output["end_logit"] = entry.end_logit
...@@ -359,7 +359,10 @@ def _write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -359,7 +359,10 @@ def _write_predictions(all_examples, all_features, all_results, n_best_size,
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n") writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n")
with open(output_nbest_file, "w") as writer: with open(output_nbest_file, "w") as writer:
......
...@@ -22,6 +22,7 @@ import paddle ...@@ -22,6 +22,7 @@ import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import layers from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
import six
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
...@@ -35,7 +36,8 @@ def create_feed_batch_process_fn(net_inputs): ...@@ -35,7 +36,8 @@ def create_feed_batch_process_fn(net_inputs):
inputs= net_inputs inputs= net_inputs
for q, var in inputs.items(): for q, var in inputs.items():
if isinstance(var, str) or isinstance(var, unicode):
if isinstance(var, str) or (six.PY3 and isinstance(var, bytes)) or (six.PY2 and isinstance(var, unicode)):
temp[var] = data[q] temp[var] = data[q]
else: else:
temp[var.name] = data[q] temp[var.name] = data[q]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册