未验证 提交 05d6be7e 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Remove enable_legacy_dygraph setting (#42363)

* [Eager] Remove enable_legacy_dygraph setting

* Add more tests
上级 c3852b08
......@@ -20,9 +20,6 @@ import unittest
from simnet_dygraph_model_v2 import BOW, HingeLoss
from paddle.fluid.framework import _enable_legacy_dygraph
_enable_legacy_dygraph()
SEED = 102
random.seed(SEED)
......
......@@ -23,7 +23,6 @@ import paddle.static as static
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard, _enable_legacy_dygraph
_enable_legacy_dygraph()
import os
from paddle import _C_ops
......@@ -979,6 +978,7 @@ class TestDropoutBackward(unittest.TestCase):
), self.cal_grad_downscale_in_infer(mask.numpy())))
def test_backward_upscale_train(self):
_enable_legacy_dygraph()
for place in self.places:
with fluid.dygraph.guard(place):
......@@ -1010,6 +1010,7 @@ class TestDropoutBackward(unittest.TestCase):
), self.cal_grad_upscale_train(mask.numpy(), prob)))
def test_backward_upscale_train_2(self):
_enable_legacy_dygraph()
for place in self.places:
with fluid.dygraph.guard(place):
......@@ -1025,6 +1026,23 @@ class TestDropoutBackward(unittest.TestCase):
np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob)))
def test_backward_upscale_train_2_eager(self):
for place in self.places:
with fluid.dygraph.guard(place):
with _test_eager_guard():
prob = 0.3
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = _C_ops.final_state_dropout(
input, None, 0.3, False, "upscale_in_train", 0, False)
out.backward()
self.assertTrue(
np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob)))
class TestRandomValue(unittest.TestCase):
def test_fixed_random_number(self):
......
......@@ -21,9 +21,6 @@ import paddle.nn.functional as F
from paddle.incubate.optimizer.functional.lbfgs import minimize_lbfgs
from paddle.fluid.framework import _enable_legacy_dygraph
_enable_legacy_dygraph()
np.random.seed(123)
......
......@@ -20,8 +20,6 @@ import os
import sys
import subprocess
import paddle
from paddle.fluid.framework import _enable_legacy_dygraph
_enable_legacy_dygraph()
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册