diff --git a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_shuffle_op.py similarity index 84% rename from python/paddle/fluid/tests/unittests/test_pixel_shuffle.py rename to python/paddle/fluid/tests/unittests/test_pixel_shuffle_op.py index 0ef6b3e77824ba8402d7add0b7c0adb4c5bed6f8..3543cce6ad04c3407d5c531625b5eae6dc4c58ea 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_shuffle_op.py @@ -140,6 +140,44 @@ class TestPixelShuffleAPI(unittest.TestCase): assert np.allclose(res_1, self.out_1_np) assert np.allclose(res_2, self.out_2_np) + def test_api_fp16(self): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + self.x_1_np = np.random.random([2, 9, 4, 4]).astype("float16") + self.x_2_np = np.random.random([2, 4, 4, 9]).astype("float16") + x_1 = paddle.fluid.data( + name="x", shape=[2, 9, 4, 4], dtype="float16" + ) + x_2 = paddle.fluid.data( + name="x2", shape=[2, 4, 4, 9], dtype="float16" + ) + # init instance + ps_1 = paddle.nn.PixelShuffle(3) + ps_2 = paddle.nn.PixelShuffle(3, "NHWC") + out_1 = ps_1(x_1) + out_2 = ps_2(x_2) + out_1_np = pixel_shuffle_np(self.x_1_np, 3) + out_2_np = pixel_shuffle_np(self.x_2_np, 3, "NHWC") + exe = paddle.static.Executor(place=place) + res_1 = exe.run( + fluid.default_main_program(), + feed={"x": self.x_1_np}, + fetch_list=out_1, + use_prune=True, + ) + res_2 = exe.run( + fluid.default_main_program(), + feed={"x2": self.x_2_np}, + fetch_list=out_2, + use_prune=True, + ) + assert np.allclose(res_1, out_1_np) + assert np.allclose(res_2, out_2_np) + # same test between layer and functional in this op. def test_static_graph_layer(self): for use_cuda in ( diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 1178928acc2dabfe387f90f25076794027019a44..bbe4161a27ab42cdaecc155dbcec60fb476257a3 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -383,7 +383,7 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW", name=None): else: helper = LayerHelper("pixel_shuffle", **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64'], 'pixel_shuffle' + x, 'x', ['float16', 'float32', 'float64'], 'pixel_shuffle' ) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(