提交 e0a2b472 编写于 作者: M minqiyang 提交者: ceci3

Move ClearBlock into OpBase and VarBase's destructor

test=develop
上级 9abf40c9
......@@ -163,20 +163,6 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
return res;
}
void BlockDesc::Clear() {
// clear all ops
ops_.clear();
// clear all vars which are not persistable
for (auto it = vars_.begin(); it != vars_.end();) {
if (it->second->Persistable()) {
++it;
} else {
vars_.erase(it++);
}
}
}
void BlockDesc::Flush() {
for (auto &op_desc : ops_) {
op_desc->Flush();
......
......@@ -97,8 +97,6 @@ class BlockDesc {
std::vector<OpDesc *> AllOps() const;
void Clear();
size_t OpSize() const { return ops_.size(); }
OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
......
......@@ -126,12 +126,19 @@ class VarBase {
: var_desc_(nullptr),
var_(var),
grads_(grad),
block_(nullptr),
stop_gradient_(stop_gradient),
pre_op_(nullptr),
pre_op_out_idx_(-1) {}
public:
virtual ~VarBase() {
LOG(ERROR) << "remove var " << name_;
if (block_) {
block_->RemoveVar(name_);
}
if (var_) {
delete var_;
}
......@@ -189,11 +196,14 @@ class VarBase {
framework::Variable* var_;
VarBase* grads_;
framework::BlockDesc* block_;
private:
bool stop_gradient_;
OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_;
std::string name_;
};
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
......@@ -212,6 +222,12 @@ class OpBase {
for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;
}
LOG(ERROR) << "remove op " << op_desc_->Type() << " id " << trace_id_;
if (block_) {
block_->RemoveOp(trace_id_, trace_id_ + 1);
}
}
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
......
......@@ -189,8 +189,6 @@ void BindBlockDesc(pybind11::module *m) {
return self.HasVar(name);
},
pybind11::return_value_policy::reference)
.def("_clear_block", [](pd::BlockDesc &self) { return self.Clear(); },
pybind11::return_value_policy::reference)
.def("_rename_var",
[](pd::BlockDesc &self, const pybind11::bytes &byte_name,
const pybind11::bytes &byte_name_new) {
......
......@@ -390,6 +390,8 @@ class Variable(object):
if _in_imperative_mode():
# record vars in tracer rather than blocks
self._ivar = kwargs.get("ivar", None)
self._ivar.block = block.desc
self._ivar.name = name
if not self._ivar:
self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc
......@@ -1200,15 +1202,6 @@ class Block(object):
else:
raise ValueError("Var {0} is not found recursively".format(name))
def _clear_block(self):
assert _in_imperative_mode()
# TODO(minqiyang): move this to Variable and Operator's __del__
self.desc._clear_block()
assert len(self.vars) == 0
assert len(self.ops) == 0
def all_parameters(self):
return list(self.iter_parameters())
......
......@@ -142,8 +142,6 @@ class TestImperativeMnist(unittest.TestCase):
sgd.minimize(avg_loss)
mnist.clear_gradients()
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {}
for param in mnist.parameters():
dy_param_value[param.name] = param._numpy()
......
......@@ -286,8 +286,6 @@ class TestImperativeResnet(unittest.TestCase):
optimizer.minimize(avg_loss)
resnet.clear_gradients()
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {}
for param in resnet.parameters():
dy_param_value[param.name] = param._numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册