diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py index 9b6e5db7953565c1289800a550e7b9dca7e9b399..d8b373133280f274713bf97596bc00273a55e647 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py @@ -78,7 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase): def cal_composite_grad(self, inputs): paddle.enable_static() - core._set_prim_all_enabled(True) + core._set_prim_forward_enabled(True) startup_program = paddle.static.Program() main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): @@ -109,7 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase): exe.run(startup_program) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) paddle.disable_static() - core._set_prim_all_enabled(False) + core._set_prim_forward_enabled(False) return res def compare_backward(self): @@ -142,12 +142,13 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): def setUp(self): core._set_prim_backward_enabled(True) - self.dtypes = ["float32"] + self.dtypes = ["float32", "float64"] self.shapes = [[2, 3, 4], [2, 3]] self.axes = [-1, 0, 1] def cal_composite_grad(self, inputs): paddle.enable_static() + core._set_prim_all_enabled(True) startup_program = paddle.static.Program() main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): @@ -164,6 +165,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): exe.run(startup_program) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) paddle.disable_static() + core._set_prim_all_enabled(False) return res def compare_backward(self): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt index 72c6bbd7d05e8fdf99fce350ad15c216dcac5c92..e57c6138d22f0d21751948fc59c407ebc9d58670 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt @@ -7,3 +7,8 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() + +if(WITH_CINN) + set_tests_properties(test_prim_flags_case PROPERTIES LABELS "RUN_TYPE=CINN") + set_tests_properties(test_prim_flags_case PROPERTIES TIMEOUT 300) +endif() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags_case.py b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags_case.py new file mode 100644 index 0000000000000000000000000000000000000000..309959747e06471c2f80e9daac6cf31d463b31f0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags_case.py @@ -0,0 +1,173 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +import unittest + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + + def forward(self, x): + out = F.softmax(x) + res = paddle.exp(out) + return res + + +class TestPrimForwardAndBackward(unittest.TestCase): + """ + Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph + """ + + def setUp(self): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.x.stop_gradient = False + self.flag = None + + def reset_env_flag(self): + os.environ["FLAGS_prim_backward"] = "False" + os.environ["FLAGS_prim_forward"] = "False" + if os.getenv("FLAGS_prim_all"): + del os.environ["FLAGS_prim_all"] + core.check_and_set_prim_all_enabled() + + def train(self, use_cinn): + net = PrimeNet() + net = apply_to_static(net, use_cinn) + + out = net(self.x) + loss = paddle.mean(out) + loss.backward() + + self.check_prim(net) + + return + + def check_prim(self, net): + ops = [ + op.type + for op in net.forward.program_cache.last()[-1][-1] + .train_program.block(0) + .ops + ] + + if self.flag in ["prim_all", "cinn_prim_all"]: + self.assertTrue('softmax' not in ops) + self.assertTrue('exp_grad' not in ops) + elif self.flag in ["prim_forward", "cinn_prim_forward"]: + self.assertTrue('softmax' not in ops) + self.assertTrue('exp_grad' in ops) + elif self.flag in ["prim_backward", "cinn_prim_backward"]: + self.assertTrue('softmax' in ops) + self.assertTrue('exp_grad' not in ops) + elif self.flag == "cinn": + self.assertTrue('softmax' in ops) + self.assertTrue('exp_grad' in ops) + else: + raise TypeError + + def test_cinn_prim_all(self): + """cinn + prim forward + prim backward""" + self.reset_env_flag() + os.environ["FLAGS_prim_all"] = "True" + self.flag = "cinn_prim_all" + plat = platform.system() + if plat == "Linux": + _ = self.train(use_cinn=True) + else: + pass + + def test_prim_all(self): + """prim forward + prim backward""" + self.reset_env_flag() + os.environ["FLAGS_prim_all"] = "True" + self.flag = "prim_all" + plat = platform.system() + if plat == "Linux": + _ = self.train(use_cinn=False) + else: + pass + + def test_cinn_prim_forward(self): + """cinn + prim forward""" + + self.reset_env_flag() + + os.environ["FLAGS_prim_forward"] = "True" + self.flag = "cinn_prim_forward" + plat = platform.system() + if plat == "Linux": + _ = self.train(use_cinn=True) + else: + pass + + def test_prim_forward(self): + """only prim forward""" + self.reset_env_flag() + os.environ["FLAGS_prim_forward"] = "True" + self.flag = "prim_forward" + plat = platform.system() + if plat == "Linux": + _ = self.train(use_cinn=False) + else: + pass + + def test_cinn_prim_backward(self): + """cinn + prim_backward""" + self.reset_env_flag() + os.environ["FLAGS_prim_backward"] = "True" + self.flag = "cinn_prim_backward" + plat = platform.system() + if plat == "Linux": + _ = self.train(use_cinn=True) + else: + pass + + def test_prim_backward(self): + """only prim backward""" + self.reset_env_flag() + os.environ["FLAGS_prim_backward"] = "True" + self.flag = "prim_backward" + plat = platform.system() + if plat == "Linux": + _ = self.train(use_cinn=False) + else: + pass + + def test_cinn(self): + """only cinn""" + self.reset_env_flag() + self.flag = "cinn" + plat = platform.system() + if plat == "Linux": + _ = self.train(use_cinn=True) + else: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index fd509d74e532abd82eaf5c0e0ba8c6132b5fef86..626bdab2f5a1a680b50e921a0b26d3c90bd3fa82 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -571,9 +571,8 @@ class PartialProgramLayer: targets.append(program.global_block().var(out.name)) if targets: - if self._build_strategy.build_cinn_pass: - # TODO(Jiabin): Change this to True if we need this to be default option - core.check_and_set_prim_all_enabled() + # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. + core.check_and_set_prim_all_enabled() backward.gradients(targets=targets, inputs=[]) start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist()) diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index f32168858b9d3a1120647921ec2ae5242004e527..37b92bbd4dd12a7128a583985960ff5a66107724 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1140,9 +1140,8 @@ class ProgramCache: enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass # NOTE(xiongkun): Need a global FLAGS to enable/disable fallback enable_fallback = enable_prim - if enable_prim: - # TODO(Jiabin): Change this to True if we need this to be default option - core.check_and_set_prim_all_enabled() + # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. + core.check_and_set_prim_all_enabled() try: concrete_program = ConcreteProgram.from_func_spec( func_spec=cache_key.function_spec,