From 13d01e6e27e8673967b54fa32317f2bc5528efac Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Fri, 15 Jul 2022 18:30:00 +0800 Subject: [PATCH] [Eager] eager variable back sync (#44343) * eager variable back sync --- paddle/fluid/eager/eager_tensor.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/paddle/fluid/eager/eager_tensor.h b/paddle/fluid/eager/eager_tensor.h index d61a55b6dea..8026b8e3684 100644 --- a/paddle/fluid/eager/eager_tensor.h +++ b/paddle/fluid/eager/eager_tensor.h @@ -209,6 +209,7 @@ class EagerVariable final { if (tensor.defined()) { if (tensor.is_dense_tensor()) { ConstructVariableFromTensor(tensor); + src_tensor_ = tensor.impl(); } else if (tensor.is_selected_rows()) { ConstructVariableFromTensor(tensor); } else if (IsVariableCompatTensor(tensor) && @@ -229,6 +230,19 @@ class EagerVariable final { } } + ~EagerVariable() { + if (src_tensor_) { + auto* framework_tensor = var_.GetMutable(); + auto tensor_dense = static_cast(src_tensor_.get()); + if (framework_tensor->memory_size() > 0 && + (!paddle::platform::is_same_place(framework_tensor->place(), + tensor_dense->place()) || + framework_tensor->dtype() != tensor_dense->dtype())) { + tensor_dense->ShareBufferWith(*framework_tensor); + } + } + } + /** Part 11: Construct paddle::framework::Variable with phi::Tensor **/ std::shared_ptr GetTensorBase() { // Construct allocation only once. @@ -304,5 +318,6 @@ class EagerVariable final { private: std::string name_{""}; paddle::framework::Variable var_; + std::shared_ptr src_tensor_; }; } // namespace egr -- GitLab