未验证 提交 e4fbb286 编写于 作者: L LoneRanger 提交者: GitHub

[fp16] suppot fp16 in tensordot (#50938)

* fix fp16 bug of tensordot

* fix fp16 of tensordot

* fix fp16 of tensordot
上级 b69af7ad
......@@ -224,6 +224,33 @@ class TestTensordotAPI(unittest.TestCase):
np_res = tensordot_np(self.x, self.y, axes)
np.testing.assert_allclose(paddle_res[0], np_res, rtol=1e-6)
def test_fp16_with_gpu(self):
paddle.enable_static()
if paddle.fluid.core.is_compiled_with_cuda():
for axes in self.all_axes:
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input_x = np.random.random([5, 5, 5, 5]).astype("float16")
x = paddle.static.data(
name="x", shape=[5, 5, 5, 5], dtype="float16"
)
input_y = np.random.random([5, 5, 5, 5]).astype("float16")
y = paddle.static.data(
name="y", shape=[5, 5, 5, 5], dtype="float16"
)
z = paddle.tensordot(x, y, axes)
exe = paddle.static.Executor(place)
paddle_res = exe.run(
feed={'x': input_x, 'y': input_y}, fetch_list=[z]
)
np_res = tensordot_np(input_x, input_y, axes)
np.testing.assert_allclose(paddle_res[0], np_res, rtol=1)
class TestTensordotAPIFloat64(TestTensordotAPI):
def set_dtype(self):
......
......@@ -3940,7 +3940,7 @@ def tensordot(x, y, axes=2, name=None):
This function computes a contraction, which sum the product of elements from two tensors along the given axes.
Args:
x (Tensor): The left tensor for contraction with data type ``float32`` or ``float64``.
x (Tensor): The left tensor for contraction with data type ``float16`` or ``float32`` or ``float64``.
y (Tensor): The right tensor for contraction with the same data type as ``x``.
axes (int|tuple|list|Tensor, optional): The axes to contract for ``x`` and ``y``, defaulted to integer ``2``.
......@@ -4048,7 +4048,7 @@ def tensordot(x, y, axes=2, name=None):
# [28312230., 30496530., 32680830., 34865130.]]
"""
op_type = 'tensordot'
input_dtype = ['float32', 'float64']
input_dtype = ['float16', 'float32', 'float64']
check_variable_and_dtype(x, 'x', input_dtype, op_type)
check_variable_and_dtype(y, 'y', input_dtype, op_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册