未验证 提交 8c14b02b 编写于 作者: C cyber-pioneer 提交者: GitHub

[feature] use prim flag in shell (#50309)

* add flag

* change flag

* use prim flag

* fix code

* fix softmax prim flag

* set case timeout
上级 9268f392
...@@ -78,7 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def cal_composite_grad(self, inputs): def cal_composite_grad(self, inputs):
paddle.enable_static() paddle.enable_static()
core._set_prim_all_enabled(True) core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
main_program = paddle.static.Program() main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
...@@ -109,7 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -109,7 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static() paddle.disable_static()
core._set_prim_all_enabled(False) core._set_prim_forward_enabled(False)
return res return res
def compare_backward(self): def compare_backward(self):
...@@ -142,12 +142,13 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): ...@@ -142,12 +142,13 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
def setUp(self): def setUp(self):
core._set_prim_backward_enabled(True) core._set_prim_backward_enabled(True)
self.dtypes = ["float32"] self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]] self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1] self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs): def cal_composite_grad(self, inputs):
paddle.enable_static() paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
main_program = paddle.static.Program() main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
...@@ -164,6 +165,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): ...@@ -164,6 +165,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static() paddle.disable_static()
core._set_prim_all_enabled(False)
return res return res
def compare_backward(self): def compare_backward(self):
......
...@@ -7,3 +7,8 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") ...@@ -7,3 +7,8 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach() endforeach()
if(WITH_CINN)
set_tests_properties(test_prim_flags_case PROPERTIES LABELS "RUN_TYPE=CINN")
set_tests_properties(test_prim_flags_case PROPERTIES TIMEOUT 300)
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
import unittest
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
def forward(self, x):
out = F.softmax(x)
res = paddle.exp(out)
return res
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
self.flag = None
def reset_env_flag(self):
os.environ["FLAGS_prim_backward"] = "False"
os.environ["FLAGS_prim_forward"] = "False"
if os.getenv("FLAGS_prim_all"):
del os.environ["FLAGS_prim_all"]
core.check_and_set_prim_all_enabled()
def train(self, use_cinn):
net = PrimeNet()
net = apply_to_static(net, use_cinn)
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
self.check_prim(net)
return
def check_prim(self, net):
ops = [
op.type
for op in net.forward.program_cache.last()[-1][-1]
.train_program.block(0)
.ops
]
if self.flag in ["prim_all", "cinn_prim_all"]:
self.assertTrue('softmax' not in ops)
self.assertTrue('exp_grad' not in ops)
elif self.flag in ["prim_forward", "cinn_prim_forward"]:
self.assertTrue('softmax' not in ops)
self.assertTrue('exp_grad' in ops)
elif self.flag in ["prim_backward", "cinn_prim_backward"]:
self.assertTrue('softmax' in ops)
self.assertTrue('exp_grad' not in ops)
elif self.flag == "cinn":
self.assertTrue('softmax' in ops)
self.assertTrue('exp_grad' in ops)
else:
raise TypeError
def test_cinn_prim_all(self):
"""cinn + prim forward + prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "cinn_prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
def test_prim_all(self):
"""prim forward + prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
def test_cinn_prim_forward(self):
"""cinn + prim forward"""
self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True"
self.flag = "cinn_prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
def test_prim_forward(self):
"""only prim forward"""
self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True"
self.flag = "prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
def test_cinn_prim_backward(self):
"""cinn + prim_backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "cinn_prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
def test_prim_backward(self):
"""only prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
def test_cinn(self):
"""only cinn"""
self.reset_env_flag()
self.flag = "cinn"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
if __name__ == '__main__':
unittest.main()
...@@ -571,9 +571,8 @@ class PartialProgramLayer: ...@@ -571,9 +571,8 @@ class PartialProgramLayer:
targets.append(program.global_block().var(out.name)) targets.append(program.global_block().var(out.name))
if targets: if targets:
if self._build_strategy.build_cinn_pass: # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
# TODO(Jiabin): Change this to True if we need this to be default option core.check_and_set_prim_all_enabled()
core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist()) start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist())
......
...@@ -1140,9 +1140,8 @@ class ProgramCache: ...@@ -1140,9 +1140,8 @@ class ProgramCache:
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
# NOTE(xiongkun): Need a global FLAGS to enable/disable fallback # NOTE(xiongkun): Need a global FLAGS to enable/disable fallback
enable_fallback = enable_prim enable_fallback = enable_prim
if enable_prim: # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
# TODO(Jiabin): Change this to True if we need this to be default option core.check_and_set_prim_all_enabled()
core.check_and_set_prim_all_enabled()
try: try:
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册