未验证 提交 d4dabe3e 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

framework.py enhancement (#8471)

* framework.py enhancement

* polish

* clean up

* enforce the inputs of Operator __init__ of type Variable

* python2 assert

* reverse assert
上级 7a9098a6
...@@ -152,7 +152,7 @@ class Variable(object): ...@@ -152,7 +152,7 @@ class Variable(object):
shape(tuple|list|None): The shape of variable. -1 means the batch size. shape(tuple|list|None): The shape of variable. -1 means the batch size.
Some kinds of variable do not contain shape, just set it to None. Some kinds of variable do not contain shape, just set it to None.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable. dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable.
lod_level(int): The level of lod tensor. 0 means there is not a time lod_level(int): The level of lod tensor. 0 means it is not a time
series data. series data.
persistable(bool): True if the variable should be saved as check point. persistable(bool): True if the variable should be saved as check point.
Defaults to False. Defaults to False.
...@@ -346,7 +346,7 @@ class OpProtoHolder(object): ...@@ -346,7 +346,7 @@ class OpProtoHolder(object):
def __init__(self): def __init__(self):
assert not hasattr( assert not hasattr(
self.__class__, self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!' '_instance'), 'Please use `instance()` to get OpProtoHolder object!'
op_protos = get_all_op_protos() op_protos = get_all_op_protos()
self.op_proto_map = {} self.op_proto_map = {}
for proto in op_protos: for proto in op_protos:
...@@ -368,8 +368,8 @@ class OpProtoHolder(object): ...@@ -368,8 +368,8 @@ class OpProtoHolder(object):
class Operator(object): class Operator(object):
""" """
Python Operator class. The operator represents the build in instructs in a Python Operator class. The operator represents the build in instructions in a
Block. Users can use the build in instructs to describe their neural Block. Users can use the build in instructions to describe their neural
network. network.
""" """
...@@ -478,7 +478,7 @@ class Operator(object): ...@@ -478,7 +478,7 @@ class Operator(object):
raise TypeError("'attrs' should be a dict.") raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if (not attr_name in attrs) or (attrs[attr_name] is None): if (attr_name not in attrs) or (attrs[attr_name] is None):
continue continue
if isinstance(attrs[attr_name], Block): if isinstance(attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
...@@ -751,7 +751,7 @@ class Block(object): ...@@ -751,7 +751,7 @@ class Block(object):
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
var = Variable(self, *args, **kwargs) var = Variable(block=self, *args, **kwargs)
if 'initializer' in kwargs: if 'initializer' in kwargs:
kwargs['initializer'](var, self) kwargs['initializer'](var, self)
return var return var
...@@ -822,13 +822,13 @@ class Block(object): ...@@ -822,13 +822,13 @@ class Block(object):
def append_op(self, *args, **kwargs): def append_op(self, *args, **kwargs):
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.append(op) self.ops.append(op)
return op return op
def delete_ops(self, ops): def delete_ops(self, ops):
# remove from cpp # remove from cpp
# FIXME(typhoonzero): remove only the first occuracy. # FIXME(typhoonzero): remove only the first occurrence.
try: try:
start = list(self.ops).index(ops[0]) start = list(self.ops).index(ops[0])
end = list(self.ops).index(ops[-1]) end = list(self.ops).index(ops[-1])
...@@ -846,6 +846,11 @@ class Block(object): ...@@ -846,6 +846,11 @@ class Block(object):
return op return op
def sync_with_cpp(self): def sync_with_cpp(self):
"""
Sync with the desc on the c++ end.
This method is used to synchronize the c++ desc instance generated by backward.
"""
# sync variables from cpp # sync variables from cpp
for var in self.desc.all_vars(): for var in self.desc.all_vars():
if not self.has_var(var.name()): if not self.has_var(var.name()):
...@@ -891,9 +896,9 @@ class Block(object): ...@@ -891,9 +896,9 @@ class Block(object):
def copy_param_info_from(self, other): def copy_param_info_from(self, other):
""" """
Copy the information of parameters from other block Copy the information of parameters from the other block
Args: Args:
other(Block): other block other(Block): the other block
Returns: Returns:
None None
...@@ -1239,6 +1244,6 @@ def get_var(name, program=None): ...@@ -1239,6 +1244,6 @@ def get_var(name, program=None):
if program is None: if program is None:
program = default_main_program() program = default_main_program()
assert isinstance(name, str) assert isinstance(name, str)
assert isinstance(name, Program) assert isinstance(program, Program)
return program.global_block().var(name) return program.global_block().var(name)
...@@ -104,7 +104,7 @@ def fc(input, ...@@ -104,7 +104,7 @@ def fc(input,
* :math:`X_i`: The input tensor. * :math:`X_i`: The input tensor.
* :math:`W`: The weights created by this layer. * :math:`W`: The weights created by this layer.
* :math:`b`: The bias parameter created by this layer (if needed). * :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation funtion. * :math:`Act`: The activation function.
* :math:`Out`: The output tensor. * :math:`Out`: The output tensor.
Args: Args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册