提交 82cbf987 编写于 作者: 0 0x45f

Fix run program grad node mem

上级 ae93930f
......@@ -88,7 +88,26 @@ inline void run_program_ad_func(
// Set Attributes
grad_node->SetAttrMap(attrs);
// Set TensorWrappers
grad_node->SetFwdX(x);
auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
// Clear unused x vars
auto temp_x = std::vector<paddle::Tensor>(x);
for (size_t i = 0; i < x.size(); i++) {
if (!backward_global_block->HasVar(x[i].name())) {
auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>());
fake.set_name("Fake_var");
temp_x[i] = fake;
}
}
grad_node->SetFwdX(temp_x);
// Clear unused out vars
for (size_t i = 0; i < out.size(); i++) {
if (!backward_global_block->HasVar(out[i]->name())) {
step_scope[0]->EraseVars({out[i]->name()});
}
}
grad_node->SetFwdParams(params);
grad_node->SetStepScope(step_scope);
......@@ -97,10 +116,6 @@ inline void run_program_ad_func(
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.
auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
auto& name = x[i].name();
......
......@@ -480,8 +480,6 @@ inline void RunProgramAPI(
}
inline void RunProgramGradAPI(
const std::vector<paddle::Tensor> &x UNUSED,
const std::vector<paddle::Tensor> &params UNUSED,
const std::vector<paddle::Tensor> &out_grad,
const std::vector<paddle::framework::Scope *> &step_scope, // NOLINT
const paddle::framework::AttributeMap &attrs,
......@@ -694,13 +692,8 @@ class GradNodeRunProgram : public egr::GradNodeBase {
for (size_t i = 0; i < out_grad_names.size(); ++i) {
hooked_grads[0][i].set_name(out_grad_names[i]);
}
RunProgramGradAPI(x_,
params_,
hooked_grads[0],
step_scope_,
attrs_,
x_grad_ptr,
params_grad_ptr);
RunProgramGradAPI(
hooked_grads[0], step_scope_, attrs_, x_grad_ptr, params_grad_ptr);
VLOG(3) << "End Eager Backward Node: GradNodeRunProgram";
executed_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册