提交 e9fdf909 编写于 作者: M minqiyang

Polish code

test=develop
上级 a15a3fc3
......@@ -271,6 +271,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return input_vars_;
}
void VarBase::RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
py_funcs_[func_id] = py_func;
}
......
......@@ -140,6 +140,8 @@ class VarBase {
}
inline bool IsStopGradient() const { return stop_gradient_; }
void RunBackward();
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool pre_op_stop_gradient) {
pre_op_ = pre_op;
......@@ -150,22 +152,6 @@ class VarBase {
}
}
void RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册