未验证 提交 066633a5 编写于 作者: W wanghuancoder 提交者: GitHub

fix inference not have python (#46085)

上级 3d7e2118
...@@ -27,9 +27,11 @@ ...@@ -27,9 +27,11 @@
#pragma once #pragma once
#include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/saved_tensors_hooks.h"
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#ifndef PADDLE_NO_PYTHON
#include "paddle/fluid/eager/saved_tensors_hooks.h"
#endif
namespace egr { namespace egr {
class TensorWrapper { class TensorWrapper {
...@@ -70,6 +72,7 @@ class TensorWrapper { ...@@ -70,6 +72,7 @@ class TensorWrapper {
"Unrecognized tensor type for no_need_buffer feature")); "Unrecognized tensor type for no_need_buffer feature"));
} }
} else { } else {
#ifndef PADDLE_NO_PYTHON
if (SavedTensorsHooks::GetInstance().IsEnable() && if (SavedTensorsHooks::GetInstance().IsEnable() &&
tensor.is_dense_tensor()) { tensor.is_dense_tensor()) {
phi::DenseTensor* dense_tensor = phi::DenseTensor* dense_tensor =
...@@ -82,8 +85,11 @@ class TensorWrapper { ...@@ -82,8 +85,11 @@ class TensorWrapper {
unpack_hook_ = SavedTensorsHooks::GetInstance().GetUnPackHook(); unpack_hook_ = SavedTensorsHooks::GetInstance().GetUnPackHook();
packed_value_ = reinterpret_cast<PyObject*>((*pack_hook)(tensor)); packed_value_ = reinterpret_cast<PyObject*>((*pack_hook)(tensor));
} else { } else {
#endif
intermidiate_tensor_.set_impl(tensor.impl()); intermidiate_tensor_.set_impl(tensor.impl());
#ifndef PADDLE_NO_PYTHON
} }
#endif
} }
if (VLOG_IS_ON(7)) { if (VLOG_IS_ON(7)) {
...@@ -99,7 +105,7 @@ class TensorWrapper { ...@@ -99,7 +105,7 @@ class TensorWrapper {
weak_grad_node_ = tensor_autograd_meta->GetMutableGradNode(); weak_grad_node_ = tensor_autograd_meta->GetMutableGradNode();
} }
} }
#ifndef PADDLE_NO_PYTHON
TensorWrapper(const TensorWrapper& other) { TensorWrapper(const TensorWrapper& other) {
no_need_buffer_ = other.no_need_buffer_; no_need_buffer_ = other.no_need_buffer_;
intermidiate_tensor_ = other.intermidiate_tensor_; intermidiate_tensor_ = other.intermidiate_tensor_;
...@@ -122,7 +128,7 @@ class TensorWrapper { ...@@ -122,7 +128,7 @@ class TensorWrapper {
} }
~TensorWrapper() { Py_XDECREF(packed_value_); } ~TensorWrapper() { Py_XDECREF(packed_value_); }
#endif
paddle::experimental::Tensor recover() { paddle::experimental::Tensor recover() {
VLOG(6) << "Recover tensor: " << intermidiate_tensor_.name() VLOG(6) << "Recover tensor: " << intermidiate_tensor_.name()
<< " for wrapper"; << " for wrapper";
...@@ -130,7 +136,7 @@ class TensorWrapper { ...@@ -130,7 +136,7 @@ class TensorWrapper {
VLOG(6) << "Return NULL tensor Here. "; VLOG(6) << "Return NULL tensor Here. ";
return paddle::experimental::Tensor(); return paddle::experimental::Tensor();
} }
#ifndef PADDLE_NO_PYTHON
if (packed_value_ && unpack_hook_) { if (packed_value_ && unpack_hook_) {
auto tensor_unpacked = auto tensor_unpacked =
(*unpack_hook_)(reinterpret_cast<void*>(packed_value_)); (*unpack_hook_)(reinterpret_cast<void*>(packed_value_));
...@@ -139,8 +145,11 @@ class TensorWrapper { ...@@ -139,8 +145,11 @@ class TensorWrapper {
static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get()) static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get())
->ResetHolder(src_dense_tensor->MoveMemoryHolder()); ->ResetHolder(src_dense_tensor->MoveMemoryHolder());
} else { } else {
#endif
check_inplace_version(); check_inplace_version();
#ifndef PADDLE_NO_PYTHON
} }
#endif
paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;
...@@ -214,7 +223,12 @@ class TensorWrapper { ...@@ -214,7 +223,12 @@ class TensorWrapper {
paddle::experimental::Tensor intermidiate_tensor_; paddle::experimental::Tensor intermidiate_tensor_;
std::weak_ptr<egr::GradNodeBase> weak_grad_node_; std::weak_ptr<egr::GradNodeBase> weak_grad_node_;
uint32_t inplace_version_snapshot_ = 0; uint32_t inplace_version_snapshot_ = 0;
#ifndef PADDLE_NO_PYTHON
PyObject* packed_value_{nullptr}; PyObject* packed_value_{nullptr};
std::shared_ptr<UnPackHookBase> unpack_hook_; std::shared_ptr<UnPackHookBase> unpack_hook_;
#else
void* packed_value_{nullptr};
std::shared_ptr<void> unpack_hook_;
#endif
}; };
} // namespace egr } // namespace egr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册