From ceb3382bc31c3748bd5077274bde976c1ed11210 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 5 Apr 2022 10:03:58 +0800 Subject: [PATCH] [Eager] Fix empty tensor Initializer bug with shape=[] (#41374) * [Eager] Fix empty tensor Initializer bug with shape=[] * [Eager] Fix empty tensor Initializer bug with shape=[] * ignore two unittest * fix unittest --- paddle/fluid/pybind/eager.cc | 19 ++++++++++++++----- paddle/fluid/pybind/eager_method.cc | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index e39a9199b1c..1f72af8d79d 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -72,11 +72,20 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name, } if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) { // TODO(jiabin): Maybe support LOD later - std::shared_ptr dense_tensor = - std::make_shared( - phi::make_intrusive(place), - phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), - ddims)); + std::shared_ptr dense_tensor = nullptr; + if (dims.empty()) { + std::shared_ptr allocation_ptr = nullptr; + dense_tensor = std::make_shared( + allocation_ptr, + phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), + ddims)); + } else { + // TODO(dev): we need enhance check for ddims. + dense_tensor = std::make_shared( + phi::make_intrusive(place), + phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), + ddims)); + } self->tensor.set_impl(dense_tensor); } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) { std::shared_ptr tensor = diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 74b866355f0..9f75b5c70b2 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -125,6 +125,7 @@ class PyTensorVoidHook : public egr::TensorVoidHook { extern void InitTensorWithNumpyValue(TensorObject* self, const pybind11::object& array, + const paddle::platform::Place& place, bool zero_copy); extern PyTypeObject* p_tensor_type; -- GitLab