提交 fb32106e 编写于 作者: Y Yu Yang

Make paddle.v2.inference can direct load protobuf

上级 d71190f0
...@@ -25,11 +25,18 @@ class Inference(object): ...@@ -25,11 +25,18 @@ class Inference(object):
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
""" """
def __init__(self, output_layer, parameters): def __init__(self, output_layer, parameters, data_types=None):
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
topo = topology.Topology(output_layer) if isinstance(output_layer, str):
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createByConfigProtoStr(output_layer)
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) if data_types is None:
raise ValueError("data_types != None when using protobuf bin")
self.__data_types__ = data_types
else:
topo = topology.Topology(output_layer)
gm = api.GradientMachine.createFromConfigProto(
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
self.__data_types__ = topo.data_type()
for param in gm.getParameters(): for param in gm.getParameters():
val = param.getBuf(api.PARAMETER_VALUE) val = param.getBuf(api.PARAMETER_VALUE)
name = param.getName() name = param.getName()
...@@ -43,7 +50,6 @@ class Inference(object): ...@@ -43,7 +50,6 @@ class Inference(object):
# called here, but it's better to call this function in one place. # called here, but it's better to call this function in one place.
param.setValueUpdated() param.setValueUpdated()
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type()
def iter_infer(self, input, feeding=None): def iter_infer(self, input, feeding=None):
from data_feeder import DataFeeder from data_feeder import DataFeeder
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册