未验证 提交 96c6dde1 编写于 作者: F feifei-111 提交者: GitHub

[ Dy2Static ] Fix dy2staic: cpu, cuda, assign([Var, Var, ]) (#44731)

* fix dy2staic: cpu, cuda, assign([Var, Var, ])

* fix1

* fix2
Co-authored-by: Nxiongkun <xiongkun03@baidu.com>
上级 5eaa55da
......@@ -128,6 +128,24 @@ def monkey_patch_variable():
var.stop_gradient = True
return var
@static_only
def cpu(self):
"""
Variable should not have cpu() and cuda() interface.
But this interface can greatly facilitate dy2static.
We do nothing here.
"""
return self
@static_only
def cuda(self):
"""
Variable should not have cpu() and cuda() interface.
But this interface can greatly facilitate dy2static.
We do nothing here.
"""
return self
def astype(self, dtype):
"""
**Notes**:
......@@ -368,6 +386,8 @@ def monkey_patch_variable():
# b=-a
('__neg__', _neg_),
('astype', astype),
('cpu', cpu),
('cuda', cuda),
('append', append),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle import fluid
import numpy as np
class TestCpuCuda(unittest.TestCase):
def test_cpu_cuda(self):
def func(x):
x = paddle.to_tensor([1, 2, 3, 4])
x = x.cuda()
x = x.cpu()
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
print(paddle.jit.to_static(func)(x))
class TestToTensor(unittest.TestCase):
def test_to_tensor_with_variable_list(self):
def func(x):
ones = paddle.to_tensor([1])
twos = paddle.to_tensor([2])
x = paddle.to_tensor([ones, twos, 3, 4])
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
self.assertTrue(
np.allclose(
paddle.jit.to_static(func)(x).numpy(), np.array([1, 2, 3, 4])))
class TestToTensor1(unittest.TestCase):
def test_to_tensor_with_variable_list(self):
def func(x):
ones = paddle.to_tensor([1])
twos = paddle.to_tensor([2])
""" we ignore the [3] and [4], they will be assign to a variable, and is regard as scalar.
TODO: deal with this case after 0-dim tensor is developed.
"""
x = paddle.to_tensor([ones, twos, [3], [4]])
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
self.assertTrue(
np.allclose(
paddle.jit.to_static(func)(x).numpy(), np.array([1, 2, 3, 4])))
class TestToTensor2(unittest.TestCase):
def test_to_tensor_with_variable_list(self):
def func(x):
x = paddle.to_tensor([[1], [2], [3], [4]])
return x
x = paddle.to_tensor([3])
print(paddle.jit.to_static(func).code)
self.assertTrue(
np.allclose(
paddle.jit.to_static(func)(x).numpy(),
np.array([[1], [2], [3], [4]])))
if __name__ == '__main__':
unittest.main()
......@@ -1535,11 +1535,32 @@ def assign(x, output=None):
inputs={'X': [input]},
outputs={'Out': [output]})
elif isinstance(input, np.ndarray):
# Not support [var, var, ...] currently.
# We now support the form of [var, VAR...] if the Var.shape=[1,]
if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input):
# We only deal with the case where the list is nested one level, convert all scalars into variables, and then use stack to process. It is necessary to ensure the consistency of types.
if not all(
[x.shape == (1, ) for x in input if isinstance(x, Variable)]):
raise TypeError(
"Unsupport paddle.assign([Variable, Variable...]) with non-scalar variable."
)
def convert_scalar(x):
if not isinstance(x, Variable):
return assign(x)
return x
to_stack_list = list(map(convert_scalar, input))
ret = paddle.stack(to_stack_list)
ret = paddle.squeeze(ret, -1)
return ret
if input.dtype == 'object':
""" may be this form [[Var], [Var], [3], [4]], we reject them.
"""
raise TypeError(
"Required type(input) numpy.ndarray, but found `list(Variable)` in input."
"The type of received input == `object`, it is not supported to convert to tensor, such as [[Var], [Var], [3], [4]]"
)
dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == core.VarDesc.VarType.FP64:
# Setting FP64 numpy data is not supported in Paddle, so we
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册