From 4274883a75d7e4b9317b0f74abfc4ceca470a132 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Thu, 13 Apr 2017 14:27:20 +0800 Subject: [PATCH] add field "prob" in paddle.infer --- paddle/py_paddle/util.py | 6 +++++- python/paddle/v2/inference.py | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/py_paddle/util.py b/paddle/py_paddle/util.py index 1c9455fab..3ae8dbf96 100644 --- a/paddle/py_paddle/util.py +++ b/paddle/py_paddle/util.py @@ -83,13 +83,17 @@ def __arguments_to_numpy__(i, arg): assert isinstance(arg, swig_paddle.Arguments) value = arg.getSlotValue(i) ids = arg.getSlotIds(i) + prob = arg.getSlotIn(i) if value is not None: assert isinstance(value, swig_paddle.Matrix) value = value.copyToNumpyMat() if ids is not None: assert isinstance(ids, swig_paddle.IVector) ids = ids.copyToNumpyArray() - return {"value": value, "id": ids} + if prob is not None: + assert isinstance(prob, swig_paddle.Matrix) + prob = prob.copyToNumpyMat() + return {"value": value, "id": ids, "prob": prob} def __monkeypatch_gradient_machine__(): diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 31a5d26e6..2860f18e1 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -81,9 +81,11 @@ def infer(output_layer, parameters, input, feeding=None, field='value'): :type input: collections.Iterable :param feeding: Reader dictionary. Default could generate from input value. - :param field: The prediction field. It should in [`value`, `ids`]. `value` - means return the prediction probabilities, `ids` means return - the prediction labels. Default is `value` + :param field: The prediction field. It should in [`value`, `id`, `prob`]. + `value` and `prob` mean return the prediction probabilities, + `id` means return the prediction labels. Default is `value`. + Note that `prob` only used when output_layer is beam_search + or max_id. :type field: str :return: a numpy array :rtype: numpy.ndarray -- GitLab