未验证 提交 13d01e6e 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] eager variable back sync (#44343)

* eager variable back sync
上级 87443831
......@@ -209,6 +209,7 @@ class EagerVariable final {
if (tensor.defined()) {
if (tensor.is_dense_tensor()) {
ConstructVariableFromTensor<phi::DenseTensor>(tensor);
src_tensor_ = tensor.impl();
} else if (tensor.is_selected_rows()) {
ConstructVariableFromTensor<phi::SelectedRows>(tensor);
} else if (IsVariableCompatTensor(tensor) &&
......@@ -229,6 +230,19 @@ class EagerVariable final {
}
}
~EagerVariable() {
if (src_tensor_) {
auto* framework_tensor = var_.GetMutable<phi::DenseTensor>();
auto tensor_dense = static_cast<phi::DenseTensor*>(src_tensor_.get());
if (framework_tensor->memory_size() > 0 &&
(!paddle::platform::is_same_place(framework_tensor->place(),
tensor_dense->place()) ||
framework_tensor->dtype() != tensor_dense->dtype())) {
tensor_dense->ShareBufferWith(*framework_tensor);
}
}
}
/** Part 11: Construct paddle::framework::Variable with phi::Tensor **/
std::shared_ptr<phi::TensorBase> GetTensorBase() {
// Construct allocation only once.
......@@ -304,5 +318,6 @@ class EagerVariable final {
private:
std::string name_{""};
paddle::framework::Variable var_;
std::shared_ptr<phi::TensorBase> src_tensor_;
};
} // namespace egr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册