提交 e395f2c6 编写于 作者: X Xin Pan

polish codes

test=develop
上级 179363a1
......@@ -57,15 +57,15 @@ class Autograd {
Autograd() {}
void RunBackward(VarBase* var) {
if (var->stop_gradient_) {
if (var->IsStopGradient()) {
return;
}
VLOG(3) << "start autograd";
std::deque<OpBase*> ready;
ready.push_back(var->pre_op_);
ready.push_back(var->PreOp());
std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->pre_op_);
std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->PreOp());
while (!ready.empty()) {
OpBase* ready_op = ready.front();
......@@ -77,7 +77,7 @@ class Autograd {
const std::vector<VarBase*>& ingrads = it.second;
for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue;
if (ready_op->input_vars_[it.first][i]->stop_gradient_) {
if (ready_op->input_vars_[it.first][i]->IsStopGradient()) {
continue;
}
OpBase* pre_op = ready_op->pre_ops_[it.first][i];
......
......@@ -100,20 +100,20 @@ class VarBase {
// Owns `var` and `grad`
VarBase(framework::Variable* var, VarBase* grad)
: pre_op_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
: var_desc_(nullptr),
var_(var),
grads_(grad),
stop_gradient_(false) {}
stop_gradient_(false),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
explicit VarBase(bool stop_gradient)
: pre_op_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
: var_desc_(nullptr),
var_(new framework::Variable()),
grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient) {}
stop_gradient_(stop_gradient),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
virtual ~VarBase() {
if (var_) {
......@@ -125,15 +125,27 @@ class VarBase {
}
}
void Clear() {
OpBase* PreOp() const { return pre_op_; }
int PreOpOutIdx() const { return pre_op_out_idx_; }
void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
bool IsStopGradient() const { return stop_gradient_; }
void RunBackward();
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool stop_gradient) {
pre_op_ = pre_op;
pre_op_out_name_ = pre_op_out_name;
pre_op_out_idx_ = pre_op_out_idx;
stop_gradient_ = stop_gradient;
}
void ClearGradient() {
delete grads_;
grads_ = new VarBase(true);
pre_op_ = nullptr;
pre_op_out_name_ = "";
}
void RunBackward();
framework::LoDTensor& GradValue();
inline std::string GradName() const {
......@@ -143,16 +155,16 @@ class VarBase {
return string::Sprintf("%s@IGrad", var_desc_->Name());
}
OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_;
framework::VarDesc* var_desc_;
framework::Variable* var_;
VarBase* grads_;
private:
bool stop_gradient_;
OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_;
};
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
......
......@@ -63,9 +63,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
invars.push_back(inp->var_);
vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) {
op->pre_ops_[it.first].push_back(inp->pre_op_);
op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_);
if (inp->PreOp()) {
op->pre_ops_[it.first].push_back(inp->PreOp());
op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
} else {
op->pre_ops_[it.first].push_back(nullptr);
}
......@@ -89,10 +89,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} else {
LOG(ERROR) << "tracer doesn't support yet";
}
out->stop_gradient_ = stop_gradient;
out->pre_op_ = op;
out->pre_op_out_name_ = it.first;
out->pre_op_out_idx_ = i;
out->TrackPreOp(op, it.first, i, stop_gradient);
VLOG(3) << "output vname " << out->var_desc_->Name() << " "
<< out->var_->IsInitialized();
......@@ -167,9 +164,9 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
op->input_vars_[PyLayer::kFwdInp] = inputs;
op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs);
for (VarBase* inp : inputs) {
if (inp->pre_op_) {
op->pre_ops_[PyLayer::kFwdInp].push_back(inp->pre_op_);
op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->pre_op_out_idx_);
if (inp->PreOp()) {
op->pre_ops_[PyLayer::kFwdInp].push_back(inp->PreOp());
op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->PreOpOutIdx());
} else {
op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr);
}
......@@ -178,10 +175,7 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto& outputs = op->output_vars_[PyLayer::kFwdOut];
for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i];
out->stop_gradient_ = stop_gradient;
out->pre_op_ = op;
out->pre_op_out_name_ = PyLayer::kFwdOut;
out->pre_op_out_idx_ = i;
out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
}
if (!stop_gradient) {
auto& grad_input_vars =
......
......@@ -133,7 +133,7 @@ PYBIND11_MODULE(core, m) {
[](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName)
.def("_grad_value", &imperative::VarBase::GradValue)
.def("_clear", &imperative::VarBase::Clear)
.def("_clear_gradient", &imperative::VarBase::ClearGradient)
.def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_; },
py::return_value_policy::reference)
......@@ -148,9 +148,9 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::reference)
.def_property(
"stop_gradient",
[](const imperative::VarBase &self) { return self.stop_gradient_; },
[](const imperative::VarBase &self) { return self.IsStopGradient(); },
[](imperative::VarBase &self, bool stop_gradient) {
self.stop_gradient_ = stop_gradient;
self.SetStopGradient(stop_gradient);
});
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
......
......@@ -388,8 +388,8 @@ class Variable(object):
def _gradient(self):
return np.array(self._ivar._grad_value())
def _clear(self):
self._ivar._clear()
def _clear_gradient(self):
self._ivar._clear_gradient()
def __str__(self):
return self.to_string(True)
......
......@@ -33,6 +33,10 @@ class Layer(core.Layer):
def parameters(self):
return []
def clear_gradients(self):
for p in self.parameters():
p._clear()
def _build_once(self, inputs):
pass
......
......@@ -48,6 +48,7 @@ class Conv2D(layers.Layer):
assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__(name=name, dtype=dtype)
# TODO(minqiyang): Move this to the top.
from ..layer_helper import LayerHelper
self._helper = LayerHelper(
type(self).__name__,
......
......@@ -133,9 +133,6 @@ class TestImperativeMnist(unittest.TestCase):
for param in generate_p.global_block().all_parameters():
static_params[param.name] = np.array(
scope.find_var(param.name).get_tensor())
sys.stderr.write(
'static_param_loss: %s: %s\n' %
(param.name, np.sum(static_params[param.name])))
dy_params = dict()
with fluid.imperative.guard():
......@@ -160,10 +157,8 @@ class TestImperativeMnist(unittest.TestCase):
d_loss = d_loss_real + d_loss_fake
d_loss._backward()
sgd.minimize(d_loss)
for p in discriminator.parameters():
p._clear()
for p in generator.parameters():
p._clear()
discriminator.clear_gradients()
generator.clear_gradients()
d_fake = discriminator(
generator(to_variable(np.ones([2, 2], np.float32))))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册