diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 1511eea68cbb2c4df655bdb6ae13c6be5c6412a9..4bf0a456b5200eb93c2bcd947aac58283ac7b41a 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -213,7 +213,7 @@ def _debug_string_(proto, throw_on_error=True): return proto.__str__() -class Variable(core.VarBase): +class Variable(object): """ In Fluid, every input and output of an operator is a variable. In most cases, variables are used for holding different kinds of data or training @@ -277,7 +277,6 @@ class Variable(core.VarBase): stop_gradient=False, is_data=False, **kwargs): - core.VarBase.__init__(self) self.block = block self.error_clip = error_clip @@ -357,6 +356,9 @@ class Variable(core.VarBase): self.op = None self.stop_gradient = stop_gradient self.is_data = is_data + if _in_imperative_mode(): + self._ivar = core.VarBase() + self._ivar.desc = self.desc def _numpy(self): scope = _imperative_tracer().get_scope(self.block.desc) @@ -365,10 +367,10 @@ class Variable(core.VarBase): def _backward(self): scope = _imperative_tracer().get_scope(self.block.desc) - self._run_backward(scope) + self._ivar._run_backward(scope) def _gradient(self): - return np.array(self._grad()) + return np.array(self._ivar._grad()) def __str__(self): return self.to_string(True) @@ -516,7 +518,7 @@ class OpProtoHolder(object): } -class Operator(core.OpBase): +class Operator(object): """ In Fluid, all the operation are represented by Operator, and Operator is regarded as a build in an instruction of a Block. Users can use the @@ -572,7 +574,6 @@ class Operator(core.OpBase): inputs=None, outputs=None, attrs=None): - core.OpBase.__init__(self) self.block = block self.desc = desc # note: not add self.attrs here: @@ -612,7 +613,6 @@ class Operator(core.OpBase): return True return False - self.inputs = [] if inputs is not None: for in_proto in proto.inputs: found = find_name(inputs, in_proto.name) @@ -639,13 +639,6 @@ class Operator(core.OpBase): else: self.desc.set_input(in_proto.name, []) - for inp in inputs.values(): - if isinstance(inp, Variable): - self.inputs.append(inp) - elif isinstance(inp, list) or isinstance(inp, tuple): - self.inputs.extend(inp[:]) - - self.outputs = [] if outputs is not None: given = set() need = set() @@ -674,12 +667,6 @@ class Operator(core.OpBase): arg.op = self self.desc.set_output(out_proto.name, out_arg_names) - for out in outputs.values(): - if isinstance(out, Variable): - self.outputs.append(out) - elif isinstance(out, list) or isinstance(out, tuple): - self.outputs.extend(out[:]) - if op_attrs is not None: if not isinstance(op_attrs, dict): raise TypeError("'attrs' should be a dict.") @@ -694,6 +681,23 @@ class Operator(core.OpBase): if self._has_kernel(type): self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc) + if _in_imperative_mode(): + self.iop = core.OpBase() + self.iop.desc = self.desc + self.inputs = [] + if inputs is not None: + for inp in inputs.values(): + if isinstance(inp, Variable): + self.inputs.append(inp) + elif isinstance(inp, list) or isinstance(inp, tuple): + self.inputs.extend(inp[:]) + self.outputs = [] + if outputs is not None: + for out in outputs.values(): + if isinstance(out, Variable): + self.outputs.append(out) + elif isinstance(out, list) or isinstance(out, tuple): + self.outputs.extend(out[:]) def _has_kernel(self, op_type): return op_type not in self.OP_WITHOUT_KERNEL_SET @@ -1246,7 +1250,8 @@ class Block(object): op_desc = self.desc.append_op() op = Operator(block=self, desc=op_desc, *args, **kwargs) if _in_imperative_mode(): - _imperative_tracer().trace(op, op.inputs, op.outputs, self.desc) + _imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], + [v._ivar for v in op.outputs], self.desc) self.ops.append(op) return op