未验证 提交 279d2db3 编写于 作者: P pangyoki 提交者: GitHub

Cherry-pick PR41720, support no_need_buffer in eager_fluid state (#41720) (#41956)

* support no_need_buffer in eager_fluid state

* change no_need_buffer info from fwd_info to bwd_info

* fix CI fail, gru_unit donnot use no_need_buffer

* fix conflict between no_need_buffer and dispensable

* use tensor.define in dispensable

* solve conflict

* solve conflict
上级 968bf46e
...@@ -217,6 +217,13 @@ class GradNodeGenerationInfo { ...@@ -217,6 +217,13 @@ class GradNodeGenerationInfo {
return &grad_attrs_; return &grad_attrs_;
} }
const std::unordered_set<std::string>& GetNoNeedBufferInputs() const {
return no_need_buffer_ins_;
}
std::unordered_set<std::string>* GetMutableNoNeedBufferInputs() {
return &no_need_buffer_ins_;
}
private: private:
std::string op_base_type_; std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_; std::map<std::string, std::string> grad_outs_slotname_map_;
...@@ -229,6 +236,7 @@ class GradNodeGenerationInfo { ...@@ -229,6 +236,7 @@ class GradNodeGenerationInfo {
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>> std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs_; grad_outs_;
paddle::framework::AttributeMap grad_attrs_; paddle::framework::AttributeMap grad_attrs_;
std::unordered_set<std::string> no_need_buffer_ins_;
}; };
public: public:
...@@ -958,6 +966,12 @@ static bool CollectGradInformationFromOpInfo( ...@@ -958,6 +966,12 @@ static bool CollectGradInformationFromOpInfo(
VLOG(6) << "GradOuts Name: " << it.first; VLOG(6) << "GradOuts Name: " << it.first;
} }
} }
auto& inferer = op_base.Info().NoNeedBufferVarsInferer();
if (inferer && !special_no_need_buffer_op_set.count(op_type)) {
*(*op_base_infos)[index].GetMutableNoNeedBufferInputs() =
inferer(g_ins, g_outs, *op_base_grad_attrs);
}
} }
/* ------ Slot Name Matching ---- */ /* ------ Slot Name Matching ---- */
...@@ -1129,11 +1143,14 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1129,11 +1143,14 @@ static std::string GenerateGradNodeCreationContent(
for (const auto& iter : op_base_infos) { for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map = const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap(); iter.GetGradInsFwdSlotnameMap();
const std::unordered_set<std::string>& no_need_buffer_ins =
iter.GetNoNeedBufferInputs();
for (auto& kv : grad_ins_fwd_slotname_map) { for (auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second; const std::string& tensor_wrapper_name = kv.second;
std::string full_reserved = "false"; std::string full_reserved = "false";
if (fwd_outputs_name_pos_map.find(tensor_wrapper_name) == if (fwd_outputs_name_pos_map.find(tensor_wrapper_name) ==
fwd_outputs_name_pos_map.end()) { fwd_outputs_name_pos_map.end() &&
!no_need_buffer_ins.count(tensor_wrapper_name)) {
full_reserved = "true"; full_reserved = "true";
} }
const char* SET_TENSOR_WRAPPER_TEMPLATE = const char* SET_TENSOR_WRAPPER_TEMPLATE =
...@@ -2064,7 +2081,7 @@ static std::string GenerateSingleOpBase( ...@@ -2064,7 +2081,7 @@ static std::string GenerateSingleOpBase(
} else { } else {
const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE = const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE =
" auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s);\n" " auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s);\n"
" if(%s.initialized()) %s[\"%s\"] = " " if(%s.defined()) %s[\"%s\"] = "
" egr::EagerUtils::TrySyncToVars(%s);\n"; " egr::EagerUtils::TrySyncToVars(%s);\n";
generated_grad_function_body += paddle::string::Sprintf( generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, grad_input_name, DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, grad_input_name,
...@@ -2190,7 +2207,7 @@ static std::string GenerateSingleOpBase( ...@@ -2190,7 +2207,7 @@ static std::string GenerateSingleOpBase(
grad_output_name, fwd_input_position); grad_output_name, fwd_input_position);
} else { } else {
const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE = const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE =
" if(%s.initialized()) %s[\"%s\"] = " " if(%s.defined()) %s[\"%s\"] = "
"{std::make_shared<egr::EagerVariable>(egr::Controller::" "{std::make_shared<egr::EagerVariable>(egr::Controller::"
"Instance().GenerateUniqueName())};\n"; "Instance().GenerateUniqueName())};\n";
generated_grad_function_body += paddle::string::Sprintf( generated_grad_function_body += paddle::string::Sprintf(
...@@ -2532,6 +2549,8 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2532,6 +2549,8 @@ static std::string GenerateGradNodeHeaderContents(
for (const auto& iter : op_base_infos) { for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map = const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap(); iter.GetGradInsFwdSlotnameMap();
const std::unordered_set<std::string>& no_need_buffer_ins =
iter.GetNoNeedBufferInputs();
for (const auto& kv : grad_ins_fwd_slotname_map) { for (const auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second; const std::string& tensor_wrapper_name = kv.second;
...@@ -2540,6 +2559,10 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2540,6 +2559,10 @@ static std::string GenerateGradNodeHeaderContents(
std::string tensor_wrapper_arg_str; std::string tensor_wrapper_arg_str;
std::string tensor_wrapper_body_str; std::string tensor_wrapper_body_str;
std::string full_reserved_str = "full_reserved"; std::string full_reserved_str = "full_reserved";
std::string no_need_buffer_str = "false";
if (no_need_buffer_ins.count(tensor_wrapper_name)) {
no_need_buffer_str = "true";
}
if (duplicable_tensors.count(tensor_wrapper_name)) { if (duplicable_tensors.count(tensor_wrapper_name)) {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const std::vector<paddle::experimental::Tensor>& %s"; "const std::vector<paddle::experimental::Tensor>& %s";
...@@ -2553,12 +2576,12 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2553,12 +2576,12 @@ static std::string GenerateGradNodeHeaderContents(
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"for(const auto& eager_tensor : %s) {\n" "for(const auto& eager_tensor : %s) {\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, true " " %s.emplace_back( egr::TensorWrapper(eager_tensor, %s "
"/*full_reserved*/) );\n" "/*full_reserved*/, %s) );\n"
" }\n"; " }\n";
tensor_wrapper_body_str = paddle::string::Sprintf( tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name); struct_tensor_wrapper_name, full_reserved_str, no_need_buffer_str);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = const char* CLEAR_TENSOR_WRAPPER_TEMPLATE =
"for (auto tw: %s) {\n" "for (auto tw: %s) {\n"
...@@ -2579,10 +2602,10 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2579,10 +2602,10 @@ static std::string GenerateGradNodeHeaderContents(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, %s /*full_reserved*/);\n"; "%s = egr::TensorWrapper(%s, %s /*full_reserved*/, %s);\n";
tensor_wrapper_body_str = paddle::string::Sprintf( tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name, full_reserved_str); tensor_wrapper_name, full_reserved_str, no_need_buffer_str);
const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n"; const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n";
clear_tensor_wrappers_str += paddle::string::Sprintf( clear_tensor_wrappers_str += paddle::string::Sprintf(
......
...@@ -276,3 +276,11 @@ std::set<std::string> special_inplace_op_set = { ...@@ -276,3 +276,11 @@ std::set<std::string> special_inplace_op_set = {
"sum", // `sum` op has duplicate input "sum", // `sum` op has duplicate input
"assign", // output of `assign` op is in `op_passing_outs_map` "assign", // output of `assign` op is in `op_passing_outs_map`
}; };
// NOTE(pangyoki): Special no_need_buffer ops that are not supported in
// temporary.
// sequence_conv op will raise error to get no_need_buffer info during
// compiling.
std::set<std::string> special_no_need_buffer_op_set = {
"sequence_conv",
};
...@@ -510,5 +510,19 @@ class TestContinuouslyInplace(unittest.TestCase): ...@@ -510,5 +510,19 @@ class TestContinuouslyInplace(unittest.TestCase):
self.func_test_continuously_inplace() self.func_test_continuously_inplace()
class TestGetitemBeforeInplace(unittest.TestCase):
def test_getitem_before_inplace(self):
with _test_eager_guard():
a = paddle.ones(shape=[4, 2, 3], dtype="float32")
a.stop_gradient = False
b = a**2
b[0] = 3
# getitem has no_need_buffer input
c = b[0:2]
loss = c.sum()
b[1] = 2
loss.backward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册