未验证 提交 f25fda37 编写于 作者: S smallv0221 提交者: GitHub

Fix dropout static when axis != None (#37223)

* fix dropout static when axis != None

* update dropout test

* add dropout test

* fix test

* Update test_dropout_op.py

* Update test_dropout_op.py

* fix testcase

* fix testcase

* Update test_dropout_op.py

* fix testcase

* fix testcase

* optimize perf

* add new test

* fix testcase
上级 e2fdb080
...@@ -333,7 +333,7 @@ class TestDropoutFAPI(unittest.TestCase): ...@@ -333,7 +333,7 @@ class TestDropoutFAPI(unittest.TestCase):
def check_static_result(self, place): def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[40, 40], dtype="float32") input = fluid.data(name="input", shape=[-1, -1], dtype="float32")
res1 = paddle.nn.functional.dropout(x=input, p=0., training=False) res1 = paddle.nn.functional.dropout(x=input, p=0., training=False)
res2 = paddle.nn.functional.dropout( res2 = paddle.nn.functional.dropout(
x=input, p=0., axis=0, training=True, mode='upscale_in_train') x=input, p=0., axis=0, training=True, mode='upscale_in_train')
...@@ -380,7 +380,10 @@ class TestDropoutFAPI(unittest.TestCase): ...@@ -380,7 +380,10 @@ class TestDropoutFAPI(unittest.TestCase):
training=False, training=False,
mode='upscale_in_train') mode='upscale_in_train')
in_np = np.random.random([40, 40]).astype("float32") res13 = paddle.nn.functional.dropout(
x=input, p=0.7, axis=1, training=True, mode='upscale_in_train')
in_np = np.ones([40, 40]).astype("float32")
res_np = in_np res_np = in_np
res_np2 = np.zeros_like(in_np) res_np2 = np.zeros_like(in_np)
...@@ -398,6 +401,9 @@ class TestDropoutFAPI(unittest.TestCase): ...@@ -398,6 +401,9 @@ class TestDropoutFAPI(unittest.TestCase):
feed={"input": in_np}, feed={"input": in_np},
fetch_list=[res10]) fetch_list=[res10])
self.assertTrue(np.allclose(fetches2[0], res_np2)) self.assertTrue(np.allclose(fetches2[0], res_np2))
fetches3 = exe.run(fluid.default_main_program(),
feed={"input": in_np},
fetch_list=[res13])
def test_static(self): def test_static(self):
for place in self.places: for place in self.places:
...@@ -471,6 +477,12 @@ class TestDropoutFAPI(unittest.TestCase): ...@@ -471,6 +477,12 @@ class TestDropoutFAPI(unittest.TestCase):
axis=(0, 1), axis=(0, 1),
training=False, training=False,
mode='upscale_in_train') mode='upscale_in_train')
res13 = paddle.nn.functional.dropout(
x=input,
p=0.5,
axis=1,
training=True,
mode='upscale_in_train')
res_list = [ res_list = [
res1, res2, res3, res4, res5, res6, res7, res8, res9, res11, res1, res2, res3, res4, res5, res6, res7, res8, res9, res11,
......
...@@ -940,6 +940,8 @@ def dropout(x, ...@@ -940,6 +940,8 @@ def dropout(x,
#get mask shape #get mask shape
input_shape = x.shape input_shape = x.shape
if not in_dygraph_mode():
input_shape_tensor = paddle.shape(x)
drop_axes = [axis] if isinstance(axis, int) else list(axis) drop_axes = [axis] if isinstance(axis, int) else list(axis)
if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1: if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1:
raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \ raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \
...@@ -949,8 +951,12 @@ def dropout(x, ...@@ -949,8 +951,12 @@ def dropout(x,
"length of axis should not be greater than dimensions of x:{}, but get length of axis: {}". "length of axis should not be greater than dimensions of x:{}, but get length of axis: {}".
format(len(input_shape), len(drop_axes))) format(len(input_shape), len(drop_axes)))
mask_shape = [1] * len(input_shape) mask_shape = [1] * len(input_shape)
for i in drop_axes: if not in_dygraph_mode():
mask_shape[i] = input_shape[i] for i in drop_axes:
mask_shape[i] = input_shape_tensor[i]
else:
for i in drop_axes:
mask_shape[i] = input_shape[i]
#get mask #get mask
random_tensor = paddle.uniform( random_tensor = paddle.uniform(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册