未验证 提交 2592805b 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Fixed auto codegen for intermediate tensors (#39797)

* Refactored GradNodeAccumulation data structure and behaviour

* Fixed CI issues

* Fix compilation issues

* Fixed minor issues

* Reverted changes for intermediate and OverwriteOutput

* fixed minor issue

* Fixed auto codegen for intermediate tensors

* Removed restriction on AccumulationNode modification

* Fixed CI Coverage issues

* Adjusted Log contents

* Fixed CI issues
上级 eb7c211a
......@@ -52,49 +52,44 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
}
}
static void RetainGradForRegularNode(
const paddle::experimental::Tensor& tensor) {
AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor);
if (meta->RetainGrads()) {
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
if (IsLeafTensor(tensor)) {
// Leaf tensor's grad will always be retained
// Refer to implementation of AccumulationNode for more details
return;
} else {
meta->SetRetainGrads(true);
}
AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor);
if (meta->RetainGrads()) {
return;
} else {
meta->SetRetainGrads(true);
}
std::weak_ptr<paddle::experimental::Tensor> weak_grad_tensor =
meta->WeakGrad();
std::weak_ptr<paddle::experimental::Tensor> weak_grad_tensor =
meta->WeakGrad();
// Define Hook
auto hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) {
if (!weak_grad_tensor.expired()) {
auto grad_tensor = weak_grad_tensor.lock();
if (t.defined()) {
VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name();
// Simply Copy impl() to grad_tensor
grad_tensor->set_impl(t.impl());
return *grad_tensor.get();
// Define Hook
auto hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) {
if (!weak_grad_tensor.expired()) {
auto grad_tensor = weak_grad_tensor.lock();
if (t.defined()) {
VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name();
// Simply Copy impl() to grad_tensor
grad_tensor->set_impl(t.impl());
return *grad_tensor.get();
} else {
VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
return paddle::experimental::Tensor();
}
} else {
VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
return paddle::experimental::Tensor();
}
} else {
VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
return paddle::experimental::Tensor();
}
};
};
// Append to GradientHooks
RegisterGradientHookForTensor(tensor,
std::make_shared<egr::CppTensorHook>(hook));
}
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
if (IsLeafTensor(tensor)) {
// Leaf tensor's grad will always be retained
// Refer to implementation of AccumulationNode for more details
return;
} else {
RetainGradForRegularNode(tensor);
// Append to GradientHooks
RegisterGradientHookForTensor(tensor,
std::make_shared<egr::CppTensorHook>(hook));
}
}
......
......@@ -1156,11 +1156,13 @@ static std::string GenerateGradNodeCreationContent(
grad_node_creation_str += paddle::string::Sprintf(
SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position);
const char* SET_HISTORY_TEMPLATE =
" egr::EagerUtils::SetHistory(&%s, grad_node);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
// Intermediate Tensor does not require SetHistory
if (!output.intermediate()) {
const char* SET_HISTORY_TEMPLATE =
" egr::EagerUtils::SetHistory(&%s, grad_node);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
}
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(&%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
......@@ -1173,17 +1175,20 @@ static std::string GenerateGradNodeCreationContent(
grad_node_creation_str += paddle::string::Sprintf(
SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position);
const char* SET_HISTORY_TEMPLATE =
" egr::EagerUtils::SetHistory(%s, grad_node);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
// Intermediate Tensor does not require SetHistory
if (!output.intermediate()) {
const char* SET_HISTORY_TEMPLATE =
" egr::EagerUtils::SetHistory(%s, grad_node);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
}
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
}
// Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
......
......@@ -221,10 +221,11 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
<< " 's name is: " << grad_output_tensor.name();
auto* next_node = next_node_shared.get();
if (!node_input_buffers_dict.count(next_node)) {
node_input_buffers_dict[next_node] =
std::make_unique<GradTensorHolder>(next_node->InputMeta());
const auto& input_meta = next_node->InputMeta();
auto grad_tensor_holder =
std::make_unique<GradTensorHolder>(input_meta);
node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
}
VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
<< ", rank: " << edge_rank.second;
......
......@@ -244,7 +244,7 @@ GradNodeBase::ApplyGradientHooks(
if (!out.defined() || !out.initialized()) {
out = (*hook)(tensors[slot_id][rank]);
} else {
// If more than one hook is registered, the input to the next hook func
// If more than one hook is registered, the input to the next hook func
// should be the output of the previous hook
out = (*hook)(out);
}
......
......@@ -122,12 +122,21 @@ paddle::experimental::Tensor* EagerUtils::mutable_grad(
void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas,
const std::shared_ptr<GradNodeBase>& grad_node) {
for (const auto& autograd_meta : *autograd_metas) {
if (dynamic_cast<GradNodeAccumulation*>(autograd_meta->GradNode())) {
VLOG(6) << "Warning: Reseting GradNodeAccumulation for leaf tensor is "
"detected";
}
autograd_meta->SetGradNode(grad_node);
}
}
void EagerUtils::SetHistory(AutogradMeta* autograd_meta,
const std::shared_ptr<GradNodeBase>& grad_node) {
if (dynamic_cast<GradNodeAccumulation*>(autograd_meta->GradNode())) {
VLOG(6)
<< "Warning: Reseting GradNodeAccumulation for leaf tensor is detected";
}
autograd_meta->SetGradNode(grad_node);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册