提交 e9fdf909 编写于 作者: M minqiyang

Polish code

test=develop
上级 a15a3fc3
...@@ -271,6 +271,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -271,6 +271,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return input_vars_; return input_vars_;
} }
void VarBase::RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
py_funcs_[func_id] = py_func; py_funcs_[func_id] = py_func;
} }
......
...@@ -140,6 +140,8 @@ class VarBase { ...@@ -140,6 +140,8 @@ class VarBase {
} }
inline bool IsStopGradient() const { return stop_gradient_; } inline bool IsStopGradient() const { return stop_gradient_; }
void RunBackward();
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name, void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool pre_op_stop_gradient) { int pre_op_out_idx, bool pre_op_stop_gradient) {
pre_op_ = pre_op; pre_op_ = pre_op;
...@@ -150,22 +152,6 @@ class VarBase { ...@@ -150,22 +152,6 @@ class VarBase {
} }
} }
void RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
void ClearGradient() { void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name(); VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册