提交 f6dc09e9 编写于 作者: X Xin Pan

void hurting declarative performance

test=develop
上级 748549b2
...@@ -213,7 +213,7 @@ def _debug_string_(proto, throw_on_error=True): ...@@ -213,7 +213,7 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__() return proto.__str__()
class Variable(core.VarBase): class Variable(object):
""" """
In Fluid, every input and output of an operator is a variable. In most In Fluid, every input and output of an operator is a variable. In most
cases, variables are used for holding different kinds of data or training cases, variables are used for holding different kinds of data or training
...@@ -277,7 +277,6 @@ class Variable(core.VarBase): ...@@ -277,7 +277,6 @@ class Variable(core.VarBase):
stop_gradient=False, stop_gradient=False,
is_data=False, is_data=False,
**kwargs): **kwargs):
core.VarBase.__init__(self)
self.block = block self.block = block
self.error_clip = error_clip self.error_clip = error_clip
...@@ -357,6 +356,9 @@ class Variable(core.VarBase): ...@@ -357,6 +356,9 @@ class Variable(core.VarBase):
self.op = None self.op = None
self.stop_gradient = stop_gradient self.stop_gradient = stop_gradient
self.is_data = is_data self.is_data = is_data
if _in_imperative_mode():
self._ivar = core.VarBase()
self._ivar.desc = self.desc
def _numpy(self): def _numpy(self):
scope = _imperative_tracer().get_scope(self.block.desc) scope = _imperative_tracer().get_scope(self.block.desc)
...@@ -365,10 +367,10 @@ class Variable(core.VarBase): ...@@ -365,10 +367,10 @@ class Variable(core.VarBase):
def _backward(self): def _backward(self):
scope = _imperative_tracer().get_scope(self.block.desc) scope = _imperative_tracer().get_scope(self.block.desc)
self._run_backward(scope) self._ivar._run_backward(scope)
def _gradient(self): def _gradient(self):
return np.array(self._grad()) return np.array(self._ivar._grad())
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
...@@ -516,7 +518,7 @@ class OpProtoHolder(object): ...@@ -516,7 +518,7 @@ class OpProtoHolder(object):
} }
class Operator(core.OpBase): class Operator(object):
""" """
In Fluid, all the operation are represented by Operator, and Operator In Fluid, all the operation are represented by Operator, and Operator
is regarded as a build in an instruction of a Block. Users can use the is regarded as a build in an instruction of a Block. Users can use the
...@@ -572,7 +574,6 @@ class Operator(core.OpBase): ...@@ -572,7 +574,6 @@ class Operator(core.OpBase):
inputs=None, inputs=None,
outputs=None, outputs=None,
attrs=None): attrs=None):
core.OpBase.__init__(self)
self.block = block self.block = block
self.desc = desc self.desc = desc
# note: not add self.attrs here: # note: not add self.attrs here:
...@@ -612,7 +613,6 @@ class Operator(core.OpBase): ...@@ -612,7 +613,6 @@ class Operator(core.OpBase):
return True return True
return False return False
self.inputs = []
if inputs is not None: if inputs is not None:
for in_proto in proto.inputs: for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name) found = find_name(inputs, in_proto.name)
...@@ -639,13 +639,6 @@ class Operator(core.OpBase): ...@@ -639,13 +639,6 @@ class Operator(core.OpBase):
else: else:
self.desc.set_input(in_proto.name, []) self.desc.set_input(in_proto.name, [])
for inp in inputs.values():
if isinstance(inp, Variable):
self.inputs.append(inp)
elif isinstance(inp, list) or isinstance(inp, tuple):
self.inputs.extend(inp[:])
self.outputs = []
if outputs is not None: if outputs is not None:
given = set() given = set()
need = set() need = set()
...@@ -674,12 +667,6 @@ class Operator(core.OpBase): ...@@ -674,12 +667,6 @@ class Operator(core.OpBase):
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
for out in outputs.values():
if isinstance(out, Variable):
self.outputs.append(out)
elif isinstance(out, list) or isinstance(out, tuple):
self.outputs.extend(out[:])
if op_attrs is not None: if op_attrs is not None:
if not isinstance(op_attrs, dict): if not isinstance(op_attrs, dict):
raise TypeError("'attrs' should be a dict.") raise TypeError("'attrs' should be a dict.")
...@@ -694,6 +681,23 @@ class Operator(core.OpBase): ...@@ -694,6 +681,23 @@ class Operator(core.OpBase):
if self._has_kernel(type): if self._has_kernel(type):
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
if _in_imperative_mode():
self.iop = core.OpBase()
self.iop.desc = self.desc
self.inputs = []
if inputs is not None:
for inp in inputs.values():
if isinstance(inp, Variable):
self.inputs.append(inp)
elif isinstance(inp, list) or isinstance(inp, tuple):
self.inputs.extend(inp[:])
self.outputs = []
if outputs is not None:
for out in outputs.values():
if isinstance(out, Variable):
self.outputs.append(out)
elif isinstance(out, list) or isinstance(out, tuple):
self.outputs.extend(out[:])
def _has_kernel(self, op_type): def _has_kernel(self, op_type):
return op_type not in self.OP_WITHOUT_KERNEL_SET return op_type not in self.OP_WITHOUT_KERNEL_SET
...@@ -1246,7 +1250,8 @@ class Block(object): ...@@ -1246,7 +1250,8 @@ class Block(object):
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
op = Operator(block=self, desc=op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
if _in_imperative_mode(): if _in_imperative_mode():
_imperative_tracer().trace(op, op.inputs, op.outputs, self.desc) _imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs],
[v._ivar for v in op.outputs], self.desc)
self.ops.append(op) self.ops.append(op)
return op return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册