提交 8fe0c0c5 编写于 作者: M minqiyang

implement backward refs

上级 74551758
......@@ -205,6 +205,33 @@ framework::LoDTensor& VarBase::GradValue() {
return *(grads_->var_->GetMutable<framework::LoDTensor>());
}
void VarBase::ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
void VarBase::RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (grad_op_descs_.empty() && backward_id_ <= 0) {
LOG(WARNING) << "op with no grad: " << op_desc_->Type();
......@@ -271,22 +298,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return input_vars_;
}
void VarBase::RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
py_funcs_[func_id] = py_func;
}
......
......@@ -105,23 +105,23 @@ class VarBase {
public:
VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {}
// Owns `var` and `grad`
explicit VarBase(bool stop_gradient)
: VarBase(new framework::Variable(),
stop_gradient ? nullptr : new VarBase(true), stop_gradient) {}
VarBase(framework::Variable* var, VarBase* grad)
: VarBase(var, grad, false) {}
private:
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient)
: var_desc_(nullptr),
var_(var),
grads_(grad),
stop_gradient_(false),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
explicit VarBase(bool stop_gradient)
: var_desc_(nullptr),
var_(new framework::Variable()),
grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
public:
virtual ~VarBase() {
if (var_) {
delete var_;
......@@ -132,13 +132,13 @@ class VarBase {
}
}
OpBase* PreOp() const { return pre_op_; }
int PreOpOutIdx() const { return pre_op_out_idx_; }
void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
bool IsStopGradient() const { return stop_gradient_; }
inline OpBase* PreOp() const { return pre_op_; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; }
void RunBackward();
inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient;
}
inline bool IsStopGradient() const { return stop_gradient_; }
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool pre_op_stop_gradient) {
......@@ -150,16 +150,9 @@ class VarBase {
}
}
void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
void RunBackward();
void ClearGradient();
framework::LoDTensor& GradValue();
......
......@@ -14,6 +14,8 @@
#include "paddle/fluid/imperative/tracer.h"
#include <set>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -66,10 +68,11 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
return result;
}
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block,
const platform::Place expected_place,
const bool stop_gradient) {
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs,
framework::BlockDesc* block,
const platform::Place expected_place,
const bool stop_gradient) {
std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_;
......@@ -142,6 +145,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
prepared_op.func(framework::ExecutionContext(
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
std::set<std::string> grad_deps_var;
if (!stop_gradient) {
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
new std::unordered_map<std::string, std::string>());
......@@ -161,6 +166,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_);
grad_deps_var.insert(it.first);
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
......@@ -194,6 +200,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
}
op->block_ = block;
return grad_deps_var;
}
std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
......
......@@ -15,6 +15,7 @@
#pragma once
#include <map>
#include <set>
#include <string>
#include <vector>
......@@ -43,10 +44,11 @@ class Tracer {
virtual ~Tracer() {}
void Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block,
const platform::Place expected_place,
const bool stop_gradient = false);
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs,
framework::BlockDesc* block,
const platform::Place expected_place,
const bool stop_gradient = false);
std::vector<VarBase*> PyTrace(OpBase* op, const std::vector<VarBase*>& inputs,
bool stop_gradient = false);
......
......@@ -34,8 +34,8 @@ void BindTracer(pybind11::module* m) {
framework::BlockDesc* block,
const platform::CPUPlace expected_place,
const bool stop_gradient = false) {
self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient);
return self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient);
})
.def("trace",
[](imperative::Tracer& self, imperative::OpBase* op,
......@@ -44,8 +44,8 @@ void BindTracer(pybind11::module* m) {
framework::BlockDesc* block,
const platform::CUDAPlace expected_place,
const bool stop_gradient = false) {
self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient);
return self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient);
})
.def("py_trace", &imperative::Tracer::PyTrace,
pybind11::return_value_policy::take_ownership);
......
......@@ -376,15 +376,17 @@ class Variable(object):
# get_capacity is implemented
pass
self.block.vars[name] = self
self.op = None
self.stop_gradient = stop_gradient
self.is_data = is_data
if _in_imperative_mode():
# record vars in tracer rather than blocks
self._ivar = kwargs.get("ivar", None)
if not self._ivar:
self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc
else:
self.block.vars[name] = self
self.op = None
self.stop_gradient = stop_gradient
self.is_data = is_data
def _numpy(self):
new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
......@@ -727,6 +729,7 @@ class Operator(object):
if _in_imperative_mode():
self.iop = core.OpBase()
self.iop.desc = self.desc
self.inputs = defaultdict(list)
if inputs is not None:
for k, v in six.iteritems(inputs):
......@@ -734,6 +737,7 @@ class Operator(object):
self.inputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.inputs[k].extend([var._ivar for var in v])
self.outputs = defaultdict(list)
if outputs is not None:
for k, v in six.iteritems(outputs):
......@@ -1186,8 +1190,8 @@ class Block(object):
def _clear_block(self):
self.desc._clear_block()
for name, var in self.vars.items():
if not var.persistable:
for name in self.vars.keys():
if not self.vars[name].persistable:
del self.vars[name]
del self.ops[:]
......@@ -1322,18 +1326,34 @@ class Block(object):
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
if _in_imperative_mode():
# record ops in tracer rather than blocks
#
# TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False))
self.ops.append(op)
# TODO(minqiyang): add stop_gradient support in static mode too.
# currently, we only support stop_gradient in imperative mode.
self._trace_op(op, kwargs.get("stop_gradient", False))
return op
def _trace_op(self, op, stop_gradient=False):
if _in_imperative_mode():
_imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_,
stop_gradient)
backward_refs = _imperative_tracer().trace(
op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_, stop_gradient)
print("backward_refs", backward_refs)
import sys
sys.stdout.flush()
# TODO(minqiyang): support backward hooks to eager remove backward_refs
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
for k, v in six.iteritems(op.outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
def _insert_op(self, index, *args, **kwargs):
"""
......@@ -1388,7 +1408,8 @@ class Block(object):
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
self.ops.insert(0, op)
self._trace_op(op, kwargs.get("stop_gradient", False))
if _in_imperative_mode():
self._trace_op(op, kwargs.get("stop_gradient", False))
return op
def _sync_with_cpp(self):
......
......@@ -102,7 +102,6 @@ class TestImperativeMnist(unittest.TestCase):
def test_mnist_float32(self):
seed = 90
epoch_num = 1
batch_num = 200
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
......@@ -205,12 +204,16 @@ class TestImperativeMnist(unittest.TestCase):
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
if not np.allclose(value, dy_param_init_value[key]):
print(key, value, dy_param_value[key])
# self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-6))
if not np.allclose(value, dy_param_value[key], atol=1e-6):
print(key, value, dy_param_value[key])
# self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__':
......
......@@ -208,7 +208,7 @@ class TestImperativeResnet(unittest.TestCase):
seed = 90
batch_size = train_parameters["batch_size"]
batch_num = 1
batch_num = 2
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
......@@ -266,6 +266,8 @@ class TestImperativeResnet(unittest.TestCase):
optimizer.minimize(avg_loss)
resnet.clear_gradients()
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册