未验证 提交 55f73ba5 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Fix dropout2d/3d static API (#46434)

上级 cbf3f4ba
......@@ -1112,16 +1112,18 @@ class TestDropoutBackward(unittest.TestCase):
class TestDropOutWithProbTensor(unittest.TestCase):
def setUp(self):
shapes = [[10, 10], [10, 10, 10], [10, 10, 10, 10]]
self.inputs = [
np.random.random(shape).astype("float32") for shape in shapes
]
self.init_info()
self.input = np.random.random(self.shape).astype("float32")
self.place = paddle.CUDAPlace(
0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
def init_info(self):
self.shape = [10, 10]
self.api = paddle.nn.functional.dropout
def api_case(self, x):
p = paddle.assign([0.5])
out = paddle.nn.functional.dropout(x=x, p=p, training=True)
out = self.api(x=x, p=p, training=True)
return out
def run_static(self, x):
......@@ -1131,6 +1133,8 @@ class TestDropOutWithProbTensor(unittest.TestCase):
with program_guard(main_program):
input = paddle.static.data(shape=x.shape, name='x', dtype='float32')
out = self.api_case(input)
sgd = paddle.optimizer.SGD(learning_rate=0.1)
sgd.minimize(paddle.mean(out))
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': x}, fetch_list=[out])
......@@ -1144,10 +1148,23 @@ class TestDropOutWithProbTensor(unittest.TestCase):
return out
def test_p_tensor(self):
for x in self.inputs:
static_res = self.run_static(x)
dygraph_res = self.run_dygraph(x)
np.testing.assert_array_equal(static_res, dygraph_res)
static_res = self.run_static(self.input)
dygraph_res = self.run_dygraph(self.input)
np.testing.assert_array_equal(static_res, dygraph_res)
class TestDropOut2DWithProbTensor(TestDropOutWithProbTensor):
def init_info(self):
self.shape = [2, 3, 10, 10]
self.api = paddle.nn.functional.dropout2d
class TestDropOut3DWithProbTensor(TestDropOutWithProbTensor):
def init_info(self):
self.shape = [2, 3, 8, 8, 8]
self.api = paddle.nn.functional.dropout3d
class TestRandomValue(unittest.TestCase):
......
......@@ -1116,7 +1116,7 @@ def dropout(x,
dtype = x.dtype
keep_prob = 1 - p
if training:
if p == 1.:
if in_dynamic_mode() and p == 1.:
return paddle.scale(x, scale=0.)
scale_input = paddle.scale(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册