diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 24ec364efb3b6259b5792a74f7092ef3306920ce..8387123ae11fe09e3b5664e6f1fb0cdb23a7c95a 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -75,6 +75,47 @@ static bool IsNumpyType(PyObject* obj) { type_name == "numpy.int32" || type_name == "numpy.int16"; } +static bool IsNumpyArray(PyObject* obj) { + auto type_name = std::string(Py_TYPE(obj)->tp_name); + return type_name == "numpy.ndarray"; +} + +void InitTensorWithNumpyValue(const py::object& array, + const paddle::platform::Place& place, + Tensor* self, + bool zero_copy = false) { + PADDLE_ENFORCE_EQ( + self->defined(), + true, + paddle::platform::errors::Fatal( + "Calling InitTensorWithNumpyValue of Eager Tensor without " + "EmptyTensorInitializer is " + "forbidden. Please check your code and make sure you new a " + "eager tensor before init it with NumPy.")); + phi::DenseTensor* impl_ptr = + static_cast(self->impl().get()); + if (platform::is_cpu_place(place)) { + SetTensorFromPyArray(impl_ptr, array, place, zero_copy); + } else if (platform::is_xpu_place(place)) { + SetTensorFromPyArray(impl_ptr, array, place, zero_copy); + } else if (platform::is_gpu_place(place)) { + SetTensorFromPyArray( + impl_ptr, array, place, zero_copy); + } else if (platform::is_cuda_pinned_place(place)) { + SetTensorFromPyArray( + impl_ptr, array, place, zero_copy); + } else if (platform::is_npu_place(place)) { + SetTensorFromPyArray(impl_ptr, array, place, zero_copy); + } else if (platform::is_custom_place(place)) { + SetTensorFromPyArray( + impl_ptr, array, place, zero_copy); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Place should be one of " + "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/CustomPlace")); + } +} + std::set _supported_int_dtype_{DataType::UINT8, DataType::INT8, DataType::INT16, @@ -192,7 +233,13 @@ static PyObject* tensor__add__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::experimental::Tensor other_tensor; - if (!PyCheckTensor(other_obj)) { + if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__add__", 0); { @@ -200,8 +247,6 @@ static PyObject* tensor__add__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -289,7 +334,13 @@ static PyObject* tensor__sub__method(TensorObject* self, } // 2. create or get tensor for other_obj paddle::experimental::Tensor other_tensor; - if (!PyCheckTensor(other_obj)) { + if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__sub__", 0); { @@ -297,8 +348,6 @@ static PyObject* tensor__sub__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -382,7 +431,13 @@ static PyObject* tensor__rsub__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::experimental::Tensor other_tensor; - if (!PyCheckTensor(other_obj)) { + if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__rsub__", 0); { @@ -390,8 +445,6 @@ static PyObject* tensor__rsub__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -477,7 +530,13 @@ static PyObject* tensor__mul__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::experimental::Tensor other_tensor; - if (!PyCheckTensor(other_obj)) { + if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__mul__", 0); if (PyComplex_Check(other_obj)) { @@ -489,8 +548,6 @@ static PyObject* tensor__mul__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -579,7 +636,13 @@ static PyObject* tensor__div__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::experimental::Tensor other_tensor; - if (!PyCheckTensor(other_obj)) { + if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__div__", 0); if (PyComplex_Check(other_obj)) { @@ -591,8 +654,6 @@ static PyObject* tensor__div__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -695,7 +756,13 @@ static PyObject* tensor__rdiv__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), place); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__rdiv__", 0); if (PyComplex_Check(other_obj)) { @@ -707,8 +774,6 @@ static PyObject* tensor__rdiv__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -809,7 +874,13 @@ static PyObject* tensor__gt__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), place); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__gt__", 0); if (PyComplex_Check(other_obj)) { @@ -821,8 +892,6 @@ static PyObject* tensor__gt__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -895,7 +964,13 @@ static PyObject* tensor__ge__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), place); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__ge__", 0); if (PyComplex_Check(other_obj)) { @@ -907,8 +982,6 @@ static PyObject* tensor__ge__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -982,7 +1055,13 @@ static PyObject* tensor__mod__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__mod__", 0); if (PyComplex_Check(other_obj)) { @@ -994,8 +1073,6 @@ static PyObject* tensor__mod__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1068,7 +1145,13 @@ static PyObject* tensor__matmul__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__matmul__", 0); if (PyComplex_Check(other_obj)) { @@ -1080,8 +1163,6 @@ static PyObject* tensor__matmul__method(TensorObject* self, other_tensor = full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1172,7 +1253,13 @@ static PyObject* tensor__lt__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__lt__", 0); if (PyComplex_Check(other_obj)) { @@ -1184,8 +1271,6 @@ static PyObject* tensor__lt__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1258,7 +1343,13 @@ static PyObject* tensor__le__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__le__", 0); if (PyComplex_Check(other_obj)) { @@ -1270,8 +1361,6 @@ static PyObject* tensor__le__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1345,7 +1434,13 @@ static PyObject* tensor__floordiv__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__floordiv__", 0); if (PyComplex_Check(other_obj)) { @@ -1357,8 +1452,6 @@ static PyObject* tensor__floordiv__method(TensorObject* self, other_tensor = full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1430,7 +1523,13 @@ static PyObject* tensor__pow__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::experimental::Tensor other_tensor; - if (!PyCheckTensor(other_obj)) { + if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__pow__", 0); if (PyComplex_Check(other_obj)) { @@ -1442,8 +1541,6 @@ static PyObject* tensor__pow__method(TensorObject* self, other_tensor = full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1518,7 +1615,13 @@ static PyObject* tensor__rpow__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__rpow__", 0); if (PyComplex_Check(other_obj)) { @@ -1530,8 +1633,6 @@ static PyObject* tensor__rpow__method(TensorObject* self, other_tensor = full_ad_func( self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1604,7 +1705,13 @@ static PyObject* tensor__ne__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__ne__", 0); if (PyComplex_Check(other_obj)) { @@ -1616,8 +1723,6 @@ static PyObject* tensor__ne__method(TensorObject* self, other_tensor = full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var @@ -1690,7 +1795,13 @@ static PyObject* tensor__eq__method(TensorObject* self, phi::Scalar(other_double), self_tensor.dtype(), self_tensor.place()); - } else if (!PyCheckTensor(other_obj)) { + } else if (PyCheckTensor(other_obj)) { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } else if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::experimental::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__eq__", 0); if (PyComplex_Check(other_obj)) { @@ -1702,8 +1813,6 @@ static PyObject* tensor__eq__method(TensorObject* self, other_tensor = full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); } - } else { - other_tensor = CastPyArg2Tensor(other_obj, 0); } // 3. promote types or unify right var type to left var diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index d9057ee4ca6ab5e1f20ce7de6a91204fcf4d71ec..79c5bdda4337b57c6b09d46e30ddb2e8567eebb6 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -739,6 +739,32 @@ class TestBoolAddFloatElementwiseAddop(unittest.TestCase): self.func_dygraph_add() +class TestElementwiseAddop1(unittest.TestCase): + def func_dygraph_add(self): + paddle.disable_static() + + np_a = np.random.random((2, 3, 4)).astype(np.float32) + np_b = np.random.random((2, 3, 4)).astype(np.float32) + + tensor_a = paddle.to_tensor(np_a, dtype="float32") + tensor_b = paddle.to_tensor(np_b, dtype="float32") + + # normal case: nparray + tenor + expect_out = np_a + np_b + actual_out = np_a + tensor_b + np.testing.assert_allclose(actual_out, expect_out) + + # normal case: tensor + nparray + actual_out = tensor_a + np_b + np.testing.assert_allclose(actual_out, expect_out) + + paddle.enable_static() + + def test_dygraph_add(self): + with _test_eager_guard(): + self.func_dygraph_add() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 7a0c5d09fbffcfd39aace38707e04448ed9e9576..9f37a456b7441e0f4be33b799907ba85c852e45b 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -18,6 +18,7 @@ from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 import paddle from paddle import fluid from paddle.fluid import core +from paddle.fluid.framework import _test_eager_guard class ElementwiseDivOp(OpTest): @@ -436,6 +437,34 @@ class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp): self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y) +class TestElementwiseDivop(unittest.TestCase): + def func_dygraph_div(self): + paddle.disable_static() + + np_a = np.random.random((2, 3, 4)).astype(np.float32) + np_b = np.random.random((2, 3, 4)).astype(np.float32) + np_a[np.abs(np_a) < 0.0005] = 0.002 + np_b[np.abs(np_b) < 0.0005] = 0.002 + + tensor_a = paddle.to_tensor(np_a, dtype="float32") + tensor_b = paddle.to_tensor(np_b, dtype="float32") + + # normal case: nparray / tenor + expect_out = np_a / np_b + actual_out = np_a / tensor_b + np.testing.assert_allclose(actual_out, expect_out) + + # normal case: tensor / nparray + actual_out = tensor_a / np_b + np.testing.assert_allclose(actual_out, expect_out) + + paddle.enable_static() + + def test_dygraph_div(self): + with _test_eager_guard(): + self.func_dygraph_div() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 263fb8a998182819c14488358b6c62a936c15f1c..c72728cfe951b3d5c9539068b110865950207f11 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -19,6 +19,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard from paddle.fluid.tests.unittests.op_test import ( OpTest, @@ -386,6 +387,32 @@ class TestRealComplexElementwiseMulOp(TestComplexElementwiseMulOp): self.grad_y = self.grad_out * np.conj(self.x) +class TestElementwiseMulop(unittest.TestCase): + def func_dygraph_mul(self): + paddle.disable_static() + + np_a = np.random.random((2, 3, 4)).astype(np.float32) + np_b = np.random.random((2, 3, 4)).astype(np.float32) + + tensor_a = paddle.to_tensor(np_a, dtype="float32") + tensor_b = paddle.to_tensor(np_b, dtype="float32") + + # normal case: nparray * tenor + expect_out = np_a * np_b + actual_out = np_a * tensor_b + np.testing.assert_allclose(actual_out, expect_out) + + # normal case: tensor * nparray + actual_out = tensor_a * np_b + np.testing.assert_allclose(actual_out, expect_out) + + paddle.enable_static() + + def test_dygraph_mul(self): + with _test_eager_guard(): + self.func_dygraph_mul() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index d89b3b22aa3bb8030747ca3062fc540299c89fd4..d2ad1d90f0846c4fd94249357510920536cc7cb2 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -450,6 +450,36 @@ class TestFloatElementwiseSubop(unittest.TestCase): self.func_dygraph_sub() +class TestFloatElementwiseSubop1(unittest.TestCase): + def func_dygraph_sub(self): + paddle.disable_static() + + np_a = np.random.random((2, 3, 4)).astype(np.float32) + np_b = np.random.random((2, 3, 4)).astype(np.float32) + + tensor_a = paddle.to_tensor(np_a, dtype="float32") + tensor_b = paddle.to_tensor(np_b, dtype="float32") + + # normal case: nparray - tenor + expect_out = np_a - np_b + actual_out = np_a - tensor_b + np.testing.assert_allclose( + actual_out, expect_out, rtol=1e-07, atol=1e-07 + ) + + # normal case: tenor - nparray + actual_out = tensor_a - np_b + np.testing.assert_allclose( + actual_out, expect_out, rtol=1e-07, atol=1e-07 + ) + + paddle.enable_static() + + def test_dygraph_sub(self): + with _test_eager_guard(): + self.func_dygraph_sub() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index 9efdb268a4b693e22e5bb34ada0df1f58b4b0d8e..61c843e9780c7bc7e29d7f325912a45b51a77292 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -380,5 +380,107 @@ class TestMathOpPatches(unittest.TestCase): np.testing.assert_allclose(a_np @ b_np, c_np, rtol=1e-05) +class TestDygraphMathOpPatches(unittest.TestCase): + def init_data(self): + self.np_a = np.random.random((2, 3, 4)).astype(np.float32) + self.np_b = np.random.random((2, 3, 4)).astype(np.float32) + self.np_a[np.abs(self.np_a) < 0.0005] = 0.002 + self.np_b[np.abs(self.np_b) < 0.0005] = 0.002 + + self.tensor_a = paddle.to_tensor(self.np_a, dtype="float32") + self.tensor_b = paddle.to_tensor(self.np_b, dtype="float32") + + def test_dygraph_greater_than(self): + paddle.disable_static() + self.init_data() + # normal case: tenor > nparray + expect_out = self.np_a > self.np_b + actual_out = self.tensor_a > self.np_b + np.testing.assert_equal(actual_out, expect_out) + paddle.enable_static() + + def test_dygraph_greater_equal(self): + paddle.disable_static() + self.init_data() + # normal case: tenor >= nparray + expect_out = self.np_a >= self.np_b + actual_out = self.tensor_a >= self.np_b + np.testing.assert_equal(actual_out, expect_out) + paddle.enable_static() + + def test_dygraph_reminder(self): + paddle.disable_static() + self.init_data() + # normal case: tenor % nparray + expect_out = self.np_a % self.np_b + actual_out = self.tensor_a % self.np_b + np.testing.assert_allclose(actual_out, expect_out, rtol=1e-7, atol=1e-7) + paddle.enable_static() + + def test_dygraph_less_than(self): + paddle.disable_static() + self.init_data() + # normal case: tenor < nparray + expect_out = self.np_a < self.np_b + actual_out = self.tensor_a < self.np_b + np.testing.assert_equal(actual_out, expect_out) + paddle.enable_static() + + def test_dygraph_less_equal(self): + paddle.disable_static() + self.init_data() + # normal case: tenor <= nparray + expect_out = self.np_a <= self.np_b + actual_out = self.tensor_a <= self.np_b + np.testing.assert_equal(actual_out, expect_out) + paddle.enable_static() + + def test_dygraph_floor_divide(self): + paddle.disable_static() + np_a = np.random.random((2, 3, 4)).astype(np.int32) + np_b = np.random.random((2, 3, 4)).astype(np.int32) + np_b[np.abs(np_b) < 1] = 2 + # normal case: tenor // nparray + tensor_a = paddle.to_tensor(np_a, dtype="int32") + tensor_b = paddle.to_tensor(np_b, dtype="int32") + expect_out = np_a // np_b + actual_out = tensor_a // np_b + np.testing.assert_equal(actual_out, expect_out) + paddle.enable_static() + + def test_dygraph_elementwise_pow(self): + paddle.disable_static() + self.init_data() + # normal case: tenor ** nparray + expect_out = self.np_a**self.np_b + actual_out = self.tensor_a**self.np_b + np.testing.assert_allclose(actual_out, expect_out, rtol=1e-7, atol=1e-7) + + # normal case: nparray ** tensor + expect_out = self.np_a**self.np_b + actual_out = self.np_a**self.tensor_b + np.testing.assert_allclose(actual_out, expect_out, rtol=1e-7, atol=1e-7) + + paddle.enable_static() + + def test_dygraph_not_equal(self): + paddle.disable_static() + self.init_data() + # normal case: tenor != nparray + expect_out = self.np_a != self.np_b + actual_out = self.tensor_a != self.np_b + np.testing.assert_equal(actual_out, expect_out) + paddle.enable_static() + + def test_dygraph_equal(self): + paddle.disable_static() + self.init_data() + # normal case: tenor == nparray + expect_out = self.np_a == self.np_b + actual_out = self.tensor_a == self.np_b + np.testing.assert_equal(actual_out, expect_out) + paddle.enable_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 9af6d6598d29a96c10cd5a4bf9cd0dddddd46672..a7c199bb4b3fd08b92a5f14177f7adf5ae841872 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -714,6 +714,28 @@ class TestMatMulTypePromotion(TestComplexMatMulOp): self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out) +class TestMatmulop(unittest.TestCase): + def func_dygraph_matmul(self): + paddle.disable_static() + + np_a = np.random.random((2, 4)).astype(np.float32) + np_b = np.random.random((4, 2)).astype(np.float32) + + tensor_a = paddle.to_tensor(np_a, dtype="float32") + tensor_b = paddle.to_tensor(np_b, dtype="float32") + + # normal case: tensor @ nparray + expect_out = np_a @ np_b + actual_out = tensor_a @ np_b + np.testing.assert_allclose(actual_out, expect_out) + + paddle.enable_static() + + def func_dygraph_matmul(self): + with _test_eager_guard(): + self.func_dygraph_matmul() + + if __name__ == "__main__": paddle.enable_static() unittest.main()