提交 2b3510bc 编写于 作者: M minqiyang

Add imperative python tracer

上级 e9fdf909
......@@ -196,6 +196,7 @@ class OpBase {
: op_desc_(nullptr),
forward_id_(-1),
backward_id_(-1),
trace_id_(-1),
place_(platform::CPUPlace()) {}
virtual ~OpBase() {
......@@ -216,6 +217,7 @@ class OpBase {
// Note: each fwd op corresponds to a vector of bwd ops.
std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_;
int trace_id_;
platform::Place place_;
......
......@@ -193,6 +193,16 @@ PYBIND11_MODULE(core, m) {
}
},
py::return_value_policy::reference)
.def_property("_trace_id",
[](const imperative::OpBase &self) {
pybind11::gil_scoped_release release;
return self.trace_id_;
},
[](imperative::OpBase &self, int trace_id) {
pybind11::gil_scoped_release release;
self.trace_id_ = trace_id;
},
py::return_value_policy::reference)
.def_property(
"forward_id",
[](const imperative::OpBase &self) { return self.forward_id_; },
......
......@@ -1193,13 +1193,13 @@ class Block(object):
raise ValueError("Var {0} is not found recursively".format(name))
def _clear_block(self):
# TODO(minqiyang): move this to backward_hooks
self.desc._clear_block()
assert _in_imperative_mode()
for name in self.vars.keys():
assert self.vars[name].persistable
# TODO(minqiyang): move this to Variable and Operator's __del__
self.desc._clear_block()
del self.ops[:]
assert len(self.vars) == 0
assert len(self.ops) == 0
def all_parameters(self):
return list(self.iter_parameters())
......@@ -1337,26 +1337,13 @@ class Block(object):
#
# TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False))
self.ops.append(op)
_imperative_tracer().trace_op(op,
kwargs.get("stop_gradient", False))
else:
self.ops.append(op)
return op
def _trace_op(self, op, stop_gradient=False):
backward_refs = _imperative_tracer().trace(
op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_, stop_gradient)
# TODO(minqiyang): support backward_hooks to eager remove backward_refs
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
for k, v in six.iteritems(op.outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
def _insert_op(self, index, *args, **kwargs):
"""
Insert a Operator according to the giving arguments.
......@@ -1409,9 +1396,11 @@ class Block(object):
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
self.ops.insert(0, op)
if _in_imperative_mode():
self._trace_op(op, kwargs.get("stop_gradient", False))
_imperative_tracer().trace_op(op,
kwargs.get("stop_gradient", False))
else:
self.ops.insert(0, op)
return op
def _sync_with_cpp(self):
......
......@@ -23,7 +23,11 @@ from .layers import *
from . import nn
from .nn import *
from . import tracer
from .tracer import *
__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
__all__ += nn.__all__
__all__ += tracer.__all__
......@@ -16,6 +16,7 @@ import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
from .tracer import Tracer
__all__ = ['enabled', 'guard', 'to_variable']
......@@ -28,7 +29,7 @@ def enabled():
def guard(place=None):
train = framework.Program()
startup = framework.Program()
tracer = core.Tracer(train.current_block().desc)
tracer = Tracer(train.current_block().desc)
if place is None:
if core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册