提交 4274883a 编写于 作者: L Luo Tao

add field "prob" in paddle.infer

上级 aa230bfb
......@@ -83,13 +83,17 @@ def __arguments_to_numpy__(i, arg):
assert isinstance(arg, swig_paddle.Arguments)
value = arg.getSlotValue(i)
ids = arg.getSlotIds(i)
prob = arg.getSlotIn(i)
if value is not None:
assert isinstance(value, swig_paddle.Matrix)
value = value.copyToNumpyMat()
if ids is not None:
assert isinstance(ids, swig_paddle.IVector)
ids = ids.copyToNumpyArray()
return {"value": value, "id": ids}
if prob is not None:
assert isinstance(prob, swig_paddle.Matrix)
prob = prob.copyToNumpyMat()
return {"value": value, "id": ids, "prob": prob}
def __monkeypatch_gradient_machine__():
......
......@@ -81,9 +81,11 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
:type input: collections.Iterable
:param feeding: Reader dictionary. Default could generate from input
value.
:param field: The prediction field. It should in [`value`, `ids`]. `value`
means return the prediction probabilities, `ids` means return
the prediction labels. Default is `value`
:param field: The prediction field. It should in [`value`, `id`, `prob`].
`value` and `prob` mean return the prediction probabilities,
`id` means return the prediction labels. Default is `value`.
Note that `prob` only used when output_layer is beam_search
or max_id.
:type field: str
:return: a numpy array
:rtype: numpy.ndarray
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册