From ac4a422d5a741093703e0c510a287f7ef8c5c274 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 4 Apr 2022 20:54:16 +0800 Subject: [PATCH] [Eager]Fix tile API final_state and Backward bug (#41385) * [Eager]Fix tile API final_state bug * fix backward bug --- paddle/fluid/eager/backward.cc | 6 +++--- paddle/fluid/pybind/eager_utils.cc | 6 ++++++ python/paddle/tensor/manipulation.py | 5 +++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 3e86ad6f59b..d5397e20e7d 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -580,8 +580,9 @@ std::vector RunBackward( node_input_buffers_dict[grad_node] = std::make_unique(grad_node->InputMeta()); } - - if (grad_tensors.size() > 0) { + bool copy_from_grad_t = + grad_tensors.size() > 0 && grad_tensors[i].initialized(); + if (copy_from_grad_t) { PADDLE_ENFORCE( grad_tensors.size() == tensors.size(), paddle::platform::errors::Fatal( @@ -594,7 +595,6 @@ std::vector RunBackward( // Deep copy node_input_buffers_dict[grad_node]->CopyValueFromTensor( input_info.first, input_info.second, grad_tensors[i]); - } else { VLOG(6) << "Fill grad input tensor " << i << " with 1.0"; // Initialize tensor with 1.0 diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index a6047f36ad9..ef1359ac047 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -213,6 +213,9 @@ std::vector CastPyArg2VectorOfTensor( if (PyObject_IsInstance(item, reinterpret_cast(p_tensor_type))) { result.emplace_back(reinterpret_cast(item)->tensor); + } else if (item == Py_None) { + // emplace empty Tensor for None + result.emplace_back(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " @@ -229,6 +232,9 @@ std::vector CastPyArg2VectorOfTensor( if (PyObject_IsInstance(item, reinterpret_cast(p_tensor_type))) { result.emplace_back(reinterpret_cast(item)->tensor); + } else if (item == Py_None) { + // emplace empty Tensor for None + result.emplace_back(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 92fec23c6c7..f1e2938b205 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1751,6 +1751,11 @@ def tile(x, repeat_times, name=None): # [[1, 2, 3, 1, 2, 3]] """ if in_dygraph_mode(): + if isinstance(repeat_times, core.eager.Tensor): + assert (repeat_times.ndim == 1, + "Only support ndim == 1 while repeat_times is a Tensor.") + repeat_times = repeat_times.numpy().tolist() + return _C_ops.final_state_tile(x, repeat_times) if _in_legacy_dygraph(): -- GitLab