From e21b13fbf820388e7009d8b7db82881a4d0d0ad0 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Mon, 19 Oct 2020 16:00:09 +0800 Subject: [PATCH] [API 2.0: doc] transfer from paddle.fluid.layers.assign() into creation.py (#27999) * transfer from paddle.fluid.layers.assign() into creation.py,test=develop * fix ut fail,add support for paddle.assign,test=develop * fix,test=develop * fix UT coverage,test=coverage * fix UT fail,test=coverage * fix doc,test=develop --- python/paddle/__init__.py | 2 +- .../fluid/tests/unittests/test_assign_op.py | 77 +++++++++++++++++++ python/paddle/tensor/creation.py | 77 ++++++++++++++++++- 3 files changed, 154 insertions(+), 2 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index a7602f15419..21827166d18 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -77,6 +77,7 @@ from .tensor.creation import tril #DEFINE_ALIAS from .tensor.creation import meshgrid #DEFINE_ALIAS from .tensor.creation import empty #DEFINE_ALIAS from .tensor.creation import empty_like #DEFINE_ALIAS +from .tensor.creation import assign #DEFINE_ALIAS from .tensor.linalg import matmul #DEFINE_ALIAS from .tensor.linalg import dot #DEFINE_ALIAS # from .tensor.linalg import einsum #DEFINE_ALIAS @@ -262,7 +263,6 @@ from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS from .fluid.dygraph.base import no_grad_ as no_grad #DEFINE_ALIAS from .fluid.layers import crop_tensor as crop #DEFINE_ALIAS - from . import jit from . import static from . import amp diff --git a/python/paddle/fluid/tests/unittests/test_assign_op.py b/python/paddle/fluid/tests/unittests/test_assign_op.py index 49c41823055..82ddafb8f95 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_assign_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import op_test import numpy as np import unittest +import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -99,5 +100,81 @@ class TestAssignOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.assign, x5) +class TestAssignOApi(unittest.TestCase): + def test_assign_LoDTensorArray(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program): + x = fluid.data(name='x', shape=[100, 10], dtype='float32') + x.stop_gradient = False + y = fluid.layers.fill_constant( + shape=[100, 10], dtype='float32', value=1) + z = fluid.layers.elementwise_add(x=x, y=y) + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + init_array = fluid.layers.array_write(x=z, i=i) + array = paddle.assign(init_array) + sums = fluid.layers.array_read(array=init_array, i=i) + mean = fluid.layers.mean(sums) + append_backward(mean) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + feed_x = np.random.random(size=(100, 10)).astype('float32') + ones = np.ones((100, 10)).astype('float32') + feed_add = feed_x + ones + res = exe.run(main_program, + feed={'x': feed_x}, + fetch_list=[sums.name, x.grad_name]) + self.assertTrue(np.allclose(res[0], feed_add)) + self.assertTrue(np.allclose(res[1], ones / 1000.0)) + + def test_assign_NumpyArray(self): + with fluid.dygraph.guard(): + array = np.random.random(size=(100, 10)).astype(np.bool) + result1 = paddle.zeros(shape=[3, 3], dtype='float32') + paddle.assign(array, result1) + self.assertTrue(np.allclose(result1.numpy(), array)) + + def test_assign_NumpyArray1(self): + with fluid.dygraph.guard(): + array = np.random.random(size=(100, 10)).astype(np.float32) + result1 = paddle.zeros(shape=[3, 3], dtype='float32') + paddle.assign(array, result1) + self.assertTrue(np.allclose(result1.numpy(), array)) + + def test_assign_NumpyArray2(self): + with fluid.dygraph.guard(): + array = np.random.random(size=(100, 10)).astype(np.int32) + result1 = paddle.zeros(shape=[3, 3], dtype='float32') + paddle.assign(array, result1) + self.assertTrue(np.allclose(result1.numpy(), array)) + + def test_assign_NumpyArray3(self): + with fluid.dygraph.guard(): + array = np.random.random(size=(100, 10)).astype(np.int64) + result1 = paddle.zeros(shape=[3, 3], dtype='float32') + paddle.assign(array, result1) + self.assertTrue(np.allclose(result1.numpy(), array)) + + +class TestAssignOpErrorApi(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The type of input must be Variable or numpy.ndarray. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + self.assertRaises(TypeError, paddle.assign, x1) + # When the type of input is Variable, the dtype of input must be float16, float32, float64, int32, int64, bool. + x3 = fluid.layers.data(name='x3', shape=[4], dtype="uint8") + self.assertRaises(TypeError, paddle.assign, x3) + # When the type of input is numpy.ndarray, the dtype of input must be float32, int32. + x4 = np.array([[2.5, 2.5]], dtype='float64') + self.assertRaises(TypeError, paddle.assign, x4) + x5 = np.array([[2.5, 2.5]], dtype='uint8') + self.assertRaises(TypeError, paddle.assign, x5) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 3eddc1ee1ae..65a33ade27a 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -47,7 +47,8 @@ __all__ = [ 'empty_like', 'triu', 'tril', - 'meshgrid' + 'meshgrid', + 'assign', ] @@ -1106,3 +1107,77 @@ def empty_like(x, dtype=None, name=None): stop_gradient=True) out.stop_gradient = True return out + + +def assign(x, output=None): + """ + + + The OP copies the :attr:`x` to the :attr:`output`. + + Parameters: + x (Tensor|numpy.ndarray): A tensor or numpy ndarray, its data type supports + float16, float32, float64, int32 and int64. + output (Tensor, optional): A tensor. If :attr:`output` is None, a new tensor will + be created as :attr:`output`. Default: None. + + Returns: + Tensor: A tensor with the same shape, data type and value as :attr:`x`. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + data = paddle.full(shape=[3, 2], fill_value=2.5, dtype='float64') # [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] + array = np.array([[1, 1], + [3, 4], + [1, 3]]).astype(np.int64) + result1 = paddle.zeros(shape=[3, 3], dtype='float32') + paddle.assign(array, result1) # result1 = [[1, 1], [3 4], [1, 3]] + result2 = paddle.assign(data) # result2 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] + result3 = paddle.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] + """ + helper = LayerHelper('assign', **locals()) + check_type(x, 'x', (Variable, numpy.ndarray), 'assign') + if isinstance(x, Variable): + check_dtype( + x.dtype, 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], + 'assign', '(When the type of input in assign is Variable.)') + if output is None: + output = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='assign', inputs={'X': [x]}, outputs={'Out': [output]}) + elif isinstance(x, numpy.ndarray): + dtype = convert_np_dtype_to_dtype_(x.dtype) + if dtype == VarDesc.VarType.BOOL: + value_name = "bool_values" + values = [bool(v) for v in x.flat] + elif dtype == VarDesc.VarType.FP32: + value_name = "fp32_values" + values = [float(v) for v in x.flat] + elif dtype == VarDesc.VarType.INT32: + value_name = "int32_values" + values = [int(v) for v in x.flat] + elif dtype == VarDesc.VarType.INT64: + value_name = "int64_values" + values = [int(v) for v in x.flat] + else: + raise TypeError( + "When the type of 'x' in assign is numpy.ndarray, " + "the data type of 'x' must be bool, float32, int32 or int64, but " + "received %s." % convert_dtype(dtype)) + if x.size > 1024 * 1024: + raise ValueError("The size of input is too big. Please consider " + "saving it to file and 'load_op' to load it") + if output is None: + output = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='assign_value', + outputs={'Out': [output]}, + attrs={'dtype': dtype, + 'shape': list(x.shape), + value_name: values}) + + return output -- GitLab