未验证 提交 733d8168 编写于 作者: 0 0x45f 提交者: GitHub

Fix test_reinforcement_learning.py for eager run_program OP (#41018)

* Fix test_reinforcement_learning.py for eager run_program OP

* Add comment
上级 05f3d48e
...@@ -57,7 +57,6 @@ inline void run_program_dygraph_function( ...@@ -57,7 +57,6 @@ inline void run_program_dygraph_function(
auto grad_node = std::make_shared<GradNodeRunProgram>(1, 2); auto grad_node = std::make_shared<GradNodeRunProgram>(1, 2);
grad_node->SetFwdOutNames(out_names); grad_node->SetFwdOutNames(out_names);
grad_node->SetOut(out);
// Set Attributes // Set Attributes
grad_node->SetAttrMap(attrs); grad_node->SetAttrMap(attrs);
// Set TensorWrappers // Set TensorWrappers
......
...@@ -362,13 +362,16 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -362,13 +362,16 @@ class GradNodeRunProgram : public egr::GradNodeBase {
std::vector<std::vector<paddle::experimental::Tensor>> &grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>> &grads, // NOLINT
bool create_graph) override { bool create_graph) override {
VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram";
PADDLE_ENFORCE_EQ( std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads =
grads.size(), 1, GradNodeRunProgram::ApplyGradientHooks(grads);
paddle::platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(hooked_grads.size(), 1,
"The out_grads.size() of RunProgramGradOp should be equal to 1.")); paddle::platform::errors::InvalidArgument(
"The hooked_grads.size() of RunProgramGradOp should "
"be equal to 1."));
egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta()); egr::EagerUtils::FillZeroForEmptyGradInputs(&hooked_grads,
VLOG(3) << "out_grads[0].size() : " << grads[0].size(); this->InputMeta());
VLOG(3) << "hooked_grads[0].size() : " << hooked_grads[0].size();
std::vector<paddle::experimental::Tensor> x_grad; std::vector<paddle::experimental::Tensor> x_grad;
std::vector<paddle::experimental::Tensor> params_grad; std::vector<paddle::experimental::Tensor> params_grad;
ConstructXGradTensors(x_, &x_grad); ConstructXGradTensors(x_, &x_grad);
...@@ -382,21 +385,15 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -382,21 +385,15 @@ class GradNodeRunProgram : public egr::GradNodeBase {
params_grad_ptr.emplace_back(&i); params_grad_ptr.emplace_back(&i);
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(hooked_grads[0].size(), fwd_out_names_.size(),
grads[0].size(), fwd_out_names_.size(), paddle::platform::errors::InvalidArgument(
paddle::platform::errors::InvalidArgument( "The hooked_grads[0].size() and "
"The grads[0].size() and fwd_out_names_.size() should be equal.")); "fwd_out_names_.size() should be equal."));
for (size_t i = 0; i < fwd_out_names_.size(); ++i) { for (size_t i = 0; i < fwd_out_names_.size(); ++i) {
auto &out_grad = egr::EagerUtils::unsafe_autograd_meta(*out_[i])->Grad(); hooked_grads[0][i].set_name(fwd_out_names_[i] + "@GRAD");
const_cast<paddle::experimental::Tensor &>(out_grad).set_impl(
grads[0][i].impl());
const_cast<paddle::experimental::Tensor &>(grads[0][i])
.set_name(fwd_out_names_[i] + "@GRAD");
} }
RunProgramGradAPI(x_, params_, hooked_grads[0], step_scope_, attrs_,
RunProgramGradAPI(x_, params_, grads[0], step_scope_, attrs_, x_grad_ptr, x_grad_ptr, params_grad_ptr);
params_grad_ptr);
VLOG(3) << "End Eager Backward Node: GradNodeRunProgram"; VLOG(3) << "End Eager Backward Node: GradNodeRunProgram";
return {x_grad, params_grad}; return {x_grad, params_grad};
} }
...@@ -428,10 +425,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -428,10 +425,6 @@ class GradNodeRunProgram : public egr::GradNodeBase {
fwd_out_names_ = out_names; fwd_out_names_ = out_names;
} }
void SetOut(const std::vector<paddle::experimental::Tensor *> &out) {
out_ = out;
}
protected: protected:
void ConstructXGradTensors( void ConstructXGradTensors(
const std::vector<paddle::experimental::Tensor> &x, const std::vector<paddle::experimental::Tensor> &x,
...@@ -454,6 +447,9 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -454,6 +447,9 @@ class GradNodeRunProgram : public egr::GradNodeBase {
for (auto &t : param) { for (auto &t : param) {
auto t_meta = egr::EagerUtils::unsafe_autograd_meta(t); auto t_meta = egr::EagerUtils::unsafe_autograd_meta(t);
auto t_grad = egr::EagerUtils::unsafe_autograd_meta(t)->Grad(); auto t_grad = egr::EagerUtils::unsafe_autograd_meta(t)->Grad();
// In eager mode, the number of param_grad should be the same as
// param, so here an empty Tensor is added for the param with
// stop_gradient=True
if (t_meta->StopGradient()) { if (t_meta->StopGradient()) {
param_grad->emplace_back(); param_grad->emplace_back();
} else if (t_grad.is_dense_tensor()) { } else if (t_grad.is_dense_tensor()) {
...@@ -472,7 +468,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -472,7 +468,6 @@ class GradNodeRunProgram : public egr::GradNodeBase {
std::vector<paddle::framework::Scope *> step_scope_; std::vector<paddle::framework::Scope *> step_scope_;
std::vector<std::string> fwd_out_names_; std::vector<std::string> fwd_out_names_;
std::vector<paddle::experimental::Tensor *> out_;
// Attribute Map // Attribute Map
paddle::framework::AttributeMap attrs_; paddle::framework::AttributeMap attrs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册