未验证 提交 3a0c550f 编写于 作者: S smallv0221 提交者: GitHub

Fix dropout static when axis != None (#37223) (#37589)

* fix dropout static when axis != None

* update dropout test

* add dropout test

* fix test

* Update test_dropout_op.py

* Update test_dropout_op.py

* fix testcase

* fix testcase

* Update test_dropout_op.py

* fix testcase

* fix testcase

* optimize perf

* add new test

* fix testcase
上级 7d9c669f
......@@ -333,7 +333,7 @@ class TestDropoutFAPI(unittest.TestCase):
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[40, 40], dtype="float32")
input = fluid.data(name="input", shape=[-1, -1], dtype="float32")
res1 = paddle.nn.functional.dropout(x=input, p=0., training=False)
res2 = paddle.nn.functional.dropout(
x=input, p=0., axis=0, training=True, mode='upscale_in_train')
......@@ -380,7 +380,10 @@ class TestDropoutFAPI(unittest.TestCase):
training=False,
mode='upscale_in_train')
in_np = np.random.random([40, 40]).astype("float32")
res13 = paddle.nn.functional.dropout(
x=input, p=0.7, axis=1, training=True, mode='upscale_in_train')
in_np = np.ones([40, 40]).astype("float32")
res_np = in_np
res_np2 = np.zeros_like(in_np)
......@@ -398,6 +401,9 @@ class TestDropoutFAPI(unittest.TestCase):
feed={"input": in_np},
fetch_list=[res10])
self.assertTrue(np.allclose(fetches2[0], res_np2))
fetches3 = exe.run(fluid.default_main_program(),
feed={"input": in_np},
fetch_list=[res13])
def test_static(self):
for place in self.places:
......@@ -471,6 +477,12 @@ class TestDropoutFAPI(unittest.TestCase):
axis=(0, 1),
training=False,
mode='upscale_in_train')
res13 = paddle.nn.functional.dropout(
x=input,
p=0.5,
axis=1,
training=True,
mode='upscale_in_train')
res_list = [
res1, res2, res3, res4, res5, res6, res7, res8, res9, res11,
......
......@@ -939,6 +939,8 @@ def dropout(x,
#get mask shape
input_shape = x.shape
if not in_dygraph_mode():
input_shape_tensor = paddle.shape(x)
drop_axes = [axis] if isinstance(axis, int) else list(axis)
if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1:
raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \
......@@ -948,6 +950,10 @@ def dropout(x,
"length of axis should not be greater than dimensions of x:{}, but get length of axis: {}".
format(len(input_shape), len(drop_axes)))
mask_shape = [1] * len(input_shape)
if not in_dygraph_mode():
for i in drop_axes:
mask_shape[i] = input_shape_tensor[i]
else:
for i in drop_axes:
mask_shape[i] = input_shape[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册