未验证 提交 35c7c835 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Remove full reserved strategy (#42690)

* remove full reserved strategy

* fix inplace error
上级 6e90ba1b
...@@ -1156,28 +1156,20 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1156,28 +1156,20 @@ static std::string GenerateGradNodeCreationContent(
for (const auto& iter : op_base_infos) { for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map = const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap(); iter.GetGradInsFwdSlotnameMap();
const std::unordered_set<std::string>& no_need_buffer_ins =
iter.GetNoNeedBufferInputs();
for (auto& kv : grad_ins_fwd_slotname_map) { for (auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second; const std::string& tensor_wrapper_name = kv.second;
std::string full_reserved = "false";
if (fwd_outputs_name_pos_map.find(tensor_wrapper_name) ==
fwd_outputs_name_pos_map.end() &&
!no_need_buffer_ins.count(tensor_wrapper_name)) {
full_reserved = "true";
}
const char* SET_TENSOR_WRAPPER_TEMPLATE = const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s, %s);\n"; " grad_node->SetTensorWrapper%s(%s);\n";
// Replace output directly with input in inplace op. // Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) { if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = inplace_map[tensor_wrapper_name]; auto inplace_input_name = inplace_map[tensor_wrapper_name];
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name), SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(inplace_input_name), full_reserved); LegalizeVarName(inplace_input_name));
} else { } else {
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name), SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(tensor_wrapper_name), full_reserved); LegalizeVarName(tensor_wrapper_name));
} }
} }
} }
...@@ -2592,7 +2584,6 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2592,7 +2584,6 @@ static std::string GenerateGradNodeHeaderContents(
std::string tensor_wrapper_arg_str; std::string tensor_wrapper_arg_str;
std::string tensor_wrapper_body_str; std::string tensor_wrapper_body_str;
std::string full_reserved_str = "full_reserved";
std::string no_need_buffer_str = "false"; std::string no_need_buffer_str = "false";
if (no_need_buffer_ins.count(tensor_wrapper_name)) { if (no_need_buffer_ins.count(tensor_wrapper_name)) {
no_need_buffer_str = "true"; no_need_buffer_str = "true";
...@@ -2610,12 +2601,12 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2610,12 +2601,12 @@ static std::string GenerateGradNodeHeaderContents(
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"for(const auto& eager_tensor : %s) {\n" "for(const auto& eager_tensor : %s) {\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, %s " " %s.emplace_back( egr::TensorWrapper(eager_tensor "
"/*full_reserved*/, %s) );\n" ", %s) );\n"
" }\n"; " }\n";
tensor_wrapper_body_str = paddle::string::Sprintf( tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name, full_reserved_str, no_need_buffer_str); struct_tensor_wrapper_name, no_need_buffer_str);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = const char* CLEAR_TENSOR_WRAPPER_TEMPLATE =
"for (auto tw: %s) {\n" "for (auto tw: %s) {\n"
...@@ -2636,22 +2627,20 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2636,22 +2627,20 @@ static std::string GenerateGradNodeHeaderContents(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, %s /*full_reserved*/, %s);\n"; "%s = egr::TensorWrapper(%s, %s);\n";
tensor_wrapper_body_str = paddle::string::Sprintf( tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name, full_reserved_str, no_need_buffer_str); tensor_wrapper_name, no_need_buffer_str);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n"; const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n";
clear_tensor_wrappers_str += paddle::string::Sprintf( clear_tensor_wrappers_str += paddle::string::Sprintf(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name); CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
} }
std::string full_reserved_signature_str = "bool full_reserved";
const char* SET_TENSOR_WRAPPER_TEMPLATE = const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s, %s) {\n %s\n }\n"; " void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str += paddle::string::Sprintf( set_tensor_wrappers_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_arg_str, full_reserved_signature_str, tensor_wrapper_arg_str, tensor_wrapper_body_str);
tensor_wrapper_body_str);
} }
} }
VLOG(6) << "Generated TensorWrapper"; VLOG(6) << "Generated TensorWrapper";
......
...@@ -55,8 +55,8 @@ def ParseArguments(): ...@@ -55,8 +55,8 @@ def ParseArguments():
## Code Gen Templates ## ## Code Gen Templates ##
######################## ########################
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
""" void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ """ void SetTensorWrapper{}(const paddle::experimental::Tensor& {}) {{
{} = egr::TensorWrapper({}, full_reserved, {}); {} = egr::TensorWrapper({}, {});
}} }}
""" """
...@@ -69,9 +69,9 @@ CLEAR_TENSOR_WRAPPER_TEMPLATE = \ ...@@ -69,9 +69,9 @@ CLEAR_TENSOR_WRAPPER_TEMPLATE = \
""" """
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{ """ void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
for(const auto& eager_tensor : {}) {{ for(const auto& eager_tensor : {}) {{
{}.emplace_back(egr::TensorWrapper(eager_tensor, full_reserved, {})); {}.emplace_back(egr::TensorWrapper(eager_tensor, {}));
}}; }};
}} }}
""" """
...@@ -676,9 +676,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -676,9 +676,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
if is_fwd_input: if is_fwd_input:
if is_optional: if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()));"
else: else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers) set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: else:
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
...@@ -688,9 +688,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -688,9 +688,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
fwd_output_pos = forward_outputs_position_map[name][1] fwd_output_pos = forward_outputs_position_map[name][1]
if is_optional: if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()));"
else: else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_output_tensor_wrappers_list.append(set_tensor_wrappers) set_output_tensor_wrappers_list.append(set_tensor_wrappers)
set_input_tensor_wrappers_str = "\n".join( set_input_tensor_wrappers_str = "\n".join(
set_input_tensor_wrappers_list) set_input_tensor_wrappers_list)
......
...@@ -34,7 +34,6 @@ class TensorWrapper { ...@@ -34,7 +34,6 @@ class TensorWrapper {
public: public:
TensorWrapper() = default; TensorWrapper() = default;
explicit TensorWrapper(const paddle::experimental::Tensor& tensor, explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
bool full_reserved = false,
bool no_need_buffer = false) { bool no_need_buffer = false) {
// set inplace_version_snapshot_ according to tensor's current inplace // set inplace_version_snapshot_ according to tensor's current inplace
// version. // version.
...@@ -46,32 +45,12 @@ class TensorWrapper { ...@@ -46,32 +45,12 @@ class TensorWrapper {
} }
/** /**
* Normally, we should fully reserved all non-output or non-leaf fwd tensor * Normally, we should only save data and part of autograd_meta of fwd
* here. And for fwd output tensor, we should not reserve its autogradmeta, * tensor, and should not reserve its original grad_node,
* to avoid recursive depends on GradNodeBase * to avoid recursive and additional depends on GradNodeBase
* **/ * **/
full_reserved_ = full_reserved; auto* tensor_autograd_meta = EagerUtils::nullable_autograd_meta(tensor);
no_need_buffer_ = no_need_buffer; no_need_buffer_ = no_need_buffer;
if (full_reserved_) {
VLOG(6) << "Fully reserved tensor: " << tensor.name();
intermidiate_tensor_ = tensor;
if (no_need_buffer_) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(tensor.impl().get());
auto tw_dense_tensor =
std::make_shared<phi::DenseTensor>(*dense_tensor);
tw_dense_tensor->clear();
intermidiate_tensor_.set_impl(tw_dense_tensor);
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
}
}
return;
}
// shallow copy tensor_impl here // shallow copy tensor_impl here
if (no_need_buffer) { if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) { if (phi::DenseTensor::classof(tensor.impl().get())) {
...@@ -89,10 +68,11 @@ class TensorWrapper { ...@@ -89,10 +68,11 @@ class TensorWrapper {
intermidiate_tensor_.set_impl(tensor.impl()); intermidiate_tensor_.set_impl(tensor.impl());
} }
if (VLOG_IS_ON(7)) {
// TODO(jiabin): This may has server performance issue // TODO(jiabin): This may has server performance issue
intermidiate_tensor_.set_name(tensor.name() + "@Saved"); intermidiate_tensor_.set_name(tensor.name() + "@Saved");
}
auto* tensor_autograd_meta = EagerUtils::nullable_autograd_meta(tensor);
if (tensor_autograd_meta) { if (tensor_autograd_meta) {
auto autograd_meta = auto autograd_meta =
std::make_shared<AutogradMeta>(*tensor_autograd_meta); std::make_shared<AutogradMeta>(*tensor_autograd_meta);
...@@ -112,10 +92,6 @@ class TensorWrapper { ...@@ -112,10 +92,6 @@ class TensorWrapper {
check_inplace_version(); check_inplace_version();
// if it's full_reserved just return the full copy of tensor
if (full_reserved_) {
return intermidiate_tensor_;
} else {
paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;
std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock(); std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock();
...@@ -139,7 +115,6 @@ class TensorWrapper { ...@@ -139,7 +115,6 @@ class TensorWrapper {
return recovered_tensor; return recovered_tensor;
} }
}
void clear() { intermidiate_tensor_.reset(); } void clear() { intermidiate_tensor_.reset(); }
...@@ -179,7 +154,6 @@ class TensorWrapper { ...@@ -179,7 +154,6 @@ class TensorWrapper {
} }
private: private:
bool full_reserved_ = false;
bool no_need_buffer_ = false; bool no_need_buffer_ = false;
paddle::experimental::Tensor intermidiate_tensor_; paddle::experimental::Tensor intermidiate_tensor_;
std::weak_ptr<egr::GradNodeBase> weak_grad_node_; std::weak_ptr<egr::GradNodeBase> weak_grad_node_;
......
...@@ -40,9 +40,11 @@ TEST(TensorWrapper, Basic) { ...@@ -40,9 +40,11 @@ TEST(TensorWrapper, Basic) {
auto auto_grad0 = std::make_shared<egr::AutogradMeta>(edge0); auto auto_grad0 = std::make_shared<egr::AutogradMeta>(edge0);
et1.set_autograd_meta(auto_grad0); et1.set_autograd_meta(auto_grad0);
et1.set_name("et1"); et1.set_name("et1");
auto tw0 = egr::TensorWrapper(et1, true); auto tw0 = egr::TensorWrapper(et1);
auto recover_et1 = tw0.recover(); auto recover_et1 = tw0.recover();
CHECK_EQ(recover_et1.name(), std::string("et1")); if (VLOG_IS_ON(7)) {
CHECK_EQ(recover_et1.name(), std::string("et1@saved"));
}
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).first, CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).first,
egr::EagerUtils::OutRankInfo(et1).first); egr::EagerUtils::OutRankInfo(et1).first);
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).second, CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).second,
...@@ -68,13 +70,15 @@ TEST(TensorWrapper, Basic) { ...@@ -68,13 +70,15 @@ TEST(TensorWrapper, Basic) {
et2.set_autograd_meta(auto_grad1); et2.set_autograd_meta(auto_grad1);
auto tw1 = egr::TensorWrapper(et2, false); auto tw1 = egr::TensorWrapper(et2, false);
auto recover_et2 = tw1.recover(); auto recover_et2 = tw1.recover();
if (VLOG_IS_ON(7)) {
CHECK_EQ(recover_et2.name(), std::string("et2@Saved")); CHECK_EQ(recover_et2.name(), std::string("et2@Saved"));
}
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).first, CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).first,
egr::EagerUtils::OutRankInfo(et2).first); egr::EagerUtils::OutRankInfo(et2).first);
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).second, CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).second,
egr::EagerUtils::OutRankInfo(et2).second); egr::EagerUtils::OutRankInfo(et2).second);
// Test Raw recover // Test Raw recover
paddle::experimental::Tensor et3; paddle::experimental::Tensor et3;
auto tw2 = egr::TensorWrapper(et3, true); auto tw2 = egr::TensorWrapper(et3);
CHECK(tw2.recover().initialized() == false); CHECK(tw2.recover().initialized() == false);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册