diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index fb86c5da6856c51c446455a4112f84e31c4d76a6..0d1d3ab7225227d44d98a11f8a6511d0d8a43c19 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -1249,9 +1249,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): is_optional = (name in self.optional_inputs) if is_optional: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" else: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" grad_api_args[grad_api_position] = transformed_tensor_name get_grad_in_args_list.append(tensor_wrapper_recover_str) diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 0d07f780dda9d1d88cd65c66b1ae0e31ee45cff6..70fc4afa0ac7136bc440ffe69210e170e1447f32 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -87,7 +87,7 @@ class GradSlotMeta { std::shared_ptr meta_ = nullptr; }; -class GradNodeBase { +class GradNodeBase : public std::enable_shared_from_this { public: GradNodeBase() { VLOG(6) << "Construct GradNodeBase"; } GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num); diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index e7886339f06b155cecf7ef2a7d562110167edd30..dc4cf379390f1bcf9987b7e331c1e453c48e0de3 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -95,18 +95,19 @@ class TensorWrapper { } check_inplace_version(); + // if it's full_reserved just return the full copy of tensor - if (full_reserved_) { - return intermidiate_tensor_; - } else { + paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; + if (!full_reserved_) { std::shared_ptr new_grad_node = grad_node; auto p_ab_autograd_meta = std::make_shared(Edge(new_grad_node, out_rank_info_)); - intermidiate_tensor_.set_autograd_meta( + recovered_tensor.set_autograd_meta( std::static_pointer_cast( p_ab_autograd_meta)); - return intermidiate_tensor_; } + + return recovered_tensor; } void check_inplace_version() {