diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 19624a704f1848131d909ee73ee0b5984e1a2e07..e80456d9bbeb3c34ac9eab873a84dbf8f06e34df 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -2,6 +2,7 @@ import numpy import collections import topology import minibatch +import cPickle __all__ = ['infer', 'Inference'] @@ -25,18 +26,23 @@ class Inference(object): :type parameters: paddle.v2.parameters.Parameters """ - def __init__(self, output_layer, parameters, data_types=None): + def __init__(self, parameters, output_layer=None, fileobj=None): import py_paddle.swig_paddle as api - if isinstance(output_layer, str): - gm = api.GradientMachine.createByConfigProtoStr(output_layer) - if data_types is None: - raise ValueError("data_types != None when using protobuf bin") - self.__data_types__ = data_types - else: + + if output_layer is not None: topo = topology.Topology(output_layer) gm = api.GradientMachine.createFromConfigProto( topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) self.__data_types__ = topo.data_type() + elif fileobj is not None: + tmp = cPickle.load(fileobj) + gm = api.GradientMachine.createByConfigProtoStr( + tmp['protobin'], api.CREATE_MODE_TESTING, + [api.PARAMETER_VALUE]) + self.__data_types__ = tmp['data_type'] + else: + raise ValueError("Either output_layer or fileobj must be set") + for param in gm.getParameters(): val = param.getBuf(api.PARAMETER_VALUE) name = param.getName() diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index a20e878d0817d0a75e9c47a44f8765deca99225c..2db66be2505dde38a501edf45984e1f36beb351d 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -18,6 +18,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig import paddle.trainer_config_helpers as conf_helps import layer as v2_layer import config_base +import cPickle __all__ = ['Topology'] @@ -100,6 +101,14 @@ class Topology(object): return layer return None + def serialize_for_inference(self, stream): + protobin = self.proto().SerializeToString() + data_type = self.data_type() + cPickle.dump({ + 'protobin': protobin, + 'data_type': data_type + }, stream, cPickle.HIGHEST_PROTOCOL) + def __check_layer_type__(layer): if not isinstance(layer, config_base.Layer):