未验证 提交 de002f92 编写于 作者: L levi131 提交者: GitHub

hide prim2orig in executor (#44255)

* hide prim2orig in executor

* add some test cases without param guard

* fix spell error param into program

* Use absolute path when import paddle.incubate.autograd.prim2orig
上级 fef62298
...@@ -1444,6 +1444,11 @@ class Executor(object): ...@@ -1444,6 +1444,11 @@ class Executor(object):
program._compile(scope, self.place) program._compile(scope, self.place)
ir_graph = framework.IrGraph(program._graph) ir_graph = framework.IrGraph(program._graph)
inner_program = ir_graph.to_program() inner_program = ir_graph.to_program()
else:
from paddle.incubate.autograd import prim_enabled, prim2orig
if prim_enabled() and program == default_main_program():
prim2orig()
program = self._add_feed_fetch_ops( program = self._add_feed_fetch_ops(
program=inner_program, program=inner_program,
feed=feed, feed=feed,
......
...@@ -23,6 +23,117 @@ import config ...@@ -23,6 +23,117 @@ import config
import utils import utils
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
(('matmul', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), ))
class TestWithoutProgramGuard(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs)
cls._rtol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("rtol")
cls._atol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("atol")
def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
def test_forward_grad_without_program_guard(self):
def with_program_guard():
paddle.incubate.autograd.enable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = paddle.incubate.autograd.forward_grad(
ys, static_xs, static_v)
paddle.incubate.autograd.prim2orig(mp.block(0))
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.disable_prim()
return out
def without_program_guard():
paddle.incubate.autograd.enable_prim()
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = paddle.incubate.autograd.forward_grad(
ys, static_xs, static_v)
sp = paddle.fluid.framework.default_startup_program()
mp = paddle.fluid.framework.default_main_program()
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.disable_prim()
return out
expected = with_program_guard()
actual = without_program_guard()
self.assertEqual(type(actual), type(expected))
np.testing.assert_allclose(np.concatenate(actual),
np.concatenate(expected),
rtol=self._rtol,
atol=self._atol)
def test_grad_without_program_guard(self):
def with_program_guard():
paddle.incubate.autograd.enable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
xs_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v)
paddle.incubate.autograd.prim2orig(mp.block(0))
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=xs_grad)
paddle.incubate.autograd.disable_prim()
return out
def without_program_guard():
paddle.incubate.autograd.enable_prim()
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
xs_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v)
sp = paddle.fluid.framework.default_startup_program()
mp = paddle.fluid.framework.default_main_program()
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=xs_grad)
paddle.incubate.autograd.disable_prim()
return out
expected = with_program_guard()
actual = without_program_guard()
for i, j in zip(actual, expected):
self.assertEqual(type(i), type(j))
np.testing.assert_allclose(np.concatenate(i),
np.concatenate(j),
rtol=self._rtol,
atol=self._atol)
@utils.place(config.DEVICES) @utils.place(config.DEVICES)
@utils.parameterize( @utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册