From a5e00bb7239956e10766c3b89d1919416af9c646 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Sat, 2 Apr 2022 16:54:31 +0800 Subject: [PATCH] [DoubleGrad PR #6] Fixed issues with TensorWrapper::recover() interface (#41287) --- .../final_state_generator/eager_gen.py | 4 ++-- paddle/fluid/eager/grad_node_info.h | 2 +- paddle/fluid/eager/tensor_wrapper.h | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index fb86c5da68..0d1d3ab722 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -1249,9 +1249,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): is_optional = (name in self.optional_inputs) if is_optional: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" else: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" grad_api_args[grad_api_position] = transformed_tensor_name get_grad_in_args_list.append(tensor_wrapper_recover_str) diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 0d07f780dd..70fc4afa0a 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -87,7 +87,7 @@ class GradSlotMeta { std::shared_ptr meta_ = nullptr; }; -class GradNodeBase { +class GradNodeBase : public std::enable_shared_from_this { public: GradNodeBase() { VLOG(6) << "Construct GradNodeBase"; } GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num); diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index e7886339f0..dc4cf37939 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -95,18 +95,19 @@ class TensorWrapper { } check_inplace_version(); + // if it's full_reserved just return the full copy of tensor - if (full_reserved_) { - return intermidiate_tensor_; - } else { + paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; + if (!full_reserved_) { std::shared_ptr new_grad_node = grad_node; auto p_ab_autograd_meta = std::make_shared(Edge(new_grad_node, out_rank_info_)); - intermidiate_tensor_.set_autograd_meta( + recovered_tensor.set_autograd_meta( std::static_pointer_cast( p_ab_autograd_meta)); - return intermidiate_tensor_; } + + return recovered_tensor; } void check_inplace_version() { -- GitLab