提交 e395f2c6 编写于 作者: X Xin Pan

polish codes

test=develop
上级 179363a1
...@@ -57,15 +57,15 @@ class Autograd { ...@@ -57,15 +57,15 @@ class Autograd {
Autograd() {} Autograd() {}
void RunBackward(VarBase* var) { void RunBackward(VarBase* var) {
if (var->stop_gradient_) { if (var->IsStopGradient()) {
return; return;
} }
VLOG(3) << "start autograd"; VLOG(3) << "start autograd";
std::deque<OpBase*> ready; std::deque<OpBase*> ready;
ready.push_back(var->pre_op_); ready.push_back(var->PreOp());
std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->pre_op_); std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->PreOp());
while (!ready.empty()) { while (!ready.empty()) {
OpBase* ready_op = ready.front(); OpBase* ready_op = ready.front();
...@@ -77,7 +77,7 @@ class Autograd { ...@@ -77,7 +77,7 @@ class Autograd {
const std::vector<VarBase*>& ingrads = it.second; const std::vector<VarBase*>& ingrads = it.second;
for (size_t i = 0; i < ingrads.size(); ++i) { for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue; if (!ingrads[i]) continue;
if (ready_op->input_vars_[it.first][i]->stop_gradient_) { if (ready_op->input_vars_[it.first][i]->IsStopGradient()) {
continue; continue;
} }
OpBase* pre_op = ready_op->pre_ops_[it.first][i]; OpBase* pre_op = ready_op->pre_ops_[it.first][i];
......
...@@ -100,20 +100,20 @@ class VarBase { ...@@ -100,20 +100,20 @@ class VarBase {
// Owns `var` and `grad` // Owns `var` and `grad`
VarBase(framework::Variable* var, VarBase* grad) VarBase(framework::Variable* var, VarBase* grad)
: pre_op_(nullptr), : var_desc_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
var_(var), var_(var),
grads_(grad), grads_(grad),
stop_gradient_(false) {} stop_gradient_(false),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
explicit VarBase(bool stop_gradient) explicit VarBase(bool stop_gradient)
: pre_op_(nullptr), : var_desc_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
var_(new framework::Variable()), var_(new framework::Variable()),
grads_(stop_gradient ? nullptr : new VarBase(true)), grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient) {} stop_gradient_(stop_gradient),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
virtual ~VarBase() { virtual ~VarBase() {
if (var_) { if (var_) {
...@@ -125,15 +125,27 @@ class VarBase { ...@@ -125,15 +125,27 @@ class VarBase {
} }
} }
void Clear() { OpBase* PreOp() const { return pre_op_; }
int PreOpOutIdx() const { return pre_op_out_idx_; }
void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
bool IsStopGradient() const { return stop_gradient_; }
void RunBackward();
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool stop_gradient) {
pre_op_ = pre_op;
pre_op_out_name_ = pre_op_out_name;
pre_op_out_idx_ = pre_op_out_idx;
stop_gradient_ = stop_gradient;
}
void ClearGradient() {
delete grads_; delete grads_;
grads_ = new VarBase(true); grads_ = new VarBase(true);
pre_op_ = nullptr;
pre_op_out_name_ = "";
} }
void RunBackward();
framework::LoDTensor& GradValue(); framework::LoDTensor& GradValue();
inline std::string GradName() const { inline std::string GradName() const {
...@@ -143,16 +155,16 @@ class VarBase { ...@@ -143,16 +155,16 @@ class VarBase {
return string::Sprintf("%s@IGrad", var_desc_->Name()); return string::Sprintf("%s@IGrad", var_desc_->Name());
} }
OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_;
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
framework::Variable* var_; framework::Variable* var_;
VarBase* grads_; VarBase* grads_;
private:
bool stop_gradient_; bool stop_gradient_;
OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_;
}; };
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its /* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
......
...@@ -63,9 +63,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -63,9 +63,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
invars.push_back(inp->var_); invars.push_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) { if (inp->PreOp()) {
op->pre_ops_[it.first].push_back(inp->pre_op_); op->pre_ops_[it.first].push_back(inp->PreOp());
op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_); op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
} else { } else {
op->pre_ops_[it.first].push_back(nullptr); op->pre_ops_[it.first].push_back(nullptr);
} }
...@@ -89,10 +89,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -89,10 +89,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} else { } else {
LOG(ERROR) << "tracer doesn't support yet"; LOG(ERROR) << "tracer doesn't support yet";
} }
out->stop_gradient_ = stop_gradient; out->TrackPreOp(op, it.first, i, stop_gradient);
out->pre_op_ = op;
out->pre_op_out_name_ = it.first;
out->pre_op_out_idx_ = i;
VLOG(3) << "output vname " << out->var_desc_->Name() << " " VLOG(3) << "output vname " << out->var_desc_->Name() << " "
<< out->var_->IsInitialized(); << out->var_->IsInitialized();
...@@ -167,9 +164,9 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -167,9 +164,9 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
op->input_vars_[PyLayer::kFwdInp] = inputs; op->input_vars_[PyLayer::kFwdInp] = inputs;
op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs); op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs);
for (VarBase* inp : inputs) { for (VarBase* inp : inputs) {
if (inp->pre_op_) { if (inp->PreOp()) {
op->pre_ops_[PyLayer::kFwdInp].push_back(inp->pre_op_); op->pre_ops_[PyLayer::kFwdInp].push_back(inp->PreOp());
op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->pre_op_out_idx_); op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->PreOpOutIdx());
} else { } else {
op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr); op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr);
} }
...@@ -178,10 +175,7 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -178,10 +175,7 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto& outputs = op->output_vars_[PyLayer::kFwdOut]; auto& outputs = op->output_vars_[PyLayer::kFwdOut];
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
out->stop_gradient_ = stop_gradient; out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
out->pre_op_ = op;
out->pre_op_out_name_ = PyLayer::kFwdOut;
out->pre_op_out_idx_ = i;
} }
if (!stop_gradient) { if (!stop_gradient) {
auto& grad_input_vars = auto& grad_input_vars =
......
...@@ -133,7 +133,7 @@ PYBIND11_MODULE(core, m) { ...@@ -133,7 +133,7 @@ PYBIND11_MODULE(core, m) {
[](imperative::VarBase &self) { self.RunBackward(); }) [](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName) .def("_grad_name", &imperative::VarBase::GradName)
.def("_grad_value", &imperative::VarBase::GradValue) .def("_grad_value", &imperative::VarBase::GradValue)
.def("_clear", &imperative::VarBase::Clear) .def("_clear_gradient", &imperative::VarBase::ClearGradient)
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_; }, [](const imperative::VarBase &self) { return self.grads_; },
py::return_value_policy::reference) py::return_value_policy::reference)
...@@ -148,9 +148,9 @@ PYBIND11_MODULE(core, m) { ...@@ -148,9 +148,9 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property( .def_property(
"stop_gradient", "stop_gradient",
[](const imperative::VarBase &self) { return self.stop_gradient_; }, [](const imperative::VarBase &self) { return self.IsStopGradient(); },
[](imperative::VarBase &self, bool stop_gradient) { [](imperative::VarBase &self, bool stop_gradient) {
self.stop_gradient_ = stop_gradient; self.SetStopGradient(stop_gradient);
}); });
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
......
...@@ -388,8 +388,8 @@ class Variable(object): ...@@ -388,8 +388,8 @@ class Variable(object):
def _gradient(self): def _gradient(self):
return np.array(self._ivar._grad_value()) return np.array(self._ivar._grad_value())
def _clear(self): def _clear_gradient(self):
self._ivar._clear() self._ivar._clear_gradient()
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
......
...@@ -33,6 +33,10 @@ class Layer(core.Layer): ...@@ -33,6 +33,10 @@ class Layer(core.Layer):
def parameters(self): def parameters(self):
return [] return []
def clear_gradients(self):
for p in self.parameters():
p._clear()
def _build_once(self, inputs): def _build_once(self, inputs):
pass pass
......
...@@ -48,6 +48,7 @@ class Conv2D(layers.Layer): ...@@ -48,6 +48,7 @@ class Conv2D(layers.Layer):
assert param_attr is not False, "param_attr should not be False here." assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__(name=name, dtype=dtype) super(Conv2D, self).__init__(name=name, dtype=dtype)
# TODO(minqiyang): Move this to the top.
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
self._helper = LayerHelper( self._helper = LayerHelper(
type(self).__name__, type(self).__name__,
......
...@@ -133,9 +133,6 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -133,9 +133,6 @@ class TestImperativeMnist(unittest.TestCase):
for param in generate_p.global_block().all_parameters(): for param in generate_p.global_block().all_parameters():
static_params[param.name] = np.array( static_params[param.name] = np.array(
scope.find_var(param.name).get_tensor()) scope.find_var(param.name).get_tensor())
sys.stderr.write(
'static_param_loss: %s: %s\n' %
(param.name, np.sum(static_params[param.name])))
dy_params = dict() dy_params = dict()
with fluid.imperative.guard(): with fluid.imperative.guard():
...@@ -160,10 +157,8 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -160,10 +157,8 @@ class TestImperativeMnist(unittest.TestCase):
d_loss = d_loss_real + d_loss_fake d_loss = d_loss_real + d_loss_fake
d_loss._backward() d_loss._backward()
sgd.minimize(d_loss) sgd.minimize(d_loss)
for p in discriminator.parameters(): discriminator.clear_gradients()
p._clear() generator.clear_gradients()
for p in generator.parameters():
p._clear()
d_fake = discriminator( d_fake = discriminator(
generator(to_variable(np.ones([2, 2], np.float32)))) generator(to_variable(np.ones([2, 2], np.float32))))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册