未验证 提交 275a8102 编写于 作者: W WangZhen 提交者: GitHub

Fix run program grad node mem (#55869)

上级 613beeb6
......@@ -20,6 +20,7 @@
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/memory/allocation/allocator.h"
// Filter params without grads in global block. In this case, we will
// tag its AutogradMeta with stop_gradient = True to avoid fault from
......@@ -53,6 +54,41 @@ static void clear_no_grad_edges_with_partial_block(
}
}
static void clear_unused_out_var_in_backward(
const std::vector<paddle::Tensor*>& out,
const paddle::framework::BlockDesc* backward_block,
paddle::framework::Scope* scope) {
std::deque<std::shared_ptr<paddle::memory::Allocation>>* garbages =
new std::deque<std::shared_ptr<paddle::memory::Allocation>>();
for (auto* out_tensor : out) {
if (!backward_block->HasVar(out_tensor->name())) {
auto var = scope->FindVar(out_tensor->name());
if (var == nullptr) {
continue;
}
if (var->IsType<phi::DenseTensor>()) {
garbages->emplace_back(
var->GetMutable<phi::DenseTensor>()->MoveMemoryHolder());
}
}
}
delete garbages;
}
static std::vector<paddle::Tensor> filter_unused_input_var_in_backward(
const std::vector<paddle::Tensor>& x,
const paddle::framework::BlockDesc* backward_block) {
auto filter_x = std::vector<paddle::Tensor>(x);
for (size_t i = 0; i < x.size(); i++) {
if (!backward_block->HasVar(x[i].name())) {
auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>());
fake.set_name(paddle::framework::kFakeVarName);
filter_x[i] = fake;
}
}
return filter_x;
}
inline void run_program_ad_func(
const std::vector<paddle::Tensor>& x,
const std::vector<paddle::Tensor>& params,
......@@ -87,8 +123,18 @@ inline void run_program_ad_func(
// Set Attributes
grad_node->SetAttrMap(attrs);
auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
// Clear unused x vars
auto filter_x =
filter_unused_input_var_in_backward(x, backward_global_block);
// Set TensorWrappers
grad_node->SetFwdX(x);
grad_node->SetFwdX(filter_x);
// Clear unused out vars
clear_unused_out_var_in_backward(out, backward_global_block, step_scope[0]);
grad_node->SetFwdParams(params);
grad_node->SetStepScope(step_scope);
......@@ -97,10 +143,6 @@ inline void run_program_ad_func(
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.
auto* forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
auto* backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
auto& name = x[i].name();
......
......@@ -130,7 +130,7 @@ static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors,
paddle::framework::Scope *scope) {
for (size_t i = 0; i < tensors.size(); ++i) {
auto name = tensors[i].name();
if (name == "Fake_var") {
if (name == paddle::framework::kFakeVarName) {
continue;
}
auto *var = scope->Var(name);
......@@ -159,8 +159,8 @@ static void ShareTensorsFromScope(
// because we can't find them in scope. So we skip sharing these vars or
// var@GRAD if they don't appear in global block.
auto &name = tensors[i]->name();
if (name == paddle::framework::kEmptyVarName || name == "Fake_var" ||
!global_block.HasVar(name)) {
if (name == paddle::framework::kEmptyVarName ||
name == paddle::framework::kFakeVarName || !global_block.HasVar(name)) {
VLOG(2) << "find tensor name is " << name << ", skip it!";
continue;
}
......@@ -197,7 +197,8 @@ static void ShareTensorsFromScopeWithPartialBlock(
paddle::framework::Scope *scope) {
for (size_t i = 0; i < tensors.size(); ++i) {
auto &name = tensors[i]->name();
if (name == paddle::framework::kEmptyVarName || name == "Fake_var" ||
if (name == paddle::framework::kEmptyVarName ||
name == paddle::framework::kFakeVarName ||
(!forward_global_block.HasVar(name) &&
!backward_global_block.HasVar(name))) {
VLOG(2) << "find tensor name is " << name << ", skip it!";
......@@ -482,8 +483,6 @@ inline void RunProgramAPI(
}
inline void RunProgramGradAPI(
const std::vector<paddle::Tensor> &x UNUSED,
const std::vector<paddle::Tensor> &params UNUSED,
const std::vector<paddle::Tensor> &out_grad,
const std::vector<paddle::framework::Scope *> &step_scope, // NOLINT
const paddle::framework::AttributeMap &attrs,
......@@ -701,20 +700,19 @@ class GradNodeRunProgram : public egr::GradNodeBase {
for (size_t i = 0; i < out_grad_names.size(); ++i) {
hooked_grads[0][i].set_name(out_grad_names[i]);
}
RunProgramGradAPI(x_,
params_,
hooked_grads[0],
step_scope_,
attrs_,
x_grad_ptr,
params_grad_ptr);
RunProgramGradAPI(
hooked_grads[0], step_scope_, attrs_, x_grad_ptr, params_grad_ptr);
VLOG(3) << "End Eager Backward Node: GradNodeRunProgram";
executed_ = true;
return {x_grad, params_grad};
}
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
void ClearTensorWrappers() override {
x_.clear();
params_.clear();
SetIsTensorWrappersCleared(true);
}
// SetAttrMap
void SetAttrMap(const paddle::framework::AttributeMap &attrs) {
......
......@@ -65,6 +65,8 @@ PHI_DECLARE_int32(inner_op_parallelism);
namespace paddle {
namespace framework {
constexpr char kFakeVarName[] = "Fake_var";
/// If a variable is a empty variable, that name will be used.
constexpr char kEmptyVarName[] = "@EMPTY@";
......
......@@ -143,7 +143,7 @@ static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
if (var_names[i] == "Fake_var") {
if (var_names[i] == framework::kFakeVarName) {
continue;
}
auto *var = scope->Var(var_names[i]);
......@@ -162,7 +162,8 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
// because we can't findthem in scope. So we skip sharing these vars or
// var@GRAD if they don't appear in global block.
if (var_names[i] == framework::kEmptyVarName ||
var_names[i] == "Fake_var" || !global_block.HasVar(var_names[i])) {
var_names[i] == framework::kFakeVarName ||
!global_block.HasVar(var_names[i])) {
VLOG(2) << "find variable name is " << var_names[i] << ", skip it!";
continue;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册