未验证 提交 f1c28236 编写于 作者: C chentianyu03 提交者: GitHub

layer.to api support numpy.dtype and paddle.dtype (#38018)

* layer.to api support numpy.dtype and paddle.dtype

* skip the layer to eaxmple code executing for env not support
上级 a3c8abc7
......@@ -37,6 +37,7 @@ from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.core import VarDesc
from paddle.fluid.dygraph import no_grad
import paddle.utils.deprecated as deprecated
......@@ -1495,7 +1496,7 @@ class Layer(object):
If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
index of the GPUs or XPUs. Default: None.
dtype(str|core.VarDesc.VarType|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
......@@ -1506,7 +1507,7 @@ class Layer(object):
Examples:
.. code-block:: python
# required: gpu
# required: skip
import paddle
linear=paddle.nn.Linear(2, 2)
......@@ -1563,7 +1564,7 @@ class Layer(object):
if dtype is None:
dtype = t.dtype
if type(dtype) is str:
if type(dtype) is not VarDesc.VarType:
dtype = convert_np_dtype_to_dtype_(dtype)
# 1. gpu place need to determine whether the memory is sufficient for allocation:
......
......@@ -403,6 +403,52 @@ class TestLayerTo(unittest.TestCase):
self.assertRaises(AssertionError, self.linear.to, blocking=1)
def test_to_api_paddle_dtype(self):
self.linear.to(dtype=paddle.float64)
self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.buf_name.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.linear.to()
self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.buf_name.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
def test_to_api_numpy_dtype(self):
self.linear.to(dtype=np.float64)
self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.buf_name.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.linear.to()
self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.buf_name.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册