From fb32106e246695fc91a63186fb22a68c66f98a33 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 11 Sep 2017 17:14:01 -0700 Subject: [PATCH] Make paddle.v2.inference can direct load protobuf --- python/paddle/v2/inference.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 8acea6155c..19624a704f 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -25,11 +25,18 @@ class Inference(object): :type parameters: paddle.v2.parameters.Parameters """ - def __init__(self, output_layer, parameters): + def __init__(self, output_layer, parameters, data_types=None): import py_paddle.swig_paddle as api - topo = topology.Topology(output_layer) - gm = api.GradientMachine.createFromConfigProto( - topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) + if isinstance(output_layer, str): + gm = api.GradientMachine.createByConfigProtoStr(output_layer) + if data_types is None: + raise ValueError("data_types != None when using protobuf bin") + self.__data_types__ = data_types + else: + topo = topology.Topology(output_layer) + gm = api.GradientMachine.createFromConfigProto( + topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) + self.__data_types__ = topo.data_type() for param in gm.getParameters(): val = param.getBuf(api.PARAMETER_VALUE) name = param.getName() @@ -43,7 +50,6 @@ class Inference(object): # called here, but it's better to call this function in one place. param.setValueUpdated() self.__gradient_machine__ = gm - self.__data_types__ = topo.data_type() def iter_infer(self, input, feeding=None): from data_feeder import DataFeeder -- GitLab