未验证 提交 30122212 编写于 作者: F feifei-111 提交者: GitHub

[API]Support static branch in paddle.to_tensor (#45164)

* fix_shape
上级 bb6bd223
......@@ -129,9 +129,6 @@ def is_to_variable(node):
if utils.is_dygraph_api(node):
return api_name.endswith("to_variable")
if utils.is_paddle_api(node):
return api_name.endswith("to_tensor")
return False
......
......@@ -556,6 +556,7 @@ class PartialProgramLayer:
var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(),
var_desc.name(), var_desc.type(),
False)
var_base.stop_gradient = var.stop_gradient
out_varbase_map[var_desc.name()] = var_base
return var_base
......
......@@ -29,7 +29,7 @@ from paddle.jit import to_static
def dyfunc_generator():
for i in range(100):
yield paddle.to_tensor([i] * 10)
yield paddle.fluid.dygraph.to_variable([i] * 10)
def main_func():
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# to_tensor api will create 1 less op now, this test was changed
from __future__ import print_function
import numpy as np
......@@ -298,7 +300,7 @@ class TestTensorShapeBasic2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tensor_shape_2
def _set_expected_op_num(self):
self.expected_op_num = 2
self.expected_op_num = 1
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
......@@ -347,7 +349,7 @@ class TestTupleShape1(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 2
......@@ -362,7 +364,7 @@ class TestTupleShape2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tuple_shape_2
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
......@@ -375,7 +377,7 @@ class TestTupleShape3(TestTensorShapeBasic):
self.dygraph_func = dyfunc_tuple_shape_3
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 2
......@@ -388,7 +390,7 @@ class TestPaddleShapeApi(TestTensorShapeBasic):
self.dygraph_func = dyfunc_paddle_shape_api
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_op_num = 5
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
......@@ -490,7 +492,7 @@ class TestTensorShapeInWhile4(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_while_4
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_op_num = 4
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
......@@ -554,7 +556,7 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_op_num = 4
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
......@@ -602,7 +604,7 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic):
self.dygraph_func = dyfunc_change_shape_after_assign
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 1
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import paddle
import unittest
import os
import tempfile
import paddle.inference as paddle_infer
from paddle.fluid.framework import program_guard, Program
import numpy as np
from paddle.fluid import core
def case0(x):
a = paddle.to_tensor([1.0, 2.0, 3.0], dtype="int64")
return a
def case1(x):
paddle.set_default_dtype("float64")
a = paddle.to_tensor([1.0, 2.0, 3.0], stop_gradient=False)
return a
def case2(x):
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
else:
place = paddle.CPUPlace()
a = paddle.to_tensor([1.0, 2.0, 3.0],
place=place,
dtype="int64",
stop_gradient=False)
return a
def case3(x):
paddle.set_default_dtype("float64")
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
else:
place = paddle.CPUPlace()
a = paddle.to_tensor([1.0, 2.0, 3.0], place=place)
return a
class TestToTensorReturnVal(unittest.TestCase):
def test_to_tensor_badreturn(self):
paddle.disable_static()
x = paddle.to_tensor([3])
a = paddle.jit.to_static(case0)(x)
b = case0(x)
self.assertTrue(a.dtype == b.dtype)
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
a = paddle.jit.to_static(case1)(x)
b = case1(x)
self.assertTrue(a.dtype == b.dtype)
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
a = paddle.jit.to_static(case2)(x)
b = case2(x)
self.assertTrue(a.dtype == b.dtype)
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
a = paddle.jit.to_static(case3)(x)
b = case3(x)
self.assertTrue(a.dtype == b.dtype)
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))
class TestStatic(unittest.TestCase):
def test_static(self):
paddle.enable_static()
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
else:
place = paddle.CPUPlace()
x = paddle.to_tensor(paddle.randn([5, 2]),
dtype='float64',
stop_gradient=False,
place=place)
out = paddle.static.nn.fc(x, 1)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[x, out])
if __name__ == '__main__':
unittest.main()
......@@ -270,65 +270,7 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None):
return out
@dygraph_only
def to_tensor(data, dtype=None, place=None, stop_gradient=True):
r"""
Constructs a ``paddle.Tensor`` from ``data`` ,
which can be scalar, tuple, list, numpy\.ndarray, paddle\.Tensor.
If the ``data`` is already a Tensor, copy will be performed and return a new tensor.
If you only want to change stop_gradient property, please call ``Tensor.stop_gradient = stop_gradient`` directly.
Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
except for python float number which gets dtype from ``get_default_type`` .
place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
Returns:
Tensor: A Tensor constructed from ``data`` .
Examples:
.. code-block:: python
import paddle
type(paddle.to_tensor(1))
# <class 'paddle.Tensor'>
paddle.to_tensor(1)
# Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [1])
x = paddle.to_tensor(1, stop_gradient=False)
print(x)
# Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=False,
# [1])
paddle.to_tensor(x) # A new tensor will be created with default stop_gradient=True
# Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [1])
paddle.to_tensor([[0.1, 0.2], [0.3, 0.4]], place=paddle.CPUPlace(), stop_gradient=False)
# Tensor(shape=[2, 2], dtype=float32, place=CPUPlace, stop_gradient=False,
# [[0.10000000, 0.20000000],
# [0.30000001, 0.40000001]])
type(paddle.to_tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64'))
# <class 'paddle.Tensor'>
paddle.to_tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64')
# Tensor(shape=[2, 2], dtype=complex64, place=CPUPlace, stop_gradient=True,
# [[(1+1j), (2+0j)],
# [(3+2j), (4+0j)]])
"""
def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
place = _get_paddle_place(place)
if place is None:
place = _current_expected_place()
......@@ -417,6 +359,124 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
stop_gradient=stop_gradient)
def to_tensor(data, dtype=None, place=None, stop_gradient=True):
r"""
Constructs a ``paddle.Tensor`` from ``data`` ,
which can be scalar, tuple, list, numpy\.ndarray, paddle\.Tensor.
If the ``data`` is already a Tensor, copy will be performed and return a new tensor.
If you only want to change stop_gradient property, please call ``Tensor.stop_gradient = stop_gradient`` directly.
Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
except for python float number which gets dtype from ``get_default_type`` .
place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
Returns:
Tensor: A Tensor constructed from ``data`` .
Examples:
.. code-block:: python
import paddle
type(paddle.to_tensor(1))
# <class 'paddle.Tensor'>
paddle.to_tensor(1)
# Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [1])
x = paddle.to_tensor(1, stop_gradient=False)
print(x)
# Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=False,
# [1])
paddle.to_tensor(x) # A new tensor will be created with default stop_gradient=True
# Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=True,
# [1])
paddle.to_tensor([[0.1, 0.2], [0.3, 0.4]], place=paddle.CPUPlace(), stop_gradient=False)
# Tensor(shape=[2, 2], dtype=float32, place=CPUPlace, stop_gradient=False,
# [[0.10000000, 0.20000000],
# [0.30000001, 0.40000001]])
type(paddle.to_tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64'))
# <class 'paddle.Tensor'>
paddle.to_tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64')
# Tensor(shape=[2, 2], dtype=complex64, place=CPUPlace, stop_gradient=True,
# [[(1+1j), (2+0j)],
# [(3+2j), (4+0j)]])
"""
if _non_static_mode():
return _to_tensor_non_static(data, dtype, place, stop_gradient)
# call assign for static graph
else:
def call_assign(data, dtype=None, stop_grandient=None):
if isinstance(data,
(Variable, core.VarBase)) and (dtype is None or dtype
== data.dtype):
output = data
else:
if dtype:
target_dtype = convert_dtype(dtype)
elif hasattr(data, 'dtype'):
target_dtype = convert_dtype(data.dtype)
else:
target_dtype = convert_dtype(paddle.get_default_dtype())
if not isinstance(data, np.ndarray):
if np.isscalar(data) and not isinstance(data, str):
data = np.array([data])
elif isinstance(data, (list, tuple)):
if any(isinstance(x, Variable) for x in data):
to_stack_list = [None] * len(data)
for idx, d in enumerate(data):
to_stack_list[idx] = call_assign(
d, dtype, stop_gradient)
data = paddle.stack(to_stack_list)
data = paddle.squeeze(data, -1)
output = assign(data)
if target_dtype is not None and convert_dtype(
output.dtype) != target_dtype:
output = paddle.cast(output, target_dtype)
output.stop_gradient = stop_gradient
return output
place = _get_paddle_place(place)
if place is None:
place = _current_expected_place()
elif not isinstance(
place,
(core.Place, core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace,
core.NPUPlace, core.XPUPlace, core.MLUPlace, core.CustomPlace)):
raise ValueError(
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace"
)
import re
re_exp = re.compile(r'[(](.*?)[)]', re.S)
place_str = re.findall(re_exp, str(place))[0]
with paddle.static.device_guard(place_str):
return call_assign(data, dtype, stop_gradient)
def full_like(x, fill_value, dtype=None, name=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册