diff --git a/python/paddle/fluid/tests/unittests/test_tensordot.py b/python/paddle/fluid/tests/unittests/test_tensordot.py index a8c4dbaed47305e92df60ebcf50a7d82307b1d8f..523464f42574ecab470c16858243c5959314e396 100644 --- a/python/paddle/fluid/tests/unittests/test_tensordot.py +++ b/python/paddle/fluid/tests/unittests/test_tensordot.py @@ -224,6 +224,33 @@ class TestTensordotAPI(unittest.TestCase): np_res = tensordot_np(self.x, self.y, axes) np.testing.assert_allclose(paddle_res[0], np_res, rtol=1e-6) + def test_fp16_with_gpu(self): + paddle.enable_static() + if paddle.fluid.core.is_compiled_with_cuda(): + for axes in self.all_axes: + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input_x = np.random.random([5, 5, 5, 5]).astype("float16") + x = paddle.static.data( + name="x", shape=[5, 5, 5, 5], dtype="float16" + ) + + input_y = np.random.random([5, 5, 5, 5]).astype("float16") + y = paddle.static.data( + name="y", shape=[5, 5, 5, 5], dtype="float16" + ) + + z = paddle.tensordot(x, y, axes) + exe = paddle.static.Executor(place) + + paddle_res = exe.run( + feed={'x': input_x, 'y': input_y}, fetch_list=[z] + ) + np_res = tensordot_np(input_x, input_y, axes) + np.testing.assert_allclose(paddle_res[0], np_res, rtol=1) + class TestTensordotAPIFloat64(TestTensordotAPI): def set_dtype(self): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 91fa35c05981443cdfaa79ce0ecdcaaa3a2f136e..cd45c33d74ea5b302572f3542cf2fa5058b5d181 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3940,7 +3940,7 @@ def tensordot(x, y, axes=2, name=None): This function computes a contraction, which sum the product of elements from two tensors along the given axes. Args: - x (Tensor): The left tensor for contraction with data type ``float32`` or ``float64``. + x (Tensor): The left tensor for contraction with data type ``float16`` or ``float32`` or ``float64``. y (Tensor): The right tensor for contraction with the same data type as ``x``. axes (int|tuple|list|Tensor, optional): The axes to contract for ``x`` and ``y``, defaulted to integer ``2``. @@ -4048,7 +4048,7 @@ def tensordot(x, y, axes=2, name=None): # [28312230., 30496530., 32680830., 34865130.]] """ op_type = 'tensordot' - input_dtype = ['float32', 'float64'] + input_dtype = ['float16', 'float32', 'float64'] check_variable_and_dtype(x, 'x', input_dtype, op_type) check_variable_and_dtype(y, 'y', input_dtype, op_type)