提交 b25c5124 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #1780 from luotao1/in

add field "prob" in paddle.infer
...@@ -83,13 +83,17 @@ def __arguments_to_numpy__(i, arg): ...@@ -83,13 +83,17 @@ def __arguments_to_numpy__(i, arg):
assert isinstance(arg, swig_paddle.Arguments) assert isinstance(arg, swig_paddle.Arguments)
value = arg.getSlotValue(i) value = arg.getSlotValue(i)
ids = arg.getSlotIds(i) ids = arg.getSlotIds(i)
prob = arg.getSlotIn(i)
if value is not None: if value is not None:
assert isinstance(value, swig_paddle.Matrix) assert isinstance(value, swig_paddle.Matrix)
value = value.copyToNumpyMat() value = value.copyToNumpyMat()
if ids is not None: if ids is not None:
assert isinstance(ids, swig_paddle.IVector) assert isinstance(ids, swig_paddle.IVector)
ids = ids.copyToNumpyArray() ids = ids.copyToNumpyArray()
return {"value": value, "id": ids} if prob is not None:
assert isinstance(prob, swig_paddle.Matrix)
prob = prob.copyToNumpyMat()
return {"value": value, "id": ids, "prob": prob}
def __monkeypatch_gradient_machine__(): def __monkeypatch_gradient_machine__():
......
...@@ -81,9 +81,11 @@ def infer(output_layer, parameters, input, feeding=None, field='value'): ...@@ -81,9 +81,11 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
:type input: collections.Iterable :type input: collections.Iterable
:param feeding: Reader dictionary. Default could generate from input :param feeding: Reader dictionary. Default could generate from input
value. value.
:param field: The prediction field. It should in [`value`, `ids`]. `value` :param field: The prediction field. It should in [`value`, `id`, `prob`].
means return the prediction probabilities, `ids` means return `value` and `prob` mean return the prediction probabilities,
the prediction labels. Default is `value` `id` means return the prediction labels. Default is `value`.
Note that `prob` only used when output_layer is beam_search
or max_id.
:type field: str :type field: str
:return: a numpy array :return: a numpy array
:rtype: numpy.ndarray :rtype: numpy.ndarray
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册