diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index bf10e07ba0d6fcb271d1ee23383bae2c6a9e597f..cb72248b155fdc548d3e7899baaadb63f8b78f01 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -333,7 +333,7 @@ class TestDropoutFAPI(unittest.TestCase): def check_static_result(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): - input = fluid.data(name="input", shape=[40, 40], dtype="float32") + input = fluid.data(name="input", shape=[-1, -1], dtype="float32") res1 = paddle.nn.functional.dropout(x=input, p=0., training=False) res2 = paddle.nn.functional.dropout( x=input, p=0., axis=0, training=True, mode='upscale_in_train') @@ -380,7 +380,10 @@ class TestDropoutFAPI(unittest.TestCase): training=False, mode='upscale_in_train') - in_np = np.random.random([40, 40]).astype("float32") + res13 = paddle.nn.functional.dropout( + x=input, p=0.7, axis=1, training=True, mode='upscale_in_train') + + in_np = np.ones([40, 40]).astype("float32") res_np = in_np res_np2 = np.zeros_like(in_np) @@ -398,6 +401,9 @@ class TestDropoutFAPI(unittest.TestCase): feed={"input": in_np}, fetch_list=[res10]) self.assertTrue(np.allclose(fetches2[0], res_np2)) + fetches3 = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res13]) def test_static(self): for place in self.places: @@ -471,6 +477,12 @@ class TestDropoutFAPI(unittest.TestCase): axis=(0, 1), training=False, mode='upscale_in_train') + res13 = paddle.nn.functional.dropout( + x=input, + p=0.5, + axis=1, + training=True, + mode='upscale_in_train') res_list = [ res1, res2, res3, res4, res5, res6, res7, res8, res9, res11, diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 217f8cd4125518a5a5f85cb6138f16e3046fff0d..98019ceb480a01a55c0c3092f800c59e15d0c77e 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -940,6 +940,8 @@ def dropout(x, #get mask shape input_shape = x.shape + if not in_dygraph_mode(): + input_shape_tensor = paddle.shape(x) drop_axes = [axis] if isinstance(axis, int) else list(axis) if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1: raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \ @@ -949,8 +951,12 @@ def dropout(x, "length of axis should not be greater than dimensions of x:{}, but get length of axis: {}". format(len(input_shape), len(drop_axes))) mask_shape = [1] * len(input_shape) - for i in drop_axes: - mask_shape[i] = input_shape[i] + if not in_dygraph_mode(): + for i in drop_axes: + mask_shape[i] = input_shape_tensor[i] + else: + for i in drop_axes: + mask_shape[i] = input_shape[i] #get mask random_tensor = paddle.uniform(