未验证 提交 ac4a422d 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Fix tile API final_state and Backward bug (#41385)

* [Eager]Fix tile API final_state bug

* fix backward bug
上级 489b8a88
......@@ -580,8 +580,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
node_input_buffers_dict[grad_node] =
std::make_unique<GradTensorHolder>(grad_node->InputMeta());
}
if (grad_tensors.size() > 0) {
bool copy_from_grad_t =
grad_tensors.size() > 0 && grad_tensors[i].initialized();
if (copy_from_grad_t) {
PADDLE_ENFORCE(
grad_tensors.size() == tensors.size(),
paddle::platform::errors::Fatal(
......@@ -594,7 +595,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Deep copy
node_input_buffers_dict[grad_node]->CopyValueFromTensor(
input_info.first, input_info.second, grad_tensors[i]);
} else {
VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
// Initialize tensor with 1.0
......
......@@ -213,6 +213,9 @@ std::vector<paddle::experimental::Tensor> CastPyArg2VectorOfTensor(
if (PyObject_IsInstance(item,
reinterpret_cast<PyObject*>(p_tensor_type))) {
result.emplace_back(reinterpret_cast<TensorObject*>(item)->tensor);
} else if (item == Py_None) {
// emplace empty Tensor for None
result.emplace_back();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
......@@ -229,6 +232,9 @@ std::vector<paddle::experimental::Tensor> CastPyArg2VectorOfTensor(
if (PyObject_IsInstance(item,
reinterpret_cast<PyObject*>(p_tensor_type))) {
result.emplace_back(reinterpret_cast<TensorObject*>(item)->tensor);
} else if (item == Py_None) {
// emplace empty Tensor for None
result.emplace_back();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
......
......@@ -1751,6 +1751,11 @@ def tile(x, repeat_times, name=None):
# [[1, 2, 3, 1, 2, 3]]
"""
if in_dygraph_mode():
if isinstance(repeat_times, core.eager.Tensor):
assert (repeat_times.ndim == 1,
"Only support ndim == 1 while repeat_times is a Tensor.")
repeat_times = repeat_times.numpy().tolist()
return _C_ops.final_state_tile(x, repeat_times)
if _in_legacy_dygraph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册