diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 7a6705e63b420b71787d2ae0b35791e47afa3cda..0c6707748ef5ab87a368b171a29301acfa12789f 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/op_meta_info_helper.h" +#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" @@ -35,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" +#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/storage.h" @@ -771,6 +773,7 @@ static PyObject* eager_api_async_write(PyObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } #endif + PyMethodDef variable_functions[] = { // TODO(jiabin): Remove scale when we have final state tests {"scale", (PyCFunction)(void (*)(void))eager_api_scale, @@ -794,13 +797,13 @@ PyMethodDef variable_functions[] = { {"sparse_csr_tensor", (PyCFunction)(void (*)(void))eager_api_sparse_csr_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, +/**sparse functions**/ #if defined(PADDLE_WITH_CUDA) {"async_read", (PyCFunction)(void (*)(void))eager_api_async_read, METH_VARARGS | METH_KEYWORDS, NULL}, {"async_write", (PyCFunction)(void (*)(void))eager_api_async_write, METH_VARARGS | METH_KEYWORDS, NULL}, #endif - /**sparse functions**/ {NULL, NULL, 0, NULL}}; void BindFunctions(PyObject* module) { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b8ed2716fc7d5a70832c2c6300d75becf99d32df..dc1f82d235e31c4c98cf7249e41483e2496d37b0 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -173,9 +173,13 @@ def _test_eager_guard(place=None): monkey_patch_math_varbase() # Ugly setting - from paddle.tensor.manipulation import fill_, zero_ + from paddle.tensor.manipulation import fill_, zero_, fill_diagonal_, fill_diagonal_tensor_, tolist setattr(core.eager.Tensor, 'fill_', fill_) setattr(core.eager.Tensor, 'zero_', zero_) + setattr(core.eager.Tensor, 'fill_diagonal_', fill_diagonal_) + setattr(core.eager.Tensor, 'fill_diagonal_tensor_', + fill_diagonal_tensor_) + setattr(core.eager.Tensor, 'tolist', tolist) _already_patch_eager_tensor = True try: diff --git a/python/paddle/fluid/tests/unittests/test_Tensor_type.py b/python/paddle/fluid/tests/unittests/test_Tensor_type.py index f1427d29782b969d9571f79c9a7bc62bf4e77070..c40981c07372405a6cd2990e7bd156451b37e32b 100644 --- a/python/paddle/fluid/tests/unittests/test_Tensor_type.py +++ b/python/paddle/fluid/tests/unittests/test_Tensor_type.py @@ -18,10 +18,11 @@ import unittest import numpy as np import paddle import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard class TensorTypeTest(unittest.TestCase): - def test_type_totensor(self): + def func_type_totensor(self): paddle.disable_static() inx = np.array([1, 2]) tensorx = paddle.to_tensor(inx) @@ -29,7 +30,12 @@ class TensorTypeTest(unittest.TestCase): expectx = "" self.assertEqual((typex_str == expectx), True) - def test_type_Tensor(self): + def test_type_totensor(self): + with _test_eager_guard(): + self.func_type_totensor() + self.func_type_totensor() + + def func_type_Tensor(self): paddle.disable_static() inx = np.array([1, 2]) tensorx = paddle.Tensor(inx) @@ -43,7 +49,12 @@ class TensorTypeTest(unittest.TestCase): expectx = "" self.assertEqual((typex_str == expectx), True) - def test_type_core(self): + def test_type_Tensor(self): + with _test_eager_guard(): + self.func_type_Tensor() + self.func_type_Tensor() + + def func_type_core(self): paddle.disable_static() inx = np.array([1, 2]) tensorx = core.VarBase(inx) @@ -56,6 +67,11 @@ class TensorTypeTest(unittest.TestCase): expectx = "" self.assertEqual((typex_str == expectx), True) + def test_type_core(self): + with _test_eager_guard(): + pass + self.func_type_core() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adamax_api.py b/python/paddle/fluid/tests/unittests/test_adamax_api.py index 57cb9d3cb5f7ddef60f6577ba0d8217ab3d16b45..1698ac90a9f2d1dd3becdfbe830ad84ab536b0a7 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_api.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_api.py @@ -19,10 +19,11 @@ import numpy as np from op_test import OpTest import paddle import paddle.fluid as fluid +from paddle.fluid.framework import _test_eager_guard class TestAdamaxAPI(unittest.TestCase): - def test_adamax_api_dygraph(self): + def func_adamax_api_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) @@ -36,7 +37,12 @@ class TestAdamaxAPI(unittest.TestCase): adam.step() adam.clear_gradients() - def test_adamax_api(self): + def test_adamax_api_dygraph(self): + with _test_eager_guard(): + self.func_adamax_api_dygraph() + self.func_adamax_api_dygraph() + + def func_adamax_api(self): paddle.enable_static() place = fluid.CPUPlace() shape = [2, 3, 8, 8] @@ -63,9 +69,14 @@ class TestAdamaxAPI(unittest.TestCase): rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss]) assert rets[0] is not None + def test_adamax_api(self): + with _test_eager_guard(): + self.func_adamax_api() + self.func_adamax_api() + class TestAdamaxAPIGroup(TestAdamaxAPI): - def test_adamax_api_dygraph(self): + def func_adamax_api_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) @@ -89,6 +100,11 @@ class TestAdamaxAPIGroup(TestAdamaxAPI): adam.step() adam.clear_gradients() + def test_adamax_api_dygraph(self): + with _test_eager_guard(): + self.func_adamax_api_dygraph() + self.func_adamax_api_dygraph() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py index 3beb6a537eca0719a2494aad86b8d97bb0339cd9..ca0c97adedb943c19b9b01e7ae9830da57fccdac 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py @@ -17,10 +17,11 @@ import unittest import numpy as np import six import paddle +from paddle.fluid.framework import _test_eager_guard class TensorFillDiagonal_Test(unittest.TestCase): - def test_dim2_normal(self): + def func_dim2_normal(self): expected_np = np.array( [[1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') expected_grad = np.array( @@ -50,7 +51,12 @@ class TensorFillDiagonal_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_offset(self): + def test_dim2_normal(self): + with _test_eager_guard(): + self.func_dim2_normal() + self.func_dim2_normal() + + def func_offset(self): expected_np = np.array( [[2, 2, 1], [2, 2, 2], [2, 2, 2]]).astype('float32') expected_grad = np.array( @@ -80,7 +86,12 @@ class TensorFillDiagonal_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_bool(self): + def test_offset(self): + with _test_eager_guard(): + self.func_offset() + self.func_offset() + + def func_bool(self): expected_np = np.array( [[False, True, True], [True, False, True], [True, True, False]]) @@ -101,7 +112,12 @@ class TensorFillDiagonal_Test(unittest.TestCase): self.assertEqual((x.numpy() == expected_np).all(), True) - def test_dim2_unnormal_wrap(self): + def test_bool(self): + with _test_eager_guard(): + self.func_bool() + self.func_bool() + + def func_dim2_unnormal_wrap(self): expected_np = np.array([[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2], [1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') @@ -133,7 +149,12 @@ class TensorFillDiagonal_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim2_unnormal_unwrap(self): + def test_dim2_unnormal_wrap(self): + with _test_eager_guard(): + self.func_dim2_unnormal_wrap() + self.func_dim2_unnormal_wrap() + + def func_dim2_unnormal_unwrap(self): expected_np = np.array([[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]).astype('float32') @@ -165,7 +186,12 @@ class TensorFillDiagonal_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim_larger2_normal(self): + def test_dim2_unnormal_unwrap(self): + with _test_eager_guard(): + self.func_dim2_unnormal_unwrap() + self.func_dim2_unnormal_unwrap() + + def func_dim_larger2_normal(self): expected_np = np.array([[[1, 2, 2], [2, 2, 2], [2, 2, 2]], [[2, 2, 2], [ 2, 1, 2 ], [2, 2, 2]], [[2, 2, 2], [2, 2, 2], [2, 2, 1]]]).astype('float32') @@ -198,6 +224,11 @@ class TensorFillDiagonal_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) + def test_dim_larger2_normal(self): + with _test_eager_guard(): + self.func_dim_larger2_normal() + self.func_dim_larger2_normal() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py index 2f37ccf219eb08aa3e8ae1fd9853801a01a893be..81ec1daa6691d3e48c91777e08710ea98b0f3a62 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py @@ -18,6 +18,7 @@ import unittest import numpy as np import six import paddle +from paddle.fluid.framework import _test_eager_guard class TensorFillDiagTensor_Test(unittest.TestCase): @@ -27,7 +28,7 @@ class TensorFillDiagTensor_Test(unittest.TestCase): if fluid.core.is_compiled_with_cuda(): self.places.append(fluid.CUDAPlace(0)) - def test_dim2(self): + def func_dim2(self): expected_np = np.array( [[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2]]).astype('float32') expected_grad = np.array( @@ -54,7 +55,12 @@ class TensorFillDiagTensor_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim2_offset_1(self): + def test_dim2(self): + with _test_eager_guard(): + self.func_dim2() + self.func_dim2() + + def func_dim2_offset_1(self): expected_np = np.array( [[2, 2, 2], [1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') expected_grad = np.array( @@ -81,7 +87,12 @@ class TensorFillDiagTensor_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim2_offset1(self): + def test_dim2_offset_1(self): + with _test_eager_guard(): + self.func_dim2_offset_1() + self.func_dim2_offset_1() + + def func_dim2_offset1(self): expected_np = np.array( [[2, 1, 2], [2, 2, 1], [2, 2, 2], [2, 2, 2]]).astype('float32') expected_grad = np.array( @@ -108,7 +119,12 @@ class TensorFillDiagTensor_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim4(self): + def test_dim2_offset1(self): + with _test_eager_guard(): + self.func_dim2_offset1() + self.func_dim2_offset1() + + def func_dim4(self): expected_np = np.array( [[[[0, 3], [2, 2], [2, 2]], [[2, 2], [1, 4], [2, 2]], [[2, 2], [2, 2], [2, 5]], [[2, 2], [2, 2], [2, 2]]], @@ -144,7 +160,12 @@ class TensorFillDiagTensor_Test(unittest.TestCase): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_largedim(self): + def test_func_dim4(self): + with _test_eager_guard(): + self.func_dim4() + self.func_dim4() + + def func_largedim(self): #large dim only test on gpu because the cpu version is too slow for ci test, and the memory is limited if len(self.places) > 1: bsdim = 1024 @@ -168,6 +189,11 @@ class TensorFillDiagTensor_Test(unittest.TestCase): self.assertEqual((y == expected_pred).all(), True) self.assertEqual((y.grad == expected_grad).all(), True) + def test_largedim(self): + with _test_eager_guard(): + self.func_largedim() + self.func_largedim() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_to_list.py b/python/paddle/fluid/tests/unittests/test_tensor_to_list.py index 73b91297e6fd62e083b340dffa72ddddaf16863c..a78113030ed53b8c4a4933ac99d85d1f2dcf5131 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_to_list.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_to_list.py @@ -17,13 +17,14 @@ import unittest import numpy as np import six import paddle +from paddle.fluid.framework import _test_eager_guard class TensorToListTest(unittest.TestCase): def setUp(self): self.shape = [11, 25, 32, 43] - def test_tensor_tolist(self): + def func_tensor_tolist(self): places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -39,6 +40,11 @@ class TensorToListTest(unittest.TestCase): self.assertEqual(tensorlist, expectlist) + def test_tensor_tolist(self): + with _test_eager_guard(): + self.func_tensor_tolist() + self.func_tensor_tolist() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/optimizer/adamax.py b/python/paddle/optimizer/adamax.py index de70e2e72a9c678a7655150ef854f82e1c32aa51..4c4a85559c0d95ec7c0f4c2a97c125820bb2131b 100644 --- a/python/paddle/optimizer/adamax.py +++ b/python/paddle/optimizer/adamax.py @@ -16,6 +16,7 @@ from .optimizer import Optimizer from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable, name_scope +from paddle import _C_ops __all__ = [] @@ -190,30 +191,38 @@ class Adamax(Optimizer): param_and_grad[0]) beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, param_and_grad[0]) - # create the adamax optimize op - adamax_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "LearningRate": self._create_param_lr(param_and_grad), - "Moment": moment, - "InfNorm": inf_norm, - "Beta1Pow": beta1_pow_acc - }, - outputs={ - "ParamOut": param_and_grad[0], - "MomentOut": moment, - "InfNormOut": inf_norm - }, - attrs={ - "beta1": self._beta1, - "beta2": self._beta2, - "epsilon": self._epsilon - }, - stop_gradient=True) - return adamax_op + if framework._non_static_mode(): + _C_ops.adamax(param_and_grad[0], param_and_grad[1], + self._create_param_lr(param_and_grad), moment, + inf_norm, beta1_pow_acc, param_and_grad[0], moment, + inf_norm, "beta1", self._beta1, "beta2", self._beta2, + "epsilon", self._epsilon) + else: + # create the adamax optimize op + adamax_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": self._create_param_lr(param_and_grad), + "Moment": moment, + "InfNorm": inf_norm, + "Beta1Pow": beta1_pow_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": moment, + "InfNormOut": inf_norm + }, + attrs={ + "beta1": self._beta1, + "beta2": self._beta2, + "epsilon": self._epsilon + }, + stop_gradient=True) + + return adamax_op def _finish_update(self, block, parameters_and_grads): """Update Beta1 Power accumulator