diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 726e049e611509d5b403fa699a7633fdd90ab37b..307f8fae31597626fd24840b94bbfbddc3606683 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -217,6 +217,13 @@ class GradNodeGenerationInfo { return &grad_attrs_; } + const std::unordered_set& GetNoNeedBufferInputs() const { + return no_need_buffer_ins_; + } + std::unordered_set* GetMutableNoNeedBufferInputs() { + return &no_need_buffer_ins_; + } + private: std::string op_base_type_; std::map grad_outs_slotname_map_; @@ -229,6 +236,7 @@ class GradNodeGenerationInfo { std::vector>> grad_outs_; paddle::framework::AttributeMap grad_attrs_; + std::unordered_set no_need_buffer_ins_; }; public: @@ -958,6 +966,12 @@ static bool CollectGradInformationFromOpInfo( VLOG(6) << "GradOuts Name: " << it.first; } } + + auto& inferer = op_base.Info().NoNeedBufferVarsInferer(); + if (inferer && !special_no_need_buffer_op_set.count(op_type)) { + *(*op_base_infos)[index].GetMutableNoNeedBufferInputs() = + inferer(g_ins, g_outs, *op_base_grad_attrs); + } } /* ------ Slot Name Matching ---- */ @@ -1129,11 +1143,14 @@ static std::string GenerateGradNodeCreationContent( for (const auto& iter : op_base_infos) { const std::map& grad_ins_fwd_slotname_map = iter.GetGradInsFwdSlotnameMap(); + const std::unordered_set& no_need_buffer_ins = + iter.GetNoNeedBufferInputs(); for (auto& kv : grad_ins_fwd_slotname_map) { const std::string& tensor_wrapper_name = kv.second; std::string full_reserved = "false"; if (fwd_outputs_name_pos_map.find(tensor_wrapper_name) == - fwd_outputs_name_pos_map.end()) { + fwd_outputs_name_pos_map.end() && + !no_need_buffer_ins.count(tensor_wrapper_name)) { full_reserved = "true"; } const char* SET_TENSOR_WRAPPER_TEMPLATE = @@ -2064,7 +2081,7 @@ static std::string GenerateSingleOpBase( } else { const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE = " auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s);\n" - " if(%s.initialized()) %s[\"%s\"] = " + " if(%s.defined()) %s[\"%s\"] = " " egr::EagerUtils::TrySyncToVars(%s);\n"; generated_grad_function_body += paddle::string::Sprintf( DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, grad_input_name, @@ -2190,7 +2207,7 @@ static std::string GenerateSingleOpBase( grad_output_name, fwd_input_position); } else { const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE = - " if(%s.initialized()) %s[\"%s\"] = " + " if(%s.defined()) %s[\"%s\"] = " "{std::make_shared(egr::Controller::" "Instance().GenerateUniqueName())};\n"; generated_grad_function_body += paddle::string::Sprintf( @@ -2532,6 +2549,8 @@ static std::string GenerateGradNodeHeaderContents( for (const auto& iter : op_base_infos) { const std::map& grad_ins_fwd_slotname_map = iter.GetGradInsFwdSlotnameMap(); + const std::unordered_set& no_need_buffer_ins = + iter.GetNoNeedBufferInputs(); for (const auto& kv : grad_ins_fwd_slotname_map) { const std::string& tensor_wrapper_name = kv.second; @@ -2540,6 +2559,10 @@ static std::string GenerateGradNodeHeaderContents( std::string tensor_wrapper_arg_str; std::string tensor_wrapper_body_str; std::string full_reserved_str = "full_reserved"; + std::string no_need_buffer_str = "false"; + if (no_need_buffer_ins.count(tensor_wrapper_name)) { + no_need_buffer_str = "true"; + } if (duplicable_tensors.count(tensor_wrapper_name)) { const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = "const std::vector& %s"; @@ -2553,12 +2576,12 @@ static std::string GenerateGradNodeHeaderContents( const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = "for(const auto& eager_tensor : %s) {\n" - " %s.emplace_back( egr::TensorWrapper(eager_tensor, true " - "/*full_reserved*/) );\n" + " %s.emplace_back( egr::TensorWrapper(eager_tensor, %s " + "/*full_reserved*/, %s) );\n" " }\n"; tensor_wrapper_body_str = paddle::string::Sprintf( SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, - struct_tensor_wrapper_name); + struct_tensor_wrapper_name, full_reserved_str, no_need_buffer_str); const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = "for (auto tw: %s) {\n" @@ -2579,10 +2602,10 @@ static std::string GenerateGradNodeHeaderContents( TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = - "%s = egr::TensorWrapper(%s, %s /*full_reserved*/);\n"; + "%s = egr::TensorWrapper(%s, %s /*full_reserved*/, %s);\n"; tensor_wrapper_body_str = paddle::string::Sprintf( SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, - tensor_wrapper_name, full_reserved_str); + tensor_wrapper_name, full_reserved_str, no_need_buffer_str); const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n"; clear_tensor_wrappers_str += paddle::string::Sprintf( diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index f1e9c7e8f491b64df48858d3cdebc6d7bd82aa67..7b128bd3b0e4d0ca11900df9e5cfcf656f36221b 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -276,3 +276,11 @@ std::set special_inplace_op_set = { "sum", // `sum` op has duplicate input "assign", // output of `assign` op is in `op_passing_outs_map` }; + +// NOTE(pangyoki): Special no_need_buffer ops that are not supported in +// temporary. +// sequence_conv op will raise error to get no_need_buffer info during +// compiling. +std::set special_no_need_buffer_op_set = { + "sequence_conv", +}; diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index c54d3f02d43f07ab49a6ea2178ba9e786ba54011..99873eaa98870f07c299df222f6cdc5cf6d6629f 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -510,5 +510,19 @@ class TestContinuouslyInplace(unittest.TestCase): self.func_test_continuously_inplace() +class TestGetitemBeforeInplace(unittest.TestCase): + def test_getitem_before_inplace(self): + with _test_eager_guard(): + a = paddle.ones(shape=[4, 2, 3], dtype="float32") + a.stop_gradient = False + b = a**2 + b[0] = 3 + # getitem has no_need_buffer input + c = b[0:2] + loss = c.sum() + b[1] = 2 + loss.backward() + + if __name__ == '__main__': unittest.main()