提交 af523df4 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4023 from reyoung/feature/serialize_protobuf

Make paddle.v2.inference can direct load protobuf
...@@ -2,6 +2,7 @@ import numpy ...@@ -2,6 +2,7 @@ import numpy
import collections import collections
import topology import topology
import minibatch import minibatch
import cPickle
__all__ = ['infer', 'Inference'] __all__ = ['infer', 'Inference']
...@@ -25,11 +26,23 @@ class Inference(object): ...@@ -25,11 +26,23 @@ class Inference(object):
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
""" """
def __init__(self, output_layer, parameters): def __init__(self, parameters, output_layer=None, fileobj=None):
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
if output_layer is not None:
topo = topology.Topology(output_layer) topo = topology.Topology(output_layer)
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createFromConfigProto(
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
self.__data_types__ = topo.data_type()
elif fileobj is not None:
tmp = cPickle.load(fileobj)
gm = api.GradientMachine.createByConfigProtoStr(
tmp['protobin'], api.CREATE_MODE_TESTING,
[api.PARAMETER_VALUE])
self.__data_types__ = tmp['data_type']
else:
raise ValueError("Either output_layer or fileobj must be set")
for param in gm.getParameters(): for param in gm.getParameters():
val = param.getBuf(api.PARAMETER_VALUE) val = param.getBuf(api.PARAMETER_VALUE)
name = param.getName() name = param.getName()
...@@ -43,7 +56,6 @@ class Inference(object): ...@@ -43,7 +56,6 @@ class Inference(object):
# called here, but it's better to call this function in one place. # called here, but it's better to call this function in one place.
param.setValueUpdated() param.setValueUpdated()
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type()
def iter_infer(self, input, feeding=None): def iter_infer(self, input, feeding=None):
from data_feeder import DataFeeder from data_feeder import DataFeeder
......
...@@ -18,6 +18,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig ...@@ -18,6 +18,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig
import paddle.trainer_config_helpers as conf_helps import paddle.trainer_config_helpers as conf_helps
import layer as v2_layer import layer as v2_layer
import config_base import config_base
import cPickle
__all__ = ['Topology'] __all__ = ['Topology']
...@@ -100,6 +101,14 @@ class Topology(object): ...@@ -100,6 +101,14 @@ class Topology(object):
return layer return layer
return None return None
def serialize_for_inference(self, stream):
protobin = self.proto().SerializeToString()
data_type = self.data_type()
cPickle.dump({
'protobin': protobin,
'data_type': data_type
}, stream, cPickle.HIGHEST_PROTOCOL)
def __check_layer_type__(layer): def __check_layer_type__(layer):
if not isinstance(layer, config_base.Layer): if not isinstance(layer, config_base.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册