未验证 提交 7ffbf7e3 编写于 作者: 张春乔 提交者: GitHub

suppot fp16 in flatten (#50906)

上级 77298931
......@@ -17,6 +17,8 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle
class TestFlattenOp(OpTest):
def setUp(self):
......@@ -64,5 +66,31 @@ class TestFlattenOpSixDims(TestFlattenOp):
self.new_shape = (36, 16)
class TestFlattenOpFP16(unittest.TestCase):
def test_fp16_with_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([12, 14]).astype("float16")
x = paddle.static.data(
name="x", shape=[12, 14], dtype="float16"
)
y = paddle.flatten(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[y],
)
assert np.array_equal(res[0].shape, [12 * 14])
if __name__ == "__main__":
unittest.main()
......@@ -1510,7 +1510,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
Out.shape = (3 * 100 * 100 * 4)
Args:
x (Tensor): A tensor of number of dimentions >= axis. A tensor with data type float32,
x (Tensor): A tensor of number of dimentions >= axis. A tensor with data type float16, float32,
float64, int8, int32, int64, uint8.
start_axis (int): the start axis to flatten
stop_axis (int): the stop axis to flatten
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册