From 46371c534ce89e6e94357e9fb5eb182ed4598a3b Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sun, 4 Dec 2022 19:51:02 +0800 Subject: [PATCH] [Eager] fix set_value logic when input's dtype is different (#48519) * [Eager] fix set_value logic when input's dtype is different * value_tensor * fix set_value logic when input's dtype is different --- paddle/fluid/pybind/eager_method.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 6f0bd5fb16..8c7b6296eb 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1227,7 +1227,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, Py_TYPE(value_obj))); } } - { // Release gil and do tracing py::gil_scoped_release release; @@ -1242,6 +1241,9 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, self->tensor.name(), self->tensor, amp_dtype, "set_value"); value_tensor = egr::EagerAmpAutoCast( value_tensor.name(), value_tensor, amp_dtype, "set_value"); + if (self->tensor.dtype() != value_tensor.dtype()) { + value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); + } } self->tensor = set_value__dygraph_function( self->tensor, value_tensor, {}, {}, {}, attrs); -- GitLab