未验证 提交 a82c56a0 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Add prim backward blacklist (#56320)

* support setting backward prim blacklist

* add test case

* polish prim forward flag

* fix test case
上级 0ba4a234
......@@ -55,6 +55,13 @@ void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().AddSkipCompOps(op_type);
}
void PrimCommonUtils::SetPrimBackwardBlacklist(
const std::unordered_set<std::string>& op_types) {
for (const auto& item : op_types) {
StaticCompositeContext::Instance().AddSkipCompOps(item);
}
}
void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().RemoveSkipCompOps(op_type);
}
......
......@@ -30,6 +30,8 @@ class PrimCommonUtils {
static void SetAllPrimEnabled(bool enabled);
static size_t CheckSkipCompOps(const std::string& op_type);
static void AddSkipCompOps(const std::string& op_type);
static void SetPrimBackwardBlacklist(
const std::unordered_set<std::string>& op_types);
static void RemoveSkipCompOps(const std::string& op_type);
static void SetTargetGradName(const std::map<std::string, std::string>& m);
};
......
......@@ -1419,6 +1419,8 @@ All parameter, weight, gradient are variables in Paddle.
defalut_val.index() - 1);
});
m.def("_add_skip_comp_ops", &paddle::prim::PrimCommonUtils::AddSkipCompOps);
m.def("_set_bwd_prim_blacklist",
&paddle::prim::PrimCommonUtils::SetPrimBackwardBlacklist);
m.def("_remove_skip_comp_ops",
&paddle::prim::PrimCommonUtils::RemoveSkipCompOps);
m.def("get_grad_op_desc",
......
......@@ -308,8 +308,6 @@ try:
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name
from .libpaddle import _add_skip_comp_ops
from .libpaddle import _remove_skip_comp_ops
# prim controller flags
from .libpaddle import __set_bwd_prim_enabled
......@@ -320,6 +318,9 @@ try:
from .libpaddle import _is_eager_prim_enabled
from .libpaddle import __set_eager_prim_enabled
from .libpaddle import _set_prim_target_grad_name
from .libpaddle import _add_skip_comp_ops
from .libpaddle import _set_bwd_prim_blacklist
from .libpaddle import _remove_skip_comp_ops
# custom devivce
from .libpaddle import _get_current_custom_device_stream
......@@ -487,23 +488,21 @@ ops_contain_none = {
}
def _set_prim_forward_blacklist(ops=None):
if ops is None:
prim_config["forward_blacklist"] = []
elif isinstance(ops, str):
prim_config["forward_blacklist"].add(ops)
elif isinstance(ops, (list, tuple)):
for item in ops:
if not isinstance(item, str):
raise TypeError(
"ops set in forward_blacklist must belong to [str, str of tuple or list]"
)
else:
prim_config["forward_blacklist"].add(item)
else:
raise TypeError(
"ops set in forward_blacklist must belong to [str, str of tuple or list]"
)
def _set_prim_forward_blacklist(*args):
for item in args:
if not isinstance(item, str):
raise TypeError("ops set in forward_blacklist must belong to str")
else:
prim_config["forward_blacklist"].add(item)
return
def _set_prim_backward_blacklist(*args):
ops = set(args)
for item in ops:
if not isinstance(item, str):
raise TypeError("all items in set must belong to string")
_set_bwd_prim_blacklist(ops)
return
......
......@@ -19,7 +19,7 @@ import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.framework import core
from paddle.incubate.autograd import primapi
......@@ -133,12 +133,49 @@ class TestPrimBlacklistFlags(unittest.TestCase):
core._set_prim_forward_enabled(False)
return
def test_prim_forward_blackward(self):
# self.not_in_blacklist()
def test_prim_forward_blacklist(self):
self.not_in_blacklist()
core._set_prim_forward_blacklist("softmax")
self.in_blacklist()
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
x1 = F.softmax(x)
x2 = paddle.exp(x1)
res = paddle.nn.functional.relu(x2)
return res
class TestPrimBackwardBlacklistFlags(unittest.TestCase):
def train(self):
x = paddle.randn([2, 4])
x.stop_gradient = False
net = PrimeNet()
net = paddle.jit.to_static(net)
out = net(x)
loss = paddle.mean(out)
loss.backward()
self.check_prim(net)
return
def check_prim(self, net):
block = net.forward.program_cache.last()[-1][-1].train_program.block
ops = [op.type for op in block(0).ops]
self.assertTrue('softmax_grad' in ops)
self.assertTrue('exp_grad' in ops)
self.assertTrue('relu_grad' not in ops)
def test_prim_backward_blacklist(self):
core._set_prim_all_enabled(True)
core._set_prim_backward_blacklist("softmax", "exp")
self.train()
core._set_prim_all_enabled(False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册