未验证 提交 d4dabe3e 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

framework.py enhancement (#8471)

* framework.py enhancement

* polish

* clean up

* enforce the inputs of Operator __init__ of type Variable

* python2 assert

* reverse assert
上级 7a9098a6
......@@ -152,7 +152,7 @@ class Variable(object):
shape(tuple|list|None): The shape of variable. -1 means the batch size.
Some kinds of variable do not contain shape, just set it to None.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable.
lod_level(int): The level of lod tensor. 0 means there is not a time
lod_level(int): The level of lod tensor. 0 means it is not a time
series data.
persistable(bool): True if the variable should be saved as check point.
Defaults to False.
......@@ -346,7 +346,7 @@ class OpProtoHolder(object):
def __init__(self):
assert not hasattr(
self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
'_instance'), 'Please use `instance()` to get OpProtoHolder object!'
op_protos = get_all_op_protos()
self.op_proto_map = {}
for proto in op_protos:
......@@ -368,8 +368,8 @@ class OpProtoHolder(object):
class Operator(object):
"""
Python Operator class. The operator represents the build in instructs in a
Block. Users can use the build in instructs to describe their neural
Python Operator class. The operator represents the build in instructions in a
Block. Users can use the build in instructions to describe their neural
network.
"""
......@@ -478,7 +478,7 @@ class Operator(object):
raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs:
attr_name = attr.name
if (not attr_name in attrs) or (attrs[attr_name] is None):
if (attr_name not in attrs) or (attrs[attr_name] is None):
continue
if isinstance(attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
......@@ -751,7 +751,7 @@ class Block(object):
if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs):
var = Variable(self, *args, **kwargs)
var = Variable(block=self, *args, **kwargs)
if 'initializer' in kwargs:
kwargs['initializer'](var, self)
return var
......@@ -822,13 +822,13 @@ class Block(object):
def append_op(self, *args, **kwargs):
op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.append(op)
return op
def delete_ops(self, ops):
# remove from cpp
# FIXME(typhoonzero): remove only the first occuracy.
# FIXME(typhoonzero): remove only the first occurrence.
try:
start = list(self.ops).index(ops[0])
end = list(self.ops).index(ops[-1])
......@@ -846,6 +846,11 @@ class Block(object):
return op
def sync_with_cpp(self):
"""
Sync with the desc on the c++ end.
This method is used to synchronize the c++ desc instance generated by backward.
"""
# sync variables from cpp
for var in self.desc.all_vars():
if not self.has_var(var.name()):
......@@ -891,9 +896,9 @@ class Block(object):
def copy_param_info_from(self, other):
"""
Copy the information of parameters from other block
Copy the information of parameters from the other block
Args:
other(Block): other block
other(Block): the other block
Returns:
None
......@@ -1239,6 +1244,6 @@ def get_var(name, program=None):
if program is None:
program = default_main_program()
assert isinstance(name, str)
assert isinstance(name, Program)
assert isinstance(program, Program)
return program.global_block().var(name)
......@@ -104,7 +104,7 @@ def fc(input,
* :math:`X_i`: The input tensor.
* :math:`W`: The weights created by this layer.
* :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation funtion.
* :math:`Act`: The activation function.
* :math:`Out`: The output tensor.
Args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册