未验证 提交 187cffd0 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #15928 from velconia/imperative_backward_hooks

Imperative backward hooks
...@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include <queue> #include <queue>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -155,6 +159,16 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { ...@@ -155,6 +159,16 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
ops_.erase(ops_.begin() + s, ops_.begin() + e); ops_.erase(ops_.begin() + s, ops_.begin() + e);
} }
void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) {
// TODO(minqiyang): make this faster
for (auto it = ops_.begin(); it != ops_.end(); ++it) {
if (it->get() == op_desc) {
ops_.erase(it);
break;
}
}
}
std::vector<OpDesc *> BlockDesc::AllOps() const { std::vector<OpDesc *> BlockDesc::AllOps() const {
std::vector<OpDesc *> res; std::vector<OpDesc *> res;
for (const auto &op : ops_) { for (const auto &op : ops_) {
...@@ -163,20 +177,6 @@ std::vector<OpDesc *> BlockDesc::AllOps() const { ...@@ -163,20 +177,6 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
return res; return res;
} }
void BlockDesc::Clear() {
// clear all ops
ops_.clear();
// clear all vars which are not persistable
for (auto it = vars_.begin(); it != vars_.end();) {
if (it->second->Persistable()) {
++it;
} else {
vars_.erase(it++);
}
}
}
void BlockDesc::Flush() { void BlockDesc::Flush() {
for (auto &op_desc : ops_) { for (auto &op_desc : ops_) {
op_desc->Flush(); op_desc->Flush();
......
...@@ -93,12 +93,12 @@ class BlockDesc { ...@@ -93,12 +93,12 @@ class BlockDesc {
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
void RemoveOpInternal(const OpDesc *op_desc);
void RemoveVar(const std::string &name) { vars_.erase(name); } void RemoveVar(const std::string &name) { vars_.erase(name); }
std::vector<OpDesc *> AllOps() const; std::vector<OpDesc *> AllOps() const;
void Clear();
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
OpDesc *Op(int idx) const { return ops_.at(idx).get(); } OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
......
...@@ -24,3 +24,11 @@ limitations under the License. */ ...@@ -24,3 +24,11 @@ limitations under the License. */
#pragma pop_macro("_XOPEN_SOURCE") #pragma pop_macro("_XOPEN_SOURCE")
#pragma pop_macro("_POSIX_C_SOURCE") #pragma pop_macro("_POSIX_C_SOURCE")
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <limits> #include <limits>
#include <map> #include <map>
#include <random> #include <random>
#include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -139,6 +140,8 @@ class Autograd { ...@@ -139,6 +140,8 @@ class Autograd {
} }
} }
} }
ready_op->InvokeBackwardHooks();
} }
} }
...@@ -156,8 +159,10 @@ class Autograd { ...@@ -156,8 +159,10 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- " VLOG(5) << "op dep " << candidate->op_desc_->Type() << " trace id "
<< it.first << " <---- " << pre_op->op_desc_->Type(); << candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->op_desc_->Type() << " trace id "
<< pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op); visited.insert(pre_op);
queue.push_back(pre_op); queue.push_back(pre_op);
...@@ -211,6 +216,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -211,6 +216,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return {}; return {};
} }
VLOG(3) << "apply op grad: " << op_desc_->Type();
std::vector<framework::VariableValueMap> grad_outputs; std::vector<framework::VariableValueMap> grad_outputs;
if (backward_id_ > 0) { if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad"; VLOG(3) << "py_layer_grad";
...@@ -272,6 +278,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -272,6 +278,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return input_vars_; return input_vars_;
} }
void OpBase::InvokeBackwardHooks() {
VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
// call backward hooks
for (py::object& callable : backward_hooks_) {
callable(this);
}
}
void OpBase::RegisterBackwardHooks(const py::object& callable) {
VLOG(3) << "Register backward hooks " << trace_id_;
// TODO(minqiyang): check the callable format
backward_hooks_.push_back(callable);
}
void VarBase::RunBackward() { void VarBase::RunBackward() {
if (!pre_op_) return; if (!pre_op_) return;
......
...@@ -123,22 +123,32 @@ class VarBase { ...@@ -123,22 +123,32 @@ class VarBase {
private: private:
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient) VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient)
: var_desc_(nullptr), : name_(),
var_desc_(nullptr),
var_(var), var_(var),
grads_(grad), grads_(grad),
block_(nullptr),
persistable_(false),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
pre_op_(nullptr), pre_op_(nullptr),
pre_op_out_name_(),
pre_op_out_idx_(-1) {} pre_op_out_idx_(-1) {}
public: public:
virtual ~VarBase() { virtual ~VarBase() {
// TODO(minqiyang): remove var desc from block desc
if (var_) { if (var_) {
delete var_; delete var_;
var_ = nullptr;
} }
if (grads_) { if (grads_) {
delete grads_; delete grads_;
grads_ = nullptr;
} }
pre_op_ = nullptr;
pre_op_out_idx_ = -1;
} }
inline OpBase* PreOp() const { return pre_op_; } inline OpBase* PreOp() const { return pre_op_; }
...@@ -151,6 +161,14 @@ class VarBase { ...@@ -151,6 +161,14 @@ class VarBase {
void RunBackward(); void RunBackward();
inline void ResetPreOp(OpBase* op) {
if (op == pre_op_) {
// clear pre_op info when op equals to var's pre_op
pre_op_ = nullptr;
pre_op_out_idx_ = -1;
}
}
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name, void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool pre_op_stop_gradient) { int pre_op_out_idx, bool pre_op_stop_gradient) {
pre_op_ = pre_op; pre_op_ = pre_op;
...@@ -184,11 +202,15 @@ class VarBase { ...@@ -184,11 +202,15 @@ class VarBase {
return string::Sprintf("%s@IGrad", var_desc_->Name()); return string::Sprintf("%s@IGrad", var_desc_->Name());
} }
std::string name_;
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
framework::Variable* var_; framework::Variable* var_;
VarBase* grads_; VarBase* grads_;
framework::BlockDesc* block_;
bool persistable_;
private: private:
bool stop_gradient_; bool stop_gradient_;
OpBase* pre_op_; OpBase* pre_op_;
...@@ -199,15 +221,27 @@ class VarBase { ...@@ -199,15 +221,27 @@ class VarBase {
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its /* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
* gradient. This object should be managed totally by Python intepreter. * gradient. This object should be managed totally by Python intepreter.
*/ */
class OpBase { class PYBIND11_HIDDEN OpBase {
public: public:
OpBase() OpBase()
: op_desc_(nullptr), : op_desc_(nullptr),
forward_id_(-1), forward_id_(-1),
backward_id_(-1), backward_id_(-1),
place_(platform::CPUPlace()) {} trace_id_(-1),
place_(platform::CPUPlace()),
backward_hooks_() {}
virtual ~OpBase() { virtual ~OpBase() {
// TODO(minqiyang): remove op_desc from block_desc in tracer
//
// reset all output vars' pre op
for (auto iter : output_vars_) {
for (VarBase* var : iter.second) {
var->ResetPreOp(this);
}
}
// release resource
for (framework::OpDesc* desc : grad_op_descs_) { for (framework::OpDesc* desc : grad_op_descs_) {
delete desc; delete desc;
} }
...@@ -215,6 +249,10 @@ class OpBase { ...@@ -215,6 +249,10 @@ class OpBase {
std::map<std::string, std::vector<VarBase*>> ApplyGrad(); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
void RegisterBackwardHooks(const py::object& callable);
void InvokeBackwardHooks();
// One of `op_desc_` or `forward_id_` is set, not both. // One of `op_desc_` or `forward_id_` is set, not both.
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_. // For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
framework::OpDesc* op_desc_; framework::OpDesc* op_desc_;
...@@ -225,6 +263,7 @@ class OpBase { ...@@ -225,6 +263,7 @@ class OpBase {
// Note: each fwd op corresponds to a vector of bwd ops. // Note: each fwd op corresponds to a vector of bwd ops.
std::vector<framework::OpDesc*> grad_op_descs_; std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_; int backward_id_;
int trace_id_;
platform::Place place_; platform::Place place_;
...@@ -239,6 +278,8 @@ class OpBase { ...@@ -239,6 +278,8 @@ class OpBase {
std::vector<framework::VariableValueMap> grad_output_vars_; std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_; framework::BlockDesc* block_;
std::vector<py::object> backward_hooks_;
}; };
class Layer { class Layer {
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include <memory>
#include <set> #include <set>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -110,7 +113,8 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -110,7 +113,8 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
std::map<std::string, VarBase*> vars; std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type(); VLOG(3) << "tracer tracing " << op_desc->Type() << " trace id "
<< op->trace_id_;
op_desc->InferShape(*block); op_desc->InferShape(*block);
op_desc->InferVarType(block); op_desc->InferVarType(block);
...@@ -133,11 +137,13 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -133,11 +137,13 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
if (inp->PreOp() && !inp->IsStopGradient()) { if (inp->PreOp() && !inp->IsStopGradient()) {
op->pre_ops_[it.first].push_back(inp->PreOp()); op->pre_ops_[it.first].push_back(inp->PreOp());
op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx()); op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
VLOG(3) << "add pre op " << inp->PreOp()->op_desc_->Type();
} else { } else {
op->pre_ops_[it.first].push_back(nullptr); op->pre_ops_[it.first].push_back(nullptr);
} }
VLOG(3) << "input vname " << inp->var_desc_->Name() << " " VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
<< inp->var_->IsInitialized(); << inp->var_->IsInitialized() << " stop_gradient "
<< inp->IsStopGradient();
} }
} }
...@@ -189,6 +195,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -189,6 +195,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->grad_input_vars_.resize(op->grad_op_descs_.size()); op->grad_input_vars_.resize(op->grad_op_descs_.size());
op->grad_output_vars_.resize(op->grad_op_descs_.size()); op->grad_output_vars_.resize(op->grad_op_descs_.size());
for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) { for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) {
framework::OpDesc* grad_op_desc = op->grad_op_descs_[i]; framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
for (auto it : grad_op_desc->Inputs()) { for (auto it : grad_op_desc->Inputs()) {
...@@ -201,7 +208,6 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -201,7 +208,6 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.push_back(fwd_var_it->second->var_);
vars_saved_for_backward.insert(it.first);
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
...@@ -211,6 +217,8 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -211,6 +217,8 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
// Douts. // Douts.
grad_in_vars.push_back(var->grads_->var_); grad_in_vars.push_back(var->grads_->var_);
} }
vars_saved_for_backward.insert(it.first);
} }
} }
......
...@@ -33,7 +33,7 @@ class Layer : public imperative::Layer { ...@@ -33,7 +33,7 @@ class Layer : public imperative::Layer {
} }
}; };
class PyOpBase : public imperative::OpBase { class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
public: public:
using imperative::OpBase::OpBase; // Inherit constructors using imperative::OpBase::OpBase; // Inherit constructors
}; };
......
...@@ -189,8 +189,6 @@ void BindBlockDesc(pybind11::module *m) { ...@@ -189,8 +189,6 @@ void BindBlockDesc(pybind11::module *m) {
return self.HasVar(name); return self.HasVar(name);
}, },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("_clear_block", [](pd::BlockDesc &self) { return self.Clear(); },
pybind11::return_value_policy::reference)
.def("_rename_var", .def("_rename_var",
[](pd::BlockDesc &self, const pybind11::bytes &byte_name, [](pd::BlockDesc &self, const pybind11::bytes &byte_name,
const pybind11::bytes &byte_name_new) { const pybind11::bytes &byte_name_new) {
......
...@@ -177,6 +177,23 @@ PYBIND11_MODULE(core, m) { ...@@ -177,6 +177,23 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::take_ownership) py::return_value_policy::take_ownership)
.def("value", [](const imperative::VarBase &self) { return self.var_; }, .def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property("name",
[](const imperative::VarBase &self) { return self.name_; },
[](imperative::VarBase &self, const std::string &name) {
self.name_ = name;
})
.def_property("block",
[](const imperative::VarBase &self) { return self.block_; },
[](imperative::VarBase &self, framework::BlockDesc *block) {
self.block_ = block;
},
py::return_value_policy::reference)
.def_property(
"persistable",
[](const imperative::VarBase &self) { return self.persistable_; },
[](imperative::VarBase &self, const bool persistable) {
self.persistable_ = persistable;
})
.def_property( .def_property(
"desc", "desc",
[](const imperative::VarBase &self) { return self.var_desc_; }, [](const imperative::VarBase &self) { return self.var_desc_; },
...@@ -193,6 +210,10 @@ PYBIND11_MODULE(core, m) { ...@@ -193,6 +210,10 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
.def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable) {
self.RegisterBackwardHooks(callable);
})
.def_property( .def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; }, "desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) { [](imperative::OpBase &self, framework::OpDesc *op_desc) {
...@@ -201,6 +222,16 @@ PYBIND11_MODULE(core, m) { ...@@ -201,6 +222,16 @@ PYBIND11_MODULE(core, m) {
} }
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property("_trace_id",
[](const imperative::OpBase &self) {
pybind11::gil_scoped_release release;
return self.trace_id_;
},
[](imperative::OpBase &self, int trace_id) {
pybind11::gil_scoped_release release;
self.trace_id_ = trace_id;
},
py::return_value_policy::reference)
.def_property( .def_property(
"forward_id", "forward_id",
[](const imperative::OpBase &self) { return self.forward_id_; }, [](const imperative::OpBase &self) { return self.forward_id_; },
......
...@@ -393,6 +393,9 @@ class Variable(object): ...@@ -393,6 +393,9 @@ class Variable(object):
if not self._ivar: if not self._ivar:
self._ivar = core.VarBase(stop_gradient) self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.block = block.desc
self._ivar.name = name
self._ivar.persistable = persistable
if persistable: if persistable:
self.block.vars[name] = self self.block.vars[name] = self
else: else:
...@@ -721,6 +724,8 @@ class Operator(object): ...@@ -721,6 +724,8 @@ class Operator(object):
out_arg_names = [] out_arg_names = []
for arg in out_args: for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name)) out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode?
if not _in_imperative_mode():
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
...@@ -1200,15 +1205,6 @@ class Block(object): ...@@ -1200,15 +1205,6 @@ class Block(object):
else: else:
raise ValueError("Var {0} is not found recursively".format(name)) raise ValueError("Var {0} is not found recursively".format(name))
def _clear_block(self):
# TODO(minqiyang): move this to backward_hooks
self.desc._clear_block()
for name in self.vars.keys():
assert self.vars[name].persistable
del self.ops[:]
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
...@@ -1345,26 +1341,13 @@ class Block(object): ...@@ -1345,26 +1341,13 @@ class Block(object):
# #
# TODO(minqiyang): add op stop_gradient support in static mode too. # TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode. # currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False)) _imperative_tracer().trace_op(op,
kwargs.get("stop_gradient", False))
else:
self.ops.append(op) self.ops.append(op)
return op return op
def _trace_op(self, op, stop_gradient=False):
backward_refs = _imperative_tracer().trace(
op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_, stop_gradient)
# TODO(minqiyang): support backward_hooks to eager remove backward_refs
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
for k, v in six.iteritems(op.outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
def _insert_op(self, index, *args, **kwargs): def _insert_op(self, index, *args, **kwargs):
""" """
Insert a Operator according to the giving arguments. Insert a Operator according to the giving arguments.
...@@ -1417,9 +1400,11 @@ class Block(object): ...@@ -1417,9 +1400,11 @@ class Block(object):
inputs=kwargs.get("inputs", None), inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None), outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", None))
self.ops.insert(0, op)
if _in_imperative_mode(): if _in_imperative_mode():
self._trace_op(op, kwargs.get("stop_gradient", False)) _imperative_tracer().trace_op(op,
kwargs.get("stop_gradient", False))
else:
self.ops.insert(0, op)
return op return op
def _sync_with_cpp(self): def _sync_with_cpp(self):
......
...@@ -23,7 +23,11 @@ from .layers import * ...@@ -23,7 +23,11 @@ from .layers import *
from . import nn from . import nn
from .nn import * from .nn import *
from . import tracer
from .tracer import *
__all__ = [] __all__ = []
__all__ += layers.__all__ __all__ += layers.__all__
__all__ += base.__all__ __all__ += base.__all__
__all__ += nn.__all__ __all__ += nn.__all__
__all__ += tracer.__all__
...@@ -16,6 +16,7 @@ import numpy as np ...@@ -16,6 +16,7 @@ import numpy as np
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .tracer import Tracer
__all__ = ['enabled', 'guard', 'to_variable'] __all__ = ['enabled', 'guard', 'to_variable']
...@@ -28,7 +29,7 @@ def enabled(): ...@@ -28,7 +29,7 @@ def enabled():
def guard(place=None): def guard(place=None):
train = framework.Program() train = framework.Program()
startup = framework.Program() startup = framework.Program()
tracer = core.Tracer(train.current_block().desc) tracer = Tracer(train.current_block().desc)
if place is None: if place is None:
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import six
from collections import defaultdict
from paddle.fluid import core
from paddle.fluid import framework
__all__ = ['Tracer']
def release_op(op):
del framework._imperative_tracer()._ops[op._trace_id]
class Tracer(core.Tracer):
"""
Python wrapper of imperative tracer
"""
def __init__(self, block):
super(Tracer, self).__init__(block)
self._ops = defaultdict()
self._trace_id = 0
def trace_op(self, op, stop_gradient=False):
# record op's trace id
op.iop._trace_id = self._trace_id
# trace op and save it
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.block.desc,
framework._current_expected_place(),
stop_gradient)
if not stop_gradient:
self._trace_id += 1
self._ops[op.iop._trace_id] = op
# register backward hooks and variables if needed
if len(backward_refs) > 0:
op.iop.register_backward_hooks(release_op)
# TODO(minqiyang): remove all inputs and outputs after seperate
# var and grad
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
for k, v in six.iteritems(op.outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc from .core import VarDesc
from . import unique_name from . import unique_name
from .imperative import base as imperative_base
__all__ = [ __all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
...@@ -165,6 +166,7 @@ class ConstantInitializer(Initializer): ...@@ -165,6 +166,7 @@ class ConstantInitializer(Initializer):
'force_cpu': self._force_cpu or force_init_on_cpu() 'force_cpu': self._force_cpu or force_init_on_cpu()
}, },
stop_gradient=True) stop_gradient=True)
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
...@@ -244,6 +246,7 @@ class UniformInitializer(Initializer): ...@@ -244,6 +246,7 @@ class UniformInitializer(Initializer):
attrs={"in_dtype": out_var.dtype, attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype}) "out_dtype": var.dtype})
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
...@@ -322,6 +325,7 @@ class NormalInitializer(Initializer): ...@@ -322,6 +325,7 @@ class NormalInitializer(Initializer):
outputs={"Out": var}, outputs={"Out": var},
attrs={"in_dtype": out_var.dtype, attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype}) "out_dtype": var.dtype})
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
...@@ -400,6 +404,7 @@ class TruncatedNormalInitializer(Initializer): ...@@ -400,6 +404,7 @@ class TruncatedNormalInitializer(Initializer):
outputs={"Out": var}, outputs={"Out": var},
attrs={"in_dtype": out_var.dtype, attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype}) "out_dtype": var.dtype})
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
...@@ -505,6 +510,7 @@ class XavierInitializer(Initializer): ...@@ -505,6 +510,7 @@ class XavierInitializer(Initializer):
"seed": self._seed "seed": self._seed
}, },
stop_gradient=True) stop_gradient=True)
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
...@@ -605,6 +611,7 @@ class MSRAInitializer(Initializer): ...@@ -605,6 +611,7 @@ class MSRAInitializer(Initializer):
"seed": self._seed "seed": self._seed
}, },
stop_gradient=True) stop_gradient=True)
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
...@@ -703,6 +710,7 @@ class BilinearInitializer(Initializer): ...@@ -703,6 +710,7 @@ class BilinearInitializer(Initializer):
'shape': list(shape), 'shape': list(shape),
value_name: values value_name: values
}) })
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
...@@ -761,6 +769,7 @@ class NumpyArrayInitializer(Initializer): ...@@ -761,6 +769,7 @@ class NumpyArrayInitializer(Initializer):
value_name: values value_name: values
}, },
stop_gradient=True) stop_gradient=True)
if not imperative_base.enabled():
var.op = op var.op = op
return op return op
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import contextlib import contextlib
import unittest import unittest
import numpy as np import numpy as np
...@@ -142,8 +144,6 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -142,8 +144,6 @@ class TestImperativeMnist(unittest.TestCase):
sgd.minimize(avg_loss) sgd.minimize(avg_loss)
mnist.clear_gradients() mnist.clear_gradients()
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {} dy_param_value = {}
for param in mnist.parameters(): for param in mnist.parameters():
dy_param_value[param.name] = param._numpy() dy_param_value[param.name] = param._numpy()
......
...@@ -243,7 +243,9 @@ class TestImperativePtbRnn(unittest.TestCase): ...@@ -243,7 +243,9 @@ class TestImperativePtbRnn(unittest.TestCase):
dy_loss = None dy_loss = None
last_hidden = None last_hidden = None
last_cell = None last_cell = None
for i in range(2): batch_num = 50
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1)) x_data = x_data.reshape((-1, num_steps, 1))
...@@ -302,7 +304,7 @@ class TestImperativePtbRnn(unittest.TestCase): ...@@ -302,7 +304,7 @@ class TestImperativePtbRnn(unittest.TestCase):
static_loss_value = None static_loss_value = None
static_last_cell_value = None static_last_cell_value = None
static_last_hidden_value = None static_last_hidden_value = None
for i in range(2): for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1)) x_data = x_data.reshape((-1, num_steps, 1))
......
...@@ -231,7 +231,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -231,7 +231,7 @@ class TestImperativeResnet(unittest.TestCase):
seed = 90 seed = 90
batch_size = train_parameters["batch_size"] batch_size = train_parameters["batch_size"]
batch_num = 2 batch_num = 20
with fluid.imperative.guard(): with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -286,8 +286,6 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -286,8 +286,6 @@ class TestImperativeResnet(unittest.TestCase):
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
resnet.clear_gradients() resnet.clear_gradients()
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {} dy_param_value = {}
for param in resnet.parameters(): for param in resnet.parameters():
dy_param_value[param.name] = param._numpy() dy_param_value[param.name] = param._numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册