From 2b352212c27ccdccb94a2878d823b2150d74bf00 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 11 Sep 2017 17:42:20 -0700 Subject: [PATCH] Add serialize to file for topology and read file obj for inference --- python/paddle/v2/inference.py | 20 +++++++++++++------- python/paddle/v2/topology.py | 9 +++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 19624a704f1..e80456d9bbe 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -2,6 +2,7 @@ import numpy import collections import topology import minibatch +import cPickle __all__ = ['infer', 'Inference'] @@ -25,18 +26,23 @@ class Inference(object): :type parameters: paddle.v2.parameters.Parameters """ - def __init__(self, output_layer, parameters, data_types=None): + def __init__(self, parameters, output_layer=None, fileobj=None): import py_paddle.swig_paddle as api - if isinstance(output_layer, str): - gm = api.GradientMachine.createByConfigProtoStr(output_layer) - if data_types is None: - raise ValueError("data_types != None when using protobuf bin") - self.__data_types__ = data_types - else: + + if output_layer is not None: topo = topology.Topology(output_layer) gm = api.GradientMachine.createFromConfigProto( topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) self.__data_types__ = topo.data_type() + elif fileobj is not None: + tmp = cPickle.load(fileobj) + gm = api.GradientMachine.createByConfigProtoStr( + tmp['protobin'], api.CREATE_MODE_TESTING, + [api.PARAMETER_VALUE]) + self.__data_types__ = tmp['data_type'] + else: + raise ValueError("Either output_layer or fileobj must be set") + for param in gm.getParameters(): val = param.getBuf(api.PARAMETER_VALUE) name = param.getName() diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index a20e878d081..2db66be2505 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -18,6 +18,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig import paddle.trainer_config_helpers as conf_helps import layer as v2_layer import config_base +import cPickle __all__ = ['Topology'] @@ -100,6 +101,14 @@ class Topology(object): return layer return None + def serialize_for_inference(self, stream): + protobin = self.proto().SerializeToString() + data_type = self.data_type() + cPickle.dump({ + 'protobin': protobin, + 'data_type': data_type + }, stream, cPickle.HIGHEST_PROTOCOL) + def __check_layer_type__(layer): if not isinstance(layer, config_base.Layer): -- GitLab