未验证 提交 a82c56a0 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Add prim backward blacklist (#56320)

* support setting backward prim blacklist

* add test case

* polish prim forward flag

* fix test case
上级 0ba4a234
...@@ -55,6 +55,13 @@ void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) { ...@@ -55,6 +55,13 @@ void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().AddSkipCompOps(op_type); StaticCompositeContext::Instance().AddSkipCompOps(op_type);
} }
void PrimCommonUtils::SetPrimBackwardBlacklist(
const std::unordered_set<std::string>& op_types) {
for (const auto& item : op_types) {
StaticCompositeContext::Instance().AddSkipCompOps(item);
}
}
void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) { void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().RemoveSkipCompOps(op_type); StaticCompositeContext::Instance().RemoveSkipCompOps(op_type);
} }
......
...@@ -30,6 +30,8 @@ class PrimCommonUtils { ...@@ -30,6 +30,8 @@ class PrimCommonUtils {
static void SetAllPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled);
static size_t CheckSkipCompOps(const std::string& op_type); static size_t CheckSkipCompOps(const std::string& op_type);
static void AddSkipCompOps(const std::string& op_type); static void AddSkipCompOps(const std::string& op_type);
static void SetPrimBackwardBlacklist(
const std::unordered_set<std::string>& op_types);
static void RemoveSkipCompOps(const std::string& op_type); static void RemoveSkipCompOps(const std::string& op_type);
static void SetTargetGradName(const std::map<std::string, std::string>& m); static void SetTargetGradName(const std::map<std::string, std::string>& m);
}; };
......
...@@ -1419,6 +1419,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1419,6 +1419,8 @@ All parameter, weight, gradient are variables in Paddle.
defalut_val.index() - 1); defalut_val.index() - 1);
}); });
m.def("_add_skip_comp_ops", &paddle::prim::PrimCommonUtils::AddSkipCompOps); m.def("_add_skip_comp_ops", &paddle::prim::PrimCommonUtils::AddSkipCompOps);
m.def("_set_bwd_prim_blacklist",
&paddle::prim::PrimCommonUtils::SetPrimBackwardBlacklist);
m.def("_remove_skip_comp_ops", m.def("_remove_skip_comp_ops",
&paddle::prim::PrimCommonUtils::RemoveSkipCompOps); &paddle::prim::PrimCommonUtils::RemoveSkipCompOps);
m.def("get_grad_op_desc", m.def("get_grad_op_desc",
......
...@@ -308,8 +308,6 @@ try: ...@@ -308,8 +308,6 @@ try:
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name from .libpaddle import _get_phi_kernel_name
from .libpaddle import _add_skip_comp_ops
from .libpaddle import _remove_skip_comp_ops
# prim controller flags # prim controller flags
from .libpaddle import __set_bwd_prim_enabled from .libpaddle import __set_bwd_prim_enabled
...@@ -320,6 +318,9 @@ try: ...@@ -320,6 +318,9 @@ try:
from .libpaddle import _is_eager_prim_enabled from .libpaddle import _is_eager_prim_enabled
from .libpaddle import __set_eager_prim_enabled from .libpaddle import __set_eager_prim_enabled
from .libpaddle import _set_prim_target_grad_name from .libpaddle import _set_prim_target_grad_name
from .libpaddle import _add_skip_comp_ops
from .libpaddle import _set_bwd_prim_blacklist
from .libpaddle import _remove_skip_comp_ops
# custom devivce # custom devivce
from .libpaddle import _get_current_custom_device_stream from .libpaddle import _get_current_custom_device_stream
...@@ -487,23 +488,21 @@ ops_contain_none = { ...@@ -487,23 +488,21 @@ ops_contain_none = {
} }
def _set_prim_forward_blacklist(ops=None): def _set_prim_forward_blacklist(*args):
if ops is None: for item in args:
prim_config["forward_blacklist"] = [] if not isinstance(item, str):
elif isinstance(ops, str): raise TypeError("ops set in forward_blacklist must belong to str")
prim_config["forward_blacklist"].add(ops) else:
elif isinstance(ops, (list, tuple)): prim_config["forward_blacklist"].add(item)
for item in ops: return
if not isinstance(item, str):
raise TypeError(
"ops set in forward_blacklist must belong to [str, str of tuple or list]" def _set_prim_backward_blacklist(*args):
) ops = set(args)
else: for item in ops:
prim_config["forward_blacklist"].add(item) if not isinstance(item, str):
else: raise TypeError("all items in set must belong to string")
raise TypeError( _set_bwd_prim_blacklist(ops)
"ops set in forward_blacklist must belong to [str, str of tuple or list]"
)
return return
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.framework import core
from paddle.incubate.autograd import primapi from paddle.incubate.autograd import primapi
...@@ -133,12 +133,49 @@ class TestPrimBlacklistFlags(unittest.TestCase): ...@@ -133,12 +133,49 @@ class TestPrimBlacklistFlags(unittest.TestCase):
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
return return
def test_prim_forward_blackward(self): def test_prim_forward_blacklist(self):
# self.not_in_blacklist() self.not_in_blacklist()
core._set_prim_forward_blacklist("softmax") core._set_prim_forward_blacklist("softmax")
self.in_blacklist() self.in_blacklist()
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
x1 = F.softmax(x)
x2 = paddle.exp(x1)
res = paddle.nn.functional.relu(x2)
return res
class TestPrimBackwardBlacklistFlags(unittest.TestCase):
def train(self):
x = paddle.randn([2, 4])
x.stop_gradient = False
net = PrimeNet()
net = paddle.jit.to_static(net)
out = net(x)
loss = paddle.mean(out)
loss.backward()
self.check_prim(net)
return
def check_prim(self, net):
block = net.forward.program_cache.last()[-1][-1].train_program.block
ops = [op.type for op in block(0).ops]
self.assertTrue('softmax_grad' in ops)
self.assertTrue('exp_grad' in ops)
self.assertTrue('relu_grad' not in ops)
def test_prim_backward_blacklist(self):
core._set_prim_all_enabled(True)
core._set_prim_backward_blacklist("softmax", "exp")
self.train()
core._set_prim_all_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册