提交 5fddd288 编写于 作者: F fengjiayi

Merge branch 'feature/add_persistable_in_var_desc' into dev_opdesc_in_python

import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import collections import collections
__all__ = ['Block', 'Variable', 'Program', 'Operator'] __all__ = ['Block', 'Variable', 'Program', 'Operator']
def get_all_op_protos():
"""
Get all registered op proto from PaddlePaddle C++ end.
:return: A list of registered OpProto.
"""
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values
class OpProtoHolder(object):
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self):
assert not hasattr(
self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
op_protos = get_all_op_protos()
self.op_proto_map = {}
for proto in op_protos:
sefl.op_proto_map[proto.type] = proto
def get_op_proto(self, type):
assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type
return self.op_proto_map[type]
class Variable(object): class Variable(object):
def __init__(self, block, name=None, shape=None, dtype=None, def __init__(self, block, name=None, shape=None, dtype=None,
lod_level=None): lod_level=None):
...@@ -36,27 +71,50 @@ class Variable(object): ...@@ -36,27 +71,50 @@ class Variable(object):
class Operator(object): class Operator(object):
def __init__(self, def __init__(self, block, desc, type, inputs=None, outputs=None,
block,
proto,
type=None,
inputs=None,
outputs=None,
attrs=None): attrs=None):
self.block = block self.block = block
self.proto = proto self.desc = desc
if type is not None: self.proto = OpProtoHolder.instance().get_op_proto(type)
# TODO. self.desc.set_type(type)
pass
if inputs is not None: if inputs is not None:
# TODO for in_proto in self.proto.inputs:
pass in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list):
in_argus = [in_argus]
if not in_proto.duplicable and len(in_argus) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given." %
(in_proto.name, len(in_argus)))
in_argu_names = []
for argu in in_argus:
in_argu_names.append(argu.name())
self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None: if outputs is not None:
# TODO for out_proto in self.proto.outputs:
pass out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list):
out_argus = [out_argus]
if not out_proto.duplicable and len(out_argus) > 1:
raise ValueError(
"Output %s expects only one output, but %d are given." %
(out_proto.name, len(out_argus)))
out_argu_names = []
for argu in out_argus:
out_argu_names.append(argu.name())
self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None: if attrs is not None:
# TODO for attr in self.proto.attrs:
pass attr_name = attr.name
if not attr_name in attrs:
continue
if not isinstance(attrs[attr_name], Block):
self.desc.set_attr(attr_name, attrs[attr_name])
else:
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
# TODO: Getters # TODO: Getters
...@@ -80,8 +138,8 @@ class Block(object): ...@@ -80,8 +138,8 @@ class Block(object):
return Variable(self, *args, **kwargs) return Variable(self, *args, **kwargs)
def append_op(self, *args, **kwargs): def append_op(self, *args, **kwargs):
op_proto = self.proto.append_op() op_desc = self.proto.append_op()
op = Operator(self, op_proto, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
self.ops.append(op) self.ops.append(op)
return op return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册