提交 f6dc09e9 编写于 作者: X Xin Pan

void hurting declarative performance

test=develop
上级 748549b2
......@@ -213,7 +213,7 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__()
class Variable(core.VarBase):
class Variable(object):
"""
In Fluid, every input and output of an operator is a variable. In most
cases, variables are used for holding different kinds of data or training
......@@ -277,7 +277,6 @@ class Variable(core.VarBase):
stop_gradient=False,
is_data=False,
**kwargs):
core.VarBase.__init__(self)
self.block = block
self.error_clip = error_clip
......@@ -357,6 +356,9 @@ class Variable(core.VarBase):
self.op = None
self.stop_gradient = stop_gradient
self.is_data = is_data
if _in_imperative_mode():
self._ivar = core.VarBase()
self._ivar.desc = self.desc
def _numpy(self):
scope = _imperative_tracer().get_scope(self.block.desc)
......@@ -365,10 +367,10 @@ class Variable(core.VarBase):
def _backward(self):
scope = _imperative_tracer().get_scope(self.block.desc)
self._run_backward(scope)
self._ivar._run_backward(scope)
def _gradient(self):
return np.array(self._grad())
return np.array(self._ivar._grad())
def __str__(self):
return self.to_string(True)
......@@ -516,7 +518,7 @@ class OpProtoHolder(object):
}
class Operator(core.OpBase):
class Operator(object):
"""
In Fluid, all the operation are represented by Operator, and Operator
is regarded as a build in an instruction of a Block. Users can use the
......@@ -572,7 +574,6 @@ class Operator(core.OpBase):
inputs=None,
outputs=None,
attrs=None):
core.OpBase.__init__(self)
self.block = block
self.desc = desc
# note: not add self.attrs here:
......@@ -612,7 +613,6 @@ class Operator(core.OpBase):
return True
return False
self.inputs = []
if inputs is not None:
for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name)
......@@ -639,13 +639,6 @@ class Operator(core.OpBase):
else:
self.desc.set_input(in_proto.name, [])
for inp in inputs.values():
if isinstance(inp, Variable):
self.inputs.append(inp)
elif isinstance(inp, list) or isinstance(inp, tuple):
self.inputs.extend(inp[:])
self.outputs = []
if outputs is not None:
given = set()
need = set()
......@@ -674,12 +667,6 @@ class Operator(core.OpBase):
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
for out in outputs.values():
if isinstance(out, Variable):
self.outputs.append(out)
elif isinstance(out, list) or isinstance(out, tuple):
self.outputs.extend(out[:])
if op_attrs is not None:
if not isinstance(op_attrs, dict):
raise TypeError("'attrs' should be a dict.")
......@@ -694,6 +681,23 @@ class Operator(core.OpBase):
if self._has_kernel(type):
self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc)
if _in_imperative_mode():
self.iop = core.OpBase()
self.iop.desc = self.desc
self.inputs = []
if inputs is not None:
for inp in inputs.values():
if isinstance(inp, Variable):
self.inputs.append(inp)
elif isinstance(inp, list) or isinstance(inp, tuple):
self.inputs.extend(inp[:])
self.outputs = []
if outputs is not None:
for out in outputs.values():
if isinstance(out, Variable):
self.outputs.append(out)
elif isinstance(out, list) or isinstance(out, tuple):
self.outputs.extend(out[:])
def _has_kernel(self, op_type):
return op_type not in self.OP_WITHOUT_KERNEL_SET
......@@ -1246,7 +1250,8 @@ class Block(object):
op_desc = self.desc.append_op()
op = Operator(block=self, desc=op_desc, *args, **kwargs)
if _in_imperative_mode():
_imperative_tracer().trace(op, op.inputs, op.outputs, self.desc)
_imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs],
[v._ivar for v in op.outputs], self.desc)
self.ops.append(op)
return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册