diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 8acea6155c588f2e8e5ad009cd8f0a0c09afb92b..19624a704f1848131d909ee73ee0b5984e1a2e07 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -25,11 +25,18 @@ class Inference(object): :type parameters: paddle.v2.parameters.Parameters """ - def __init__(self, output_layer, parameters): + def __init__(self, output_layer, parameters, data_types=None): import py_paddle.swig_paddle as api - topo = topology.Topology(output_layer) - gm = api.GradientMachine.createFromConfigProto( - topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) + if isinstance(output_layer, str): + gm = api.GradientMachine.createByConfigProtoStr(output_layer) + if data_types is None: + raise ValueError("data_types != None when using protobuf bin") + self.__data_types__ = data_types + else: + topo = topology.Topology(output_layer) + gm = api.GradientMachine.createFromConfigProto( + topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) + self.__data_types__ = topo.data_type() for param in gm.getParameters(): val = param.getBuf(api.PARAMETER_VALUE) name = param.getName() @@ -43,7 +50,6 @@ class Inference(object): # called here, but it's better to call this function in one place. param.setValueUpdated() self.__gradient_machine__ = gm - self.__data_types__ = topo.data_type() def iter_infer(self, input, feeding=None): from data_feeder import DataFeeder