未验证 提交 f484a61e 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Fix param and out grad names in dy2st for high order grad (#49461)

* Fix param and out grad names in dy2st for high order grad
上级 dc13f7c5
...@@ -80,16 +80,10 @@ inline void run_program_ad_func( ...@@ -80,16 +80,10 @@ inline void run_program_ad_func(
trace_backward, &p_autograd_x, &p_autograd_params); trace_backward, &p_autograd_x, &p_autograd_params);
if (require_any_grad) { if (require_any_grad) {
std::vector<std::string> out_names;
for (auto& t : deref_out) {
out_names.emplace_back(t.name());
}
egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); egr::EagerUtils::PassStopGradient(false, &p_autograd_outs);
// Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad])
auto grad_node = std::make_shared<GradNodeRunProgram>(1, 2); auto grad_node = std::make_shared<GradNodeRunProgram>(1, 2);
grad_node->SetFwdOutNames(out_names);
// Set Attributes // Set Attributes
grad_node->SetAttrMap(attrs); grad_node->SetAttrMap(attrs);
// Set TensorWrappers // Set TensorWrappers
......
...@@ -791,13 +791,15 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -791,13 +791,15 @@ class GradNodeRunProgram : public egr::GradNodeBase {
} }
} }
auto out_grad_names =
PADDLE_GET_CONST(std::vector<std::string>, attrs_.at("out_grad_names"));
PADDLE_ENFORCE_EQ(hooked_grads[0].size(), PADDLE_ENFORCE_EQ(hooked_grads[0].size(),
fwd_out_names_.size(), out_grad_names.size(),
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The hooked_grads[0].size() and " "The hooked_grads[0].size() and "
"fwd_out_names_.size() should be equal.")); "out_grad_names.size() should be equal."));
for (size_t i = 0; i < fwd_out_names_.size(); ++i) { for (size_t i = 0; i < out_grad_names.size(); ++i) {
hooked_grads[0][i].set_name(fwd_out_names_[i] + "@GRAD"); hooked_grads[0][i].set_name(out_grad_names[i]);
} }
RunProgramGradAPI(x_, RunProgramGradAPI(x_,
params_, params_,
...@@ -829,10 +831,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -829,10 +831,6 @@ class GradNodeRunProgram : public egr::GradNodeBase {
step_scope_ = scopes; step_scope_ = scopes;
} }
void SetFwdOutNames(std::vector<std::string> out_names) {
fwd_out_names_ = out_names;
}
protected: protected:
void ConstructXGradTensors( void ConstructXGradTensors(
const std::vector<paddle::experimental::Tensor> &x, const std::vector<paddle::experimental::Tensor> &x,
...@@ -850,21 +848,30 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -850,21 +848,30 @@ class GradNodeRunProgram : public egr::GradNodeBase {
} }
void ConstructParamGradTensors( void ConstructParamGradTensors(
const std::vector<paddle::experimental::Tensor> &param, const std::vector<paddle::experimental::Tensor> &params,
std::vector<paddle::experimental::Tensor> *param_grad) { std::vector<paddle::experimental::Tensor> *param_grads) {
for (auto &t : param) { auto param_grad_names = PADDLE_GET_CONST(std::vector<std::string>,
auto t_grad = egr::EagerUtils::unsafe_autograd_meta(t)->Grad(); attrs_.at("param_grad_names"));
PADDLE_ENFORCE_EQ(params.size(),
param_grad_names.size(),
paddle::platform::errors::InvalidArgument(
"The param.size() and "
"param_grad_names.size() should be equal."));
for (size_t i = 0; i < params.size(); ++i) {
auto &p = params[i];
auto &p_grad = egr::EagerUtils::unsafe_autograd_meta(p)->Grad();
// In eager mode, the number of param_grad should be the same as // In eager mode, the number of param_grad should be the same as
// param, so here an empty Tensor is added for the param with // param, so here an empty Tensor is added for the param with
// stop_gradient=True // stop_gradient=True
if (!t_grad.defined()) { if (!p_grad.defined()) {
param_grad->emplace_back(); param_grads->emplace_back();
} else if (t_grad.is_dense_tensor()) { } else if (p_grad.is_dense_tensor()) {
param_grad->emplace_back(std::make_shared<phi::DenseTensor>()); param_grads->emplace_back(std::make_shared<phi::DenseTensor>());
} else if (t_grad.is_selected_rows()) { } else if (p_grad.is_selected_rows()) {
param_grad->emplace_back(std::make_shared<phi::SelectedRows>()); param_grads->emplace_back(std::make_shared<phi::SelectedRows>());
} }
param_grad->back().set_name(t.name() + "@GRAD"); param_grads->back().set_name(param_grad_names[i]);
} }
} }
...@@ -880,8 +887,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -880,8 +887,6 @@ class GradNodeRunProgram : public egr::GradNodeBase {
std::vector<paddle::experimental::Tensor> params_; std::vector<paddle::experimental::Tensor> params_;
std::vector<paddle::framework::Scope *> step_scope_; std::vector<paddle::framework::Scope *> step_scope_;
std::vector<std::string> fwd_out_names_;
// Attribute Map // Attribute Map
paddle::framework::AttributeMap attrs_; paddle::framework::AttributeMap attrs_;
}; };
...@@ -130,6 +130,14 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -130,6 +130,14 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"(BlockDesc *)" "(BlockDesc *)"
"The global block of executed backward program desc.") "The global block of executed backward program desc.")
.SetDefault(nullptr); .SetDefault(nullptr);
AddAttr<std::vector<std::string>>("param_grad_names",
"std::vector<std::string>"
"The names of parameter gradients.")
.SetDefault({});
AddAttr<std::vector<std::string>>("out_grad_names",
"std::vector<std::string>"
"The names of output gradients.")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
RunProgram operator. RunProgram operator.
......
...@@ -135,6 +135,10 @@ class TestRunProgram(unittest.TestCase): ...@@ -135,6 +135,10 @@ class TestRunProgram(unittest.TestCase):
False, False,
'program_id', 'program_id',
_hash_with_id(program), _hash_with_id(program),
'param_grad_names',
['Fake_var@GRAD'],
'out_grad_names',
[out.name + '@GRAD'],
] ]
use_interpretorcore = ( use_interpretorcore = (
......
...@@ -262,6 +262,15 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -262,6 +262,15 @@ class RunProgramOpTest(unittest.TestCase):
) )
) )
self.attrs.extend(
(
'param_grad_names',
[p.name + '@GRAD' for p in inputs['Params']],
'out_grad_names',
[out.name + '@GRAD' for out in outputs['Out']],
)
)
_legacy_C_ops.run_program( _legacy_C_ops.run_program(
inputs['X'], inputs['X'],
inputs['Params'], inputs['Params'],
...@@ -305,6 +314,15 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -305,6 +314,15 @@ class RunProgramOpTest(unittest.TestCase):
) )
) )
self.attrs.extend(
(
'param_grad_names',
[p.name + '@GRAD' for p in inputs['Params']],
'out_grad_names',
[out.name + '@GRAD' for out in outputs['Out']],
)
)
_legacy_C_ops.run_program( _legacy_C_ops.run_program(
inputs['X'], inputs['X'],
inputs['Params'], inputs['Params'],
......
...@@ -253,7 +253,7 @@ class PartialProgramLayer: ...@@ -253,7 +253,7 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_program(self): def _create_forward_backward_train_program(self):
whole_program = self._create_program() whole_program = self._train_program
forward_end_op_index = self._infer_program.desc.block(0).op_size() forward_end_op_index = self._infer_program.desc.block(0).op_size()
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
whole_program, forward_end_op_index whole_program, forward_end_op_index
...@@ -261,7 +261,7 @@ class PartialProgramLayer: ...@@ -261,7 +261,7 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_amp_program(self): def _create_forward_backward_train_amp_program(self):
whole_program = self._create_amp_program() whole_program = self._train_amp_program
forward_end_op_index = self._infer_amp_program.desc.block(0).op_size() forward_end_op_index = self._infer_amp_program.desc.block(0).op_size()
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
whole_program, forward_end_op_index whole_program, forward_end_op_index
...@@ -269,7 +269,7 @@ class PartialProgramLayer: ...@@ -269,7 +269,7 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self): def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._create_pure_fp16_program() whole_program = self._train_pure_fp16_program
forward_end_op_index = self._infer_pure_fp16_program.desc.block( forward_end_op_index = self._infer_pure_fp16_program.desc.block(
0 0
).op_size() ).op_size()
...@@ -404,6 +404,43 @@ class PartialProgramLayer: ...@@ -404,6 +404,43 @@ class PartialProgramLayer:
def _infer_pure_fp16_program_id(self): def _infer_pure_fp16_program_id(self):
return _hash_with_id(self._infer_pure_fp16_program, self) return _hash_with_id(self._infer_pure_fp16_program, self)
@LazyInitialized
def _param_grad_names(self):
names = []
# NOTE: `names` and `self._params` must be in the same order so that
# the param grad name can be set correctly in the run_program.
for param in self._params:
candidate = [
var_name
for var_name in self.backward_program.block(0).vars.keys()
if var_name.endswith(param.name + '@GRAD')
]
if candidate:
names.append(
max(candidate, key=lambda name: name.count('grad/'))
)
else:
names.append(param.name + '@GRAD')
return names
@LazyInitialized
def _out_grad_names(self):
names = []
fwd_end_op_index = self._get_end_op_index()
for i in range(
fwd_end_op_index + 1,
min(
fwd_end_op_index + 2 * len(self._outputs.var_ids),
len(self.program.block(0).ops),
),
2,
):
op = self.program.block(0).ops[i]
if op.type == 'fill_constant':
var_name = op.output('Out')[0]
names.append(var_name)
return names
@property @property
def whole_program_id(self): def whole_program_id(self):
if self.training: if self.training:
...@@ -610,6 +647,18 @@ class PartialProgramLayer: ...@@ -610,6 +647,18 @@ class PartialProgramLayer:
'program_id', 'program_id',
self.program_id, self.program_id,
] ]
if self.training:
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
# the correct names of the parameter grads from program. And out grads are similar to above.
attrs.extend(
(
'param_grad_names',
self._param_grad_names,
'out_grad_names',
self._out_grad_names,
)
)
if self._cuda_graph_capture_mode: if self._cuda_graph_capture_mode:
attrs.extend( attrs.extend(
( (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册