未验证 提交 46d13090 编写于 作者: F feifei-111 提交者: GitHub

[API] fixed a bug in to_tensor which occurs when input is complex and improved code style (#45305)

* fixed1

* fix 2

* fixre

* test complext var

* delete logic no needed

* fix to_tensor_static

* code style

* del
上级 23d2b079
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# to_tensor api will create 1 less op now, this test was changed
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
......
...@@ -62,6 +62,19 @@ def case3(x): ...@@ -62,6 +62,19 @@ def case3(x):
return a return a
def case4(x):
paddle.set_default_dtype("float64")
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
else:
place = paddle.CPUPlace()
a = paddle.to_tensor([1.0], place=place, dtype="float64")
b = paddle.to_tensor([2], place=place, stop_gradient=False, dtype="int64")
c = paddle.to_tensor([a, b, [1]], dtype="float32")
return c
class TestToTensorReturnVal(unittest.TestCase): class TestToTensorReturnVal(unittest.TestCase):
def test_to_tensor_badreturn(self): def test_to_tensor_badreturn(self):
...@@ -92,6 +105,12 @@ class TestToTensorReturnVal(unittest.TestCase): ...@@ -92,6 +105,12 @@ class TestToTensorReturnVal(unittest.TestCase):
self.assertTrue(a.stop_gradient == b.stop_gradient) self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place)) self.assertTrue(a.place._equals(b.place))
a = paddle.jit.to_static(case4)(x)
b = case4(x)
self.assertTrue(a.dtype == b.dtype)
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
class TestStatic(unittest.TestCase): class TestStatic(unittest.TestCase):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import math import math
import re
from paddle.common_ops_import import fill_constant from paddle.common_ops_import import fill_constant
from ..fluid.layers import utils from ..fluid.layers import utils
from ..static import Variable, device_guard from ..static import Variable, device_guard
...@@ -271,16 +272,6 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None): ...@@ -271,16 +272,6 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None):
def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True): def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
place = _get_paddle_place(place)
if place is None:
place = _current_expected_place()
elif not isinstance(
place,
(core.Place, core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace,
core.NPUPlace, core.XPUPlace, core.MLUPlace, core.CustomPlace)):
raise ValueError(
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace"
)
if not isinstance(data, np.ndarray): if not isinstance(data, np.ndarray):
...@@ -359,6 +350,51 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True): ...@@ -359,6 +350,51 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
stop_gradient=stop_gradient) stop_gradient=stop_gradient)
def _to_tensor_static(data, dtype=None, stop_gradient=None):
if isinstance(data, Variable) and (dtype is None or dtype == data.dtype):
output = data
else:
if dtype:
target_dtype = dtype
elif hasattr(data, 'dtype'):
target_dtype = data.dtype
else:
target_dtype = paddle.get_default_dtype()
target_dtype = convert_dtype(target_dtype)
if not isinstance(data, np.ndarray):
if np.isscalar(data) and not isinstance(data, str):
data = np.array([data])
elif isinstance(data, (list, tuple)):
data = np.array(data)
if isinstance(data, np.ndarray) and len(data.shape) > 0 and any(
isinstance(x, Variable) for x in data):
if not all(
[x.shape == (1, ) for x in data if isinstance(x, Variable)]):
raise TypeError(
"Unsupport paddle.to_tensor([Variable, Variable...]) with non-scalar variable."
)
to_stack_list = [None] * data.shape[0]
for idx, d in enumerate(data):
to_stack_list[idx] = _to_tensor_static(d, dtype, stop_gradient)
data = paddle.stack(to_stack_list)
data = paddle.squeeze(data, -1)
if not isinstance(data, Variable):
output = assign(data)
else:
output = data
if convert_dtype(output.dtype) != target_dtype:
output = paddle.cast(output, target_dtype)
output.stop_gradient = stop_gradient
return output
def to_tensor(data, dtype=None, place=None, stop_gradient=True): def to_tensor(data, dtype=None, place=None, stop_gradient=True):
r""" r"""
Constructs a ``paddle.Tensor`` from ``data`` , Constructs a ``paddle.Tensor`` from ``data`` ,
...@@ -417,64 +453,20 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -417,64 +453,20 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
# [[(1+1j), (2+0j)], # [[(1+1j), (2+0j)],
# [(3+2j), (4+0j)]]) # [(3+2j), (4+0j)]])
""" """
place = _get_paddle_place(place)
if place is None:
place = _current_expected_place()
if _non_static_mode(): if _non_static_mode():
return _to_tensor_non_static(data, dtype, place, stop_gradient) return _to_tensor_non_static(data, dtype, place, stop_gradient)
# call assign for static graph # call assign for static graph
else: else:
re_exp = re.compile(r'[(](.+?)[)]', re.S)
def call_assign(data, dtype=None, stop_grandient=None):
if isinstance(data,
(Variable, core.VarBase)) and (dtype is None or dtype
== data.dtype):
output = data
else:
if dtype:
target_dtype = convert_dtype(dtype)
elif hasattr(data, 'dtype'):
target_dtype = convert_dtype(data.dtype)
else:
target_dtype = convert_dtype(paddle.get_default_dtype())
if not isinstance(data, np.ndarray):
if np.isscalar(data) and not isinstance(data, str):
data = np.array([data])
elif isinstance(data, (list, tuple)):
if any(isinstance(x, Variable) for x in data):
to_stack_list = [None] * len(data)
for idx, d in enumerate(data):
to_stack_list[idx] = call_assign(
d, dtype, stop_gradient)
data = paddle.stack(to_stack_list)
data = paddle.squeeze(data, -1)
output = assign(data)
if target_dtype is not None and convert_dtype(
output.dtype) != target_dtype:
output = paddle.cast(output, target_dtype)
output.stop_gradient = stop_gradient
return output
place = _get_paddle_place(place)
if place is None:
place = _current_expected_place()
elif not isinstance(
place,
(core.Place, core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace,
core.NPUPlace, core.XPUPlace, core.MLUPlace, core.CustomPlace)):
raise ValueError(
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace"
)
import re
re_exp = re.compile(r'[(](.*?)[)]', re.S)
place_str = re.findall(re_exp, str(place))[0] place_str = re.findall(re_exp, str(place))[0]
with paddle.static.device_guard(place_str): with paddle.static.device_guard(place_str):
return call_assign(data, dtype, stop_gradient) return _to_tensor_static(data, dtype, stop_gradient)
def full_like(x, fill_value, dtype=None, name=None): def full_like(x, fill_value, dtype=None, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册