From b927ce81a7a7728d00704bc46bb57516c10d6889 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Tue, 17 Jan 2023 13:19:35 +0800 Subject: [PATCH] add test for composite with dy2st (#49873) --- .../utils/static/composite_grad_desc_maker.h | 3 + python/paddle/fluid/backward.py | 12 ++-- .../prim/vjp/static/test_comp_add_grad.py | 51 ++++++++++++++--- .../vjp/static/test_comp_add_tanh_grad.py | 55 +++++++++++++++++-- .../prim/vjp/static/test_comp_div_grad.py | 45 ++++++++++++++- .../prim/vjp/static/test_comp_sqrt_grad.py | 43 ++++++++++++++- .../prim/vjp/static/test_comp_sub_grad.py | 45 ++++++++++++++- .../prim/vjp/static/test_comp_tanh_grad.py | 43 ++++++++++++++- 8 files changed, 269 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h index e053d1465e..c2e7ca4ec5 100644 --- a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h +++ b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h @@ -477,6 +477,9 @@ class GradCompositeOpMakerBase { void RecoverOutputName(const paddle::experimental::Tensor& output, const std::string& origin_name) { if (origin_name == framework::kEmptyVarName) return; + VLOG(4) << "Recover: " + << static_cast(output.impl().get())->Name() + << " To: " << origin_name; prim::StaticCompositeContext::Instance().GetBlock()->RenameVar( static_cast(output.impl().get())->Name(), origin_name); diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 19da1ccdff..76401d5c47 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1492,11 +1492,15 @@ def _append_backward_ops_( ) # remove some backward ops - not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set) + # TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem + if not core.is_prim_enabled(): + not_need_ops = _find_not_need_ops( + grad_op_descs, ops, input_grad_names_set + ) - grad_op_descs = [ - op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops - ] + grad_op_descs = [ + op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops + ] # append op_desc in grad_op_descs to target_block op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py index a946447018..b7d7969d9a 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py @@ -21,6 +21,23 @@ import paddle from paddle.fluid import core +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x, y): + tmp = self.fc(x) + out = paddle.add(tmp, y) + return out + + @param.parameterized_class( ('primal0', 'primal1', 'dtype'), [ @@ -57,11 +74,33 @@ class TestAddGradComp(unittest.TestCase): cls.primal0 = cls.primal0.astype(cls.dtype) cls.primal1 = cls.primal1.astype(cls.dtype) - def setUp(self): - paddle.enable_static() + def train(self, use_prim, use_cinn): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.y = paddle.randn([2, 4]) + self.x.stop_gradient = False + self.y.stop_gradient = False + net = PrimeNet() + core.set_prim_enabled(use_prim) + net = apply_to_static(net, use_cinn) + out = net(self.x, self.y) + res = paddle.autograd.grad(out, [self.x, self.y]) + + return res - def tearDown(self): + def test_cinn(self): paddle.disable_static() + dy_res = self.train(use_prim=False, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + comp_st_cinn_res[i].numpy(), + dy_res[i].numpy(), + rtol=1e-7, + atol=1e-7, + ) + paddle.enable_static() def test_tanh_grad_comp(self): def actual(primal0, primal1): @@ -73,8 +112,7 @@ class TestAddGradComp(unittest.TestCase): x.stop_gradient = False y.stop_gradient = False z = paddle.add(x, y) - out = paddle.tanh(z) - res = paddle.static.gradients([out], [x, y]) + res = paddle.static.gradients([z], [x, y]) exe = paddle.static.Executor() exe.run(sp) out = exe.run( @@ -100,8 +138,7 @@ class TestAddGradComp(unittest.TestCase): x.stop_gradient = False y.stop_gradient = False z = paddle.add(x, y) - out = paddle.tanh(z) - res = paddle.static.gradients([out], [x, y]) + res = paddle.static.gradients([z], [x, y]) exe = paddle.static.Executor() exe.run(sp) out = exe.run( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py index 0325768917..45cae351a7 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py @@ -21,6 +21,24 @@ import paddle from paddle.fluid import core +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x, y): + tmp = self.fc(x) + out = paddle.add(tmp, y) + res = paddle.tanh(out) + return res + + @param.parameterized_class( ('primal0', 'primal1', 'dtype'), [ @@ -57,13 +75,37 @@ class TestDivGradComp(unittest.TestCase): cls.primal0 = cls.primal0.astype(cls.dtype) cls.primal1 = cls.primal1.astype(cls.dtype) - def setUp(self): - paddle.enable_static() + def train(self, use_prim, use_cinn): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.y = paddle.randn([2, 4]) + self.x.stop_gradient = False + self.y.stop_gradient = False + net = PrimeNet() + core.set_prim_enabled(use_prim) + net = apply_to_static(net, use_cinn) + out = net(self.x, self.y) + res = paddle.autograd.grad(out, [self.x, self.y]) + + return res - def tearDown(self): + def test_cinn(self): paddle.disable_static() + dy_res = self.train(use_prim=False, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + comp_st_cinn_res[i].numpy(), + dy_res[i].numpy(), + rtol=1e-7, + atol=1e-7, + ) + paddle.enable_static() def test_tanh_grad_comp(self): + paddle.enable_static() + def actual(primal0, primal1): core.set_prim_enabled(True) mp, sp = paddle.static.Program(), paddle.static.Program() @@ -73,7 +115,8 @@ class TestDivGradComp(unittest.TestCase): x.stop_gradient = False y.stop_gradient = False z = paddle.add(x, y) - res = paddle.static.gradients([z], [x, y]) + out = paddle.tanh(z) + res = paddle.static.gradients([out], [x, y]) exe = paddle.static.Executor() exe.run(sp) out = exe.run( @@ -99,7 +142,8 @@ class TestDivGradComp(unittest.TestCase): x.stop_gradient = False y.stop_gradient = False z = paddle.add(x, y) - res = paddle.static.gradients([z], [x, y]) + out = paddle.tanh(z) + res = paddle.static.gradients([out], [x, y]) exe = paddle.static.Executor() exe.run(sp) out = exe.run( @@ -129,6 +173,7 @@ class TestDivGradComp(unittest.TestCase): atol=0, ) core.set_prim_enabled(False) + paddle.disable_static() if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py index fde1f4549d..1d675e8bd0 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py @@ -21,6 +21,23 @@ import paddle from paddle.fluid import core +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x, y): + tmp = self.fc(x) + out = paddle.divide(tmp, y) + return out + + @param.parameterized_class( ('primal0', 'primal1', 'dtype'), [ @@ -57,11 +74,33 @@ class TestDivGradComp(unittest.TestCase): cls.primal0 = cls.primal0.astype(cls.dtype) cls.primal1 = cls.primal1.astype(cls.dtype) - def setUp(self): - paddle.enable_static() + def train(self, use_prim, use_cinn): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.y = paddle.randn([2, 4]) + self.x.stop_gradient = False + self.y.stop_gradient = False + net = PrimeNet() + core.set_prim_enabled(use_prim) + net = apply_to_static(net, use_cinn) + out = net(self.x, self.y) + res = paddle.autograd.grad(out, [self.x, self.y]) + + return res - def tearDown(self): + def test_cinn(self): paddle.disable_static() + dy_res = self.train(use_prim=False, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + comp_st_cinn_res[i].numpy(), + dy_res[i].numpy(), + rtol=1e-6, + atol=1e-6, + ) + paddle.enable_static() def test_tanh_grad_comp(self): def actual(primal0, primal1): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py index 2eae9c86e2..505a439113 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py @@ -26,6 +26,23 @@ import parameterized as param import paddle +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x): + tmp = self.fc(x) + out = paddle.sqrt(tmp) + return out + + @param.parameterized_class( ('primal', 'cotangent', 'dtype'), [ @@ -38,11 +55,31 @@ class TestSqrtGradComp(unittest.TestCase): cls.primal = cls.primal.astype(cls.dtype) cls.cotangent = cls.cotangent.astype(cls.dtype) - def setUp(self): - paddle.enable_static() + def train(self, use_prim, use_cinn): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.x.stop_gradient = False + net = PrimeNet() + core.set_prim_enabled(use_prim) + net = apply_to_static(net, use_cinn) + out = net(self.x) + res = paddle.autograd.grad(out, [self.x]) + + return res - def tearDown(self): + def test_cinn(self): paddle.disable_static() + dy_res = self.train(use_prim=False, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + comp_st_cinn_res[i].numpy(), + dy_res[i].numpy(), + rtol=1e-7, + atol=1e-7, + ) + paddle.enable_static() def test_sqrt_grad_comp(self): def actual(primal, cotangent): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py index 8baf91ba0d..f98a6af621 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py @@ -21,6 +21,23 @@ import paddle from paddle.fluid import core +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x, y): + tmp = self.fc(x) + out = paddle.subtract(tmp, y) + return out + + @param.parameterized_class( ('primal0', 'primal1', 'dtype'), [ @@ -58,11 +75,33 @@ class TestDivGradComp(unittest.TestCase): cls.primal0 = cls.primal0.astype(cls.dtype) cls.primal1 = cls.primal1.astype(cls.dtype) - def setUp(self): - paddle.enable_static() + def train(self, use_prim, use_cinn): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.y = paddle.randn([2, 4]) + self.x.stop_gradient = False + self.y.stop_gradient = False + net = PrimeNet() + core.set_prim_enabled(use_prim) + net = apply_to_static(net, use_cinn) + out = net(self.x, self.y) + res = paddle.autograd.grad(out, [self.x, self.y]) + + return res - def tearDown(self): + def test_cinn(self): paddle.disable_static() + dy_res = self.train(use_prim=False, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + comp_st_cinn_res[i].numpy(), + dy_res[i].numpy(), + rtol=1e-7, + atol=1e-7, + ) + paddle.enable_static() def test_tanh_grad_comp(self): def actual(primal0, primal1): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py index 445b371b0a..c7c9109eea 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py @@ -26,6 +26,23 @@ import parameterized as param import paddle +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x): + tmp = self.fc(x) + out = paddle.tanh(tmp) + return out + + @param.parameterized_class( ('primal', 'cotangent', 'dtype'), [ @@ -38,11 +55,31 @@ class TestTanhGradComp(unittest.TestCase): cls.primal = cls.primal.astype(cls.dtype) cls.cotangent = cls.cotangent.astype(cls.dtype) - def setUp(self): - paddle.enable_static() + def train(self, use_prim, use_cinn): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.x.stop_gradient = False + net = PrimeNet() + core.set_prim_enabled(use_prim) + net = apply_to_static(net, use_cinn) + out = net(self.x) + res = paddle.autograd.grad(out, [self.x]) + + return res - def tearDown(self): + def test_cinn(self): paddle.disable_static() + dy_res = self.train(use_prim=False, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + comp_st_cinn_res[i].numpy(), + dy_res[i].numpy(), + rtol=1e-7, + atol=1e-7, + ) + paddle.enable_static() def test_tanh_grad_comp(self): def actual(primal, cotangent): -- GitLab