未验证 提交 35c7c835 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Remove full reserved strategy (#42690)

* remove full reserved strategy

* fix inplace error
上级 6e90ba1b
......@@ -1156,28 +1156,20 @@ static std::string GenerateGradNodeCreationContent(
for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap();
const std::unordered_set<std::string>& no_need_buffer_ins =
iter.GetNoNeedBufferInputs();
for (auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second;
std::string full_reserved = "false";
if (fwd_outputs_name_pos_map.find(tensor_wrapper_name) ==
fwd_outputs_name_pos_map.end() &&
!no_need_buffer_ins.count(tensor_wrapper_name)) {
full_reserved = "true";
}
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s, %s);\n";
" grad_node->SetTensorWrapper%s(%s);\n";
// Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = inplace_map[tensor_wrapper_name];
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(inplace_input_name), full_reserved);
LegalizeVarName(inplace_input_name));
} else {
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(tensor_wrapper_name), full_reserved);
LegalizeVarName(tensor_wrapper_name));
}
}
}
......@@ -2592,7 +2584,6 @@ static std::string GenerateGradNodeHeaderContents(
std::string tensor_wrapper_arg_str;
std::string tensor_wrapper_body_str;
std::string full_reserved_str = "full_reserved";
std::string no_need_buffer_str = "false";
if (no_need_buffer_ins.count(tensor_wrapper_name)) {
no_need_buffer_str = "true";
......@@ -2610,12 +2601,12 @@ static std::string GenerateGradNodeHeaderContents(
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"for(const auto& eager_tensor : %s) {\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, %s "
"/*full_reserved*/, %s) );\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor "
", %s) );\n"
" }\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name, full_reserved_str, no_need_buffer_str);
struct_tensor_wrapper_name, no_need_buffer_str);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE =
"for (auto tw: %s) {\n"
......@@ -2636,22 +2627,20 @@ static std::string GenerateGradNodeHeaderContents(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, %s /*full_reserved*/, %s);\n";
"%s = egr::TensorWrapper(%s, %s);\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name, full_reserved_str, no_need_buffer_str);
tensor_wrapper_name, no_need_buffer_str);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n";
clear_tensor_wrappers_str += paddle::string::Sprintf(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
}
std::string full_reserved_signature_str = "bool full_reserved";
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s, %s) {\n %s\n }\n";
" void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_arg_str, full_reserved_signature_str,
tensor_wrapper_body_str);
tensor_wrapper_arg_str, tensor_wrapper_body_str);
}
}
VLOG(6) << "Generated TensorWrapper";
......
......@@ -55,8 +55,8 @@ def ParseArguments():
## Code Gen Templates ##
########################
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
""" void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
{} = egr::TensorWrapper({}, full_reserved, {});
""" void SetTensorWrapper{}(const paddle::experimental::Tensor& {}) {{
{} = egr::TensorWrapper({}, {});
}}
"""
......@@ -69,9 +69,9 @@ CLEAR_TENSOR_WRAPPER_TEMPLATE = \
"""
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
for(const auto& eager_tensor : {}) {{
{}.emplace_back(egr::TensorWrapper(eager_tensor, full_reserved, {}));
{}.emplace_back(egr::TensorWrapper(eager_tensor, {}));
}};
}}
"""
......@@ -676,9 +676,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()));"
else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);"
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else:
if num_fwd_outputs > 1:
......@@ -688,9 +688,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
fwd_output_pos = forward_outputs_position_map[name][1]
if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()));"
else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_output_tensor_wrappers_list.append(set_tensor_wrappers)
set_input_tensor_wrappers_str = "\n".join(
set_input_tensor_wrappers_list)
......
......@@ -34,7 +34,6 @@ class TensorWrapper {
public:
TensorWrapper() = default;
explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
bool full_reserved = false,
bool no_need_buffer = false) {
// set inplace_version_snapshot_ according to tensor's current inplace
// version.
......@@ -46,32 +45,12 @@ class TensorWrapper {
}
/**
* Normally, we should fully reserved all non-output or non-leaf fwd tensor
* here. And for fwd output tensor, we should not reserve its autogradmeta,
* to avoid recursive depends on GradNodeBase
* Normally, we should only save data and part of autograd_meta of fwd
* tensor, and should not reserve its original grad_node,
* to avoid recursive and additional depends on GradNodeBase
* **/
full_reserved_ = full_reserved;
auto* tensor_autograd_meta = EagerUtils::nullable_autograd_meta(tensor);
no_need_buffer_ = no_need_buffer;
if (full_reserved_) {
VLOG(6) << "Fully reserved tensor: " << tensor.name();
intermidiate_tensor_ = tensor;
if (no_need_buffer_) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(tensor.impl().get());
auto tw_dense_tensor =
std::make_shared<phi::DenseTensor>(*dense_tensor);
tw_dense_tensor->clear();
intermidiate_tensor_.set_impl(tw_dense_tensor);
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
}
}
return;
}
// shallow copy tensor_impl here
if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
......@@ -89,10 +68,11 @@ class TensorWrapper {
intermidiate_tensor_.set_impl(tensor.impl());
}
// TODO(jiabin): This may has server performance issue
intermidiate_tensor_.set_name(tensor.name() + "@Saved");
if (VLOG_IS_ON(7)) {
// TODO(jiabin): This may has server performance issue
intermidiate_tensor_.set_name(tensor.name() + "@Saved");
}
auto* tensor_autograd_meta = EagerUtils::nullable_autograd_meta(tensor);
if (tensor_autograd_meta) {
auto autograd_meta =
std::make_shared<AutogradMeta>(*tensor_autograd_meta);
......@@ -112,33 +92,28 @@ class TensorWrapper {
check_inplace_version();
// if it's full_reserved just return the full copy of tensor
if (full_reserved_) {
return intermidiate_tensor_;
paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;
std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock();
if (new_grad_node) {
VLOG(3) << "Recovered TensorWrapper with GradNode "
<< new_grad_node->name() << " addr: " << new_grad_node.get();
} else {
paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;
VLOG(3) << "Recovered TensorWrapper with Empty GradNode";
}
auto* intermediate_autograd_meta =
EagerUtils::nullable_autograd_meta(intermidiate_tensor_);
std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock();
if (intermediate_autograd_meta) {
auto p_ab_autograd_meta =
std::make_shared<AutogradMeta>(*intermediate_autograd_meta);
if (new_grad_node) {
VLOG(3) << "Recovered TensorWrapper with GradNode "
<< new_grad_node->name() << " addr: " << new_grad_node.get();
} else {
VLOG(3) << "Recovered TensorWrapper with Empty GradNode";
p_ab_autograd_meta->SetGradNode(new_grad_node);
}
auto* intermediate_autograd_meta =
EagerUtils::nullable_autograd_meta(intermidiate_tensor_);
if (intermediate_autograd_meta) {
auto p_ab_autograd_meta =
std::make_shared<AutogradMeta>(*intermediate_autograd_meta);
if (new_grad_node) {
p_ab_autograd_meta->SetGradNode(new_grad_node);
}
recovered_tensor.set_autograd_meta(p_ab_autograd_meta);
}
return recovered_tensor;
recovered_tensor.set_autograd_meta(p_ab_autograd_meta);
}
return recovered_tensor;
}
void clear() { intermidiate_tensor_.reset(); }
......@@ -179,7 +154,6 @@ class TensorWrapper {
}
private:
bool full_reserved_ = false;
bool no_need_buffer_ = false;
paddle::experimental::Tensor intermidiate_tensor_;
std::weak_ptr<egr::GradNodeBase> weak_grad_node_;
......
......@@ -40,9 +40,11 @@ TEST(TensorWrapper, Basic) {
auto auto_grad0 = std::make_shared<egr::AutogradMeta>(edge0);
et1.set_autograd_meta(auto_grad0);
et1.set_name("et1");
auto tw0 = egr::TensorWrapper(et1, true);
auto tw0 = egr::TensorWrapper(et1);
auto recover_et1 = tw0.recover();
CHECK_EQ(recover_et1.name(), std::string("et1"));
if (VLOG_IS_ON(7)) {
CHECK_EQ(recover_et1.name(), std::string("et1@saved"));
}
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).first,
egr::EagerUtils::OutRankInfo(et1).first);
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).second,
......@@ -68,13 +70,15 @@ TEST(TensorWrapper, Basic) {
et2.set_autograd_meta(auto_grad1);
auto tw1 = egr::TensorWrapper(et2, false);
auto recover_et2 = tw1.recover();
CHECK_EQ(recover_et2.name(), std::string("et2@Saved"));
if (VLOG_IS_ON(7)) {
CHECK_EQ(recover_et2.name(), std::string("et2@Saved"));
}
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).first,
egr::EagerUtils::OutRankInfo(et2).first);
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).second,
egr::EagerUtils::OutRankInfo(et2).second);
// Test Raw recover
paddle::experimental::Tensor et3;
auto tw2 = egr::TensorWrapper(et3, true);
auto tw2 = egr::TensorWrapper(et3);
CHECK(tw2.recover().initialized() == false);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册