提交 9abf40c9 编写于 作者: M minqiyang 提交者: ceci3

Add imperative python tracer

上级 92f3cf42
...@@ -205,6 +205,7 @@ class OpBase { ...@@ -205,6 +205,7 @@ class OpBase {
: op_desc_(nullptr), : op_desc_(nullptr),
forward_id_(-1), forward_id_(-1),
backward_id_(-1), backward_id_(-1),
trace_id_(-1),
place_(platform::CPUPlace()) {} place_(platform::CPUPlace()) {}
virtual ~OpBase() { virtual ~OpBase() {
...@@ -225,6 +226,7 @@ class OpBase { ...@@ -225,6 +226,7 @@ class OpBase {
// Note: each fwd op corresponds to a vector of bwd ops. // Note: each fwd op corresponds to a vector of bwd ops.
std::vector<framework::OpDesc*> grad_op_descs_; std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_; int backward_id_;
int trace_id_;
platform::Place place_; platform::Place place_;
......
...@@ -193,6 +193,16 @@ PYBIND11_MODULE(core, m) { ...@@ -193,6 +193,16 @@ PYBIND11_MODULE(core, m) {
} }
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property("_trace_id",
[](const imperative::OpBase &self) {
pybind11::gil_scoped_release release;
return self.trace_id_;
},
[](imperative::OpBase &self, int trace_id) {
pybind11::gil_scoped_release release;
self.trace_id_ = trace_id;
},
py::return_value_policy::reference)
.def_property( .def_property(
"forward_id", "forward_id",
[](const imperative::OpBase &self) { return self.forward_id_; }, [](const imperative::OpBase &self) { return self.forward_id_; },
......
...@@ -1201,13 +1201,13 @@ class Block(object): ...@@ -1201,13 +1201,13 @@ class Block(object):
raise ValueError("Var {0} is not found recursively".format(name)) raise ValueError("Var {0} is not found recursively".format(name))
def _clear_block(self): def _clear_block(self):
# TODO(minqiyang): move this to backward_hooks assert _in_imperative_mode()
self.desc._clear_block()
for name in self.vars.keys(): # TODO(minqiyang): move this to Variable and Operator's __del__
assert self.vars[name].persistable self.desc._clear_block()
del self.ops[:] assert len(self.vars) == 0
assert len(self.ops) == 0
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
...@@ -1345,26 +1345,13 @@ class Block(object): ...@@ -1345,26 +1345,13 @@ class Block(object):
# #
# TODO(minqiyang): add op stop_gradient support in static mode too. # TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode. # currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False)) _imperative_tracer().trace_op(op,
self.ops.append(op) kwargs.get("stop_gradient", False))
else:
self.ops.append(op)
return op return op
def _trace_op(self, op, stop_gradient=False):
backward_refs = _imperative_tracer().trace(
op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_, stop_gradient)
# TODO(minqiyang): support backward_hooks to eager remove backward_refs
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
for k, v in six.iteritems(op.outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
def _insert_op(self, index, *args, **kwargs): def _insert_op(self, index, *args, **kwargs):
""" """
Insert a Operator according to the giving arguments. Insert a Operator according to the giving arguments.
...@@ -1417,9 +1404,11 @@ class Block(object): ...@@ -1417,9 +1404,11 @@ class Block(object):
inputs=kwargs.get("inputs", None), inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None), outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", None))
self.ops.insert(0, op)
if _in_imperative_mode(): if _in_imperative_mode():
self._trace_op(op, kwargs.get("stop_gradient", False)) _imperative_tracer().trace_op(op,
kwargs.get("stop_gradient", False))
else:
self.ops.insert(0, op)
return op return op
def _sync_with_cpp(self): def _sync_with_cpp(self):
......
...@@ -23,7 +23,11 @@ from .layers import * ...@@ -23,7 +23,11 @@ from .layers import *
from . import nn from . import nn
from .nn import * from .nn import *
from . import tracer
from .tracer import *
__all__ = [] __all__ = []
__all__ += layers.__all__ __all__ += layers.__all__
__all__ += base.__all__ __all__ += base.__all__
__all__ += nn.__all__ __all__ += nn.__all__
__all__ += tracer.__all__
...@@ -16,6 +16,7 @@ import numpy as np ...@@ -16,6 +16,7 @@ import numpy as np
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .tracer import Tracer
__all__ = ['enabled', 'guard', 'to_variable'] __all__ = ['enabled', 'guard', 'to_variable']
...@@ -28,7 +29,7 @@ def enabled(): ...@@ -28,7 +29,7 @@ def enabled():
def guard(place=None): def guard(place=None):
train = framework.Program() train = framework.Program()
startup = framework.Program() startup = framework.Program()
tracer = core.Tracer(train.current_block().desc) tracer = Tracer(train.current_block().desc)
if place is None: if place is None:
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册