提交 300f36c0 编写于 作者: C cxxly 提交者: Xiaoxu Chen

[Prim] enable whitelist and blacklist for custom_vjp

上级 e4a93b05
...@@ -446,6 +446,10 @@ def __sync_stat_with_flag(flag): ...@@ -446,6 +446,10 @@ def __sync_stat_with_flag(flag):
) )
def _is_all_prim_enabled():
return _is_fwd_prim_enabled() and _is_bwd_prim_enabled()
# Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors. # Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors.
def _test_use_sync(value): def _test_use_sync(value):
__sync_stat_with_flag(value) __sync_stat_with_flag(value)
......
...@@ -1046,14 +1046,51 @@ class TestToPrim(unittest.TestCase): ...@@ -1046,14 +1046,51 @@ class TestToPrim(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
@param.parameterized.expand((({'dropout'},),)) @param.parameterized.expand((({'dropout'},),))
def test_exclude(self, exclude): def test_blacklist(self, blacklist):
program = paddle.static.Program() program = paddle.static.Program()
with paddle.static.program_guard(program): with paddle.static.program_guard(program):
x = paddle.rand((1,)) paddle.nn.functional.softmax(
y = paddle.nn.functional.dropout(x) paddle.nn.functional.dropout(paddle.rand((1,)))
primapi.to_prim(program.blocks, exclude) )
primapi.to_prim(program.blocks, blacklist=blacklist)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in blacklist)))
@param.parameterized.expand((({'dropout'},),))
def test_whitelist(self, whitelist):
program = paddle.static.Program()
with paddle.static.program_guard(program):
paddle.nn.functional.softmax(
paddle.nn.functional.dropout(paddle.rand((1,)))
)
primapi.to_prim(program.blocks, whitelist=whitelist)
ops = tuple(op.type for op in program.block(0).ops) ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in exclude))) self.assertTrue(all(tuple(op not in ops for op in whitelist)))
@param.parameterized.expand((({'softmax'}, {'softmax', 'dropout'}),))
def test_both_not_empty(self, blacklist, whitelist):
program = paddle.static.Program()
with paddle.static.program_guard(program):
paddle.nn.functional.softmax(
paddle.nn.functional.dropout(paddle.rand((1,)))
)
primapi.to_prim(
program.blocks, blacklist=blacklist, whitelist=whitelist
)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in blacklist)))
@param.parameterized.expand(((('dropout',), 'softmax'),))
def test_type_error(self, blacklist, whitelist):
program = paddle.static.Program()
with paddle.static.program_guard(program):
paddle.nn.functional.softmax(
paddle.nn.functional.dropout(paddle.rand((1,)))
)
with self.assertRaises(TypeError):
primapi.to_prim(
program.blocks, blacklist=blacklist, whitelist=whitelist
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -77,6 +77,8 @@ class TestPrimForward(unittest.TestCase): ...@@ -77,6 +77,8 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [ fwd_ops = [
op.type op.type
for op in net.forward.get_concrete_program(self.x)[1] for op in net.forward.get_concrete_program(self.x)[1]
......
...@@ -93,6 +93,8 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -93,6 +93,8 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [ fwd_ops = [
op.type op.type
for op in net.forward.get_concrete_program(self.x)[1] for op in net.forward.get_concrete_program(self.x)[1]
......
...@@ -89,6 +89,8 @@ class TestPrimForward(unittest.TestCase): ...@@ -89,6 +89,8 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [ fwd_ops = [
op.type op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[ for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
......
...@@ -100,6 +100,8 @@ class TestPrimForward(unittest.TestCase): ...@@ -100,6 +100,8 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [ fwd_ops = [
op.type op.type
for op in net.forward.get_concrete_program(self.x)[1] for op in net.forward.get_concrete_program(self.x)[1]
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid import core
from paddle.jit.dy2static import partial_program, program_translator
class TestPartiaProgramLayerHook(unittest.TestCase):
def setUp(self):
self._hook = partial_program.PartialProgramLayerHook()
def test_before_append_backward(self):
self.assertIsNone(self._hook.before_append_backward(None))
def test_after_append_backward(self):
self.assertIsNone(self._hook.after_append_backward(None, 0))
def test_after_infer(self):
self.assertIsNone(self._hook.after_infer(None))
class TestPrimHook(unittest.TestCase):
def setUp(self):
core._set_prim_all_enabled(False)
def f():
return paddle.nn.functional.dropout(paddle.rand((1,)))
concrete_program, partial_program = paddle.jit.to_static(
f
).get_concrete_program()
self._hook = program_translator.PrimHooker(
concrete_program.main_program
)
self._forward = partial_program.forward_program
self._whole = partial_program._train_program
core._set_prim_all_enabled(True)
def tearDown(self):
core._set_prim_all_enabled(False)
def test_before_append_backward(self):
self._hook.before_append_backward(self._forward)
self.assertNotIn(
'dropout', tuple(op.type for op in self._forward.blocks[0].ops)
)
def test_after_append_backward(self):
self._hook.after_append_backward(self._whole, 0)
self.assertNotIn(
'dropout_grad', tuple(op.type for op in self._whole.blocks[0].ops)
)
if __name__ == '__main__':
unittest.main()
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core, framework from paddle.fluid import core, framework
from paddle.incubate.autograd import primapi
from paddle.nn import BatchNorm from paddle.nn import BatchNorm
from paddle.tensor import ones # noqa: F401 from paddle.tensor import ones # noqa: F401
from paddle.tensor import zeros # noqa: F401 from paddle.tensor import zeros # noqa: F401
...@@ -183,7 +184,7 @@ class TestCompositeBatchNorm(unittest.TestCase): ...@@ -183,7 +184,7 @@ class TestCompositeBatchNorm(unittest.TestCase):
attrs.use_global_stats, attrs.use_global_stats,
) )
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
......
...@@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE ...@@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
np.random.seed(2023) np.random.seed(2023)
...@@ -190,7 +191,7 @@ class TestCompositeBatchNorm(unittest.TestCase): ...@@ -190,7 +191,7 @@ class TestCompositeBatchNorm(unittest.TestCase):
attrs.use_global_stats, attrs.use_global_stats,
) )
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
z = paddle.static.gradients([y], [x1]) z = paddle.static.gradients([y], [x1])
......
...@@ -19,6 +19,7 @@ import parameterized as param ...@@ -19,6 +19,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
np.random.seed(2023) np.random.seed(2023)
...@@ -154,7 +155,7 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -154,7 +155,7 @@ class TestCompositeDropout(unittest.TestCase):
input_, p, training=(not is_test), mode=mode input_, p, training=(not is_test), mode=mode
) )
if core._is_fwd_prim_enabled(): if core._is_fwd_prim_enabled():
paddle.incubate.autograd.to_prim(mp.blocks) primapi.to_prim(mp.blocks)
grad = paddle.static.gradients(output, input_)[0] grad = paddle.static.gradients(output, input_)[0]
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(sp) exe.run(sp)
......
...@@ -22,6 +22,7 @@ np.random.seed(2013) ...@@ -22,6 +22,7 @@ np.random.seed(2013)
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -89,7 +90,7 @@ class TestCompositeGelu(unittest.TestCase): ...@@ -89,7 +90,7 @@ class TestCompositeGelu(unittest.TestCase):
# Ensure that gelu in original block # Ensure that gelu in original block
self.assertTrue('gelu' in fwd_ops) self.assertTrue('gelu' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that gelu is splitted into small ops # Ensure that gelu is splitted into small ops
......
...@@ -22,6 +22,7 @@ np.random.seed(2013) ...@@ -22,6 +22,7 @@ np.random.seed(2013)
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -97,7 +98,7 @@ class TestCompositeGelu(unittest.TestCase): ...@@ -97,7 +98,7 @@ class TestCompositeGelu(unittest.TestCase):
# Ensure that gelu in original block # Ensure that gelu in original block
self.assertTrue('gelu' in fwd_ops) self.assertTrue('gelu' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that gelu is splitted into small ops # Ensure that gelu is splitted into small ops
...@@ -164,7 +165,7 @@ class TestCompositeGeluPrimBackward(unittest.TestCase): ...@@ -164,7 +165,7 @@ class TestCompositeGeluPrimBackward(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
y = fn(x) y = fn(x)
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
exe = paddle.static.Executor() exe = paddle.static.Executor()
......
...@@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE ...@@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape1, shape2, shape3, dtype="float32"): def generate_data(shape1, shape2, shape3, dtype="float32"):
...@@ -98,7 +99,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -98,7 +99,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block # Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops) self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
...@@ -137,7 +138,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -137,7 +138,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block # Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops) self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
......
...@@ -22,6 +22,7 @@ from utils import SUB_TOLERANCE ...@@ -22,6 +22,7 @@ from utils import SUB_TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
TOLERANCE_NUMPY = { TOLERANCE_NUMPY = {
"float32": {"rtol": 2e-5, "atol": 2e-5}, "float32": {"rtol": 2e-5, "atol": 2e-5},
...@@ -196,7 +197,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -196,7 +197,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block # Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops) self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
...@@ -242,7 +243,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -242,7 +243,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block # Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops) self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
...@@ -341,7 +342,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): ...@@ -341,7 +342,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
y = fn(x, norm_shape, w, b) y = fn(x, norm_shape, w, b)
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
exe = paddle.static.Executor() exe = paddle.static.Executor()
...@@ -374,7 +375,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): ...@@ -374,7 +375,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
y = fn(x, norm_shape, weight, bias) y = fn(x, norm_shape, weight, bias)
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
exe = paddle.static.Executor() exe = paddle.static.Executor()
...@@ -480,7 +481,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase): ...@@ -480,7 +481,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block # Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops) self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
...@@ -532,7 +533,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase): ...@@ -532,7 +533,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase):
) )
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
exe = paddle.static.Executor() exe = paddle.static.Executor()
......
...@@ -20,6 +20,7 @@ from utils import TOLERANCE ...@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.tensor as tensor import paddle.tensor as tensor
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -93,7 +94,7 @@ class TestCompositeMean(unittest.TestCase): ...@@ -93,7 +94,7 @@ class TestCompositeMean(unittest.TestCase):
# Ensure that reduce_mean in original block # Ensure that reduce_mean in original block
self.assertTrue('reduce_mean' in fwd_ops) self.assertTrue('reduce_mean' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean is splitted into small ops # Ensure that reduce_mean is splitted into small ops
......
...@@ -20,6 +20,7 @@ from utils import TOLERANCE ...@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.tensor as tensor import paddle.tensor as tensor
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -99,7 +100,7 @@ class TestCompositeMean(unittest.TestCase): ...@@ -99,7 +100,7 @@ class TestCompositeMean(unittest.TestCase):
# Ensure that reduce_mean in original block # Ensure that reduce_mean in original block
self.assertTrue('reduce_mean' in fwd_ops) self.assertTrue('reduce_mean' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean is splitted into small ops # Ensure that reduce_mean is splitted into small ops
...@@ -173,7 +174,7 @@ class TestCompositeMeanPrimBackward(unittest.TestCase): ...@@ -173,7 +174,7 @@ class TestCompositeMeanPrimBackward(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
y = fn(x) y = fn(x)
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
exe = paddle.static.Executor() exe = paddle.static.Executor()
......
...@@ -20,6 +20,7 @@ from utils import TOLERANCE ...@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -87,7 +88,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -87,7 +88,7 @@ class TestCompositeSoftmax(unittest.TestCase):
# Ensure that softmax in original block # Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops) self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops # Ensure that softmax is splitted into small ops
......
...@@ -20,6 +20,7 @@ from utils import TOLERANCE ...@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -93,7 +94,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -93,7 +94,7 @@ class TestCompositeSoftmax(unittest.TestCase):
# Ensure that softmax in original block # Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops) self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops # Ensure that softmax is splitted into small ops
...@@ -158,7 +159,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): ...@@ -158,7 +159,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
y = fn(x) y = fn(x)
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
exe = paddle.static.Executor() exe = paddle.static.Executor()
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
class TestPrimFlags(unittest.TestCase): class TestPrimFlags(unittest.TestCase):
...@@ -64,6 +65,12 @@ class TestPrimFlags(unittest.TestCase): ...@@ -64,6 +65,12 @@ class TestPrimFlags(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
core._test_use_sync("aaaa") core._test_use_sync("aaaa")
core._set_prim_all_enabled(True)
self.assertTrue(core._is_all_prim_enabled())
core._set_prim_all_enabled(False)
self.assertFalse(core._is_all_prim_enabled())
class TestPrimBlacklistFlags(unittest.TestCase): class TestPrimBlacklistFlags(unittest.TestCase):
def not_in_blacklist(self): def not_in_blacklist(self):
...@@ -83,7 +90,7 @@ class TestPrimBlacklistFlags(unittest.TestCase): ...@@ -83,7 +90,7 @@ class TestPrimBlacklistFlags(unittest.TestCase):
# Ensure that softmax in original block # Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops) self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops # Ensure that softmax is splitted into small ops
...@@ -113,7 +120,7 @@ class TestPrimBlacklistFlags(unittest.TestCase): ...@@ -113,7 +120,7 @@ class TestPrimBlacklistFlags(unittest.TestCase):
# Ensure that softmax in original block # Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops) self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops # Ensure that softmax is splitted into small ops
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.autograd import primapi
paddle.framework.random._manual_program_seed(2023) paddle.framework.random._manual_program_seed(2023)
...@@ -49,7 +50,7 @@ class TestCompositeCopyOp(unittest.TestCase): ...@@ -49,7 +50,7 @@ class TestCompositeCopyOp(unittest.TestCase):
# Ensure that dropout in original block # Ensure that dropout in original block
self.assertTrue('dropout' in fwd_ops) self.assertTrue('dropout' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that dropout is not splitted into small ops # Ensure that dropout is not splitted into small ops
......
...@@ -56,7 +56,7 @@ class TestCustomVJP(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestCustomVJP(unittest.TestCase):
'elementwise_mul', 'elementwise_mul',
'scale', 'scale',
'cast', 'cast',
'fill_constant', 'fill_any_like',
'cast', 'cast',
'elementwise_mul', 'elementwise_mul',
'fill_constant', 'fill_constant',
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
from paddle.incubate.autograd import primapi
from paddle.jit.dy2static.utils import parse_arg_and_kwargs from paddle.jit.dy2static.utils import parse_arg_and_kwargs
...@@ -588,7 +589,7 @@ class PrimForwardChecker: ...@@ -588,7 +589,7 @@ class PrimForwardChecker:
args, len(inputs_sig) args, len(inputs_sig)
) )
ret = flatten(_as_list(self.python_api(*args))) ret = flatten(_as_list(self.python_api(*args)))
paddle.incubate.autograd.to_prim(main_program.blocks) primapi.to_prim(main_program.blocks)
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(startup_program) exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret) ret = exe.run(main_program, feed=feed, fetch_list=ret)
...@@ -1018,7 +1019,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1018,7 +1019,7 @@ class PrimGradChecker(PrimForwardChecker):
outputs_dict = self.get_output_dict( outputs_dict = self.get_output_dict(
self.outputs, fw_outs, outputs_sig self.outputs, fw_outs, outputs_sig
) )
paddle.incubate.autograd.to_prim(main_program.blocks) primapi.to_prim(main_program.blocks)
ys = [] ys = []
if isinstance(self.output_names, list): if isinstance(self.output_names, list):
for output_name in self.output_names: for output_name in self.output_names:
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .functional import Hessian, Jacobian, jvp, vjp from .functional import Hessian, Jacobian, jvp, vjp
from .primapi import forward_grad, grad, to_prim from .primapi import forward_grad, grad
from .primx import prim2orig from .primx import prim2orig
from .utils import disable_prim, enable_prim, prim_enabled from .utils import disable_prim, enable_prim, prim_enabled
...@@ -25,5 +25,4 @@ __all__ = [ # noqa ...@@ -25,5 +25,4 @@ __all__ = [ # noqa
'disable_prim', 'disable_prim',
'forward_grad', 'forward_grad',
'grad', 'grad',
'to_prim',
] ]
...@@ -217,11 +217,18 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -217,11 +217,18 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only @framework.static_only
def to_prim(blocks, exclude=frozenset()): def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops. """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
The operators in blacklist will be excluded from program when lowering into primitives, and only the
operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means
an operator both in blacklist and whitelist will not be lowering.
The finally set that will be lowering is:
(blocks.ops & ops have decomposite rule & whitelist) - blacklist
Args: Args:
exclude(frozenset): The Operators that will be exclude in lowering. blacklist(frozenset): The Operators that will be exclude when lowering into primitives.
whitelist(frozenset): Only the operators in whitelist will be lowering into primitives.
""" """
if not core._is_fwd_prim_enabled(): if not core._is_fwd_prim_enabled():
return return
...@@ -239,15 +246,28 @@ def to_prim(blocks, exclude=frozenset()): ...@@ -239,15 +246,28 @@ def to_prim(blocks, exclude=frozenset()):
raise TypeError( raise TypeError(
f"Expect block or sequence of blocks, but got {type(blocks)}." f"Expect block or sequence of blocks, but got {type(blocks)}."
) )
if not isinstance(exclude, (set, frozenset)): if not isinstance(blacklist, (set, frozenset)):
raise TypeError(
f'Expected type of blacklisst is set|frozenset, but got {type(blacklist)}.'
)
if not isinstance(whitelist, (set, frozenset)):
raise TypeError( raise TypeError(
f'Expected type of exclude is set|frozenset, but got {type(exclude)}.' f'Expected type of whiltelist is set|frozenset, but got {type(whitelist)}.'
) )
blacklist = prim_config["forward_blacklist"] | blacklist
with framework.program_guard(main_program): with framework.program_guard(main_program):
print("Lowering composite forward ops begin...") print("Lowering composite forward ops begin...")
primx._lower_composite(
blocks, prim_config["forward_blacklist"] | exclude if len(blacklist) > 0 and len(whitelist) > 0:
) filter_ = lambda x: x.type in whitelist and x.type not in blacklist
elif len(blacklist) > 0 and len(whitelist) == 0:
filter_ = lambda x: x.type not in blacklist
elif len(blacklist) == 0 and len(whitelist) > 0:
filter_ = lambda x: x.type in whitelist
else:
filter_ = lambda x: True
primx._lower_composite(blocks, filter_)
replace_ops = prim_config["composite_ops_record"] replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}") print(f"Lowering composite forward ops finish: {replace_ops}")
...@@ -550,8 +550,11 @@ def _lower(block, reverse, blacklist): ...@@ -550,8 +550,11 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp() block._sync_with_cpp()
def _lower_composite(block, blacklist=frozenset()): def _lower_composite(
# Some functions which are only used in _lower. block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True
):
"""The operators in block wich satisfy the filter conditon will be decomposite into primitives."""
def bind(args, to_bind, value_table): def bind(args, to_bind, value_table):
for i in range(len(args)): for i in range(len(args)):
if isinstance(args[i], list): if isinstance(args[i], list):
...@@ -603,7 +606,7 @@ def _lower_composite(block, blacklist=frozenset()): ...@@ -603,7 +606,7 @@ def _lower_composite(block, blacklist=frozenset()):
for op_idx in range(len(block.ops)): for op_idx in range(len(block.ops)):
op = block.ops[op_idx] op = block.ops[op_idx]
ops_to_remove.append(op_idx) ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist: if lookup_fn(op.type) is not None and filter_(op):
change = True change = True
op_name = op.type op_name = op.type
prim_config["composite_ops_record"].add(op_name) prim_config["composite_ops_record"].add(op_name)
...@@ -681,12 +684,12 @@ def _lower_composite(block, blacklist=frozenset()): ...@@ -681,12 +684,12 @@ def _lower_composite(block, blacklist=frozenset()):
# composite ops may contain other composite ops, thus, call _lower_composite again. # composite ops may contain other composite ops, thus, call _lower_composite again.
if change: if change:
_lower_composite(block, blacklist) _lower_composite(block, filter_)
return return
elif isinstance(block, typing.Sequence): elif isinstance(block, typing.Sequence):
for item in block: for item in block:
_lower_composite(item, blacklist) _lower_composite(item, filter_)
return return
else: else:
raise TypeError raise TypeError
......
...@@ -144,15 +144,13 @@ class ProgramInfo: ...@@ -144,15 +144,13 @@ class ProgramInfo:
class PartialProgramLayerHook: class PartialProgramLayerHook:
def before_append_backward(self, partial_program_layer, forward_program): def before_append_backward(self, forward_program):
... ...
def after_append_backward( def after_append_backward(self, whole_program, backward_start_idx):
self, partial_program_layer, whole_program, backward_start_idx
):
... ...
def after_infer(self, partial_program_layer, infer_program): def after_infer(self, infer_program):
... ...
...@@ -264,7 +262,7 @@ class PartialProgramLayer: ...@@ -264,7 +262,7 @@ class PartialProgramLayer:
for_test=is_infer_mode for_test=is_infer_mode
) )
if self._hooker: if self._hooker:
infer_program = self._hooker.after_infer(self, infer_program) infer_program = self._hooker.after_infer(infer_program)
return infer_program return infer_program
else: else:
train_program = self._append_backward_desc( train_program = self._append_backward_desc(
...@@ -298,11 +296,9 @@ class PartialProgramLayer: ...@@ -298,11 +296,9 @@ class PartialProgramLayer:
pure_fp16_program, self._amp_list, use_fp16_guard=False pure_fp16_program, self._amp_list, use_fp16_guard=False
) )
core.check_and_set_prim_all_enabled()
from paddle.incubate.autograd.primapi import to_prim
to_prim(pure_fp16_program.blocks)
if is_infer_mode: if is_infer_mode:
if self._hooker:
pure_fp16_program = self._hooker.after_infer(pure_fp16_program)
return pure_fp16_program return pure_fp16_program
else: else:
train_pure_fp16_program = self._append_backward_desc( train_pure_fp16_program = self._append_backward_desc(
...@@ -314,7 +310,6 @@ class PartialProgramLayer: ...@@ -314,7 +310,6 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_program(self): def _create_forward_backward_train_program(self):
whole_program = self._train_program whole_program = self._train_program
# _, forward_end_op_index = self._infer_info('fp32', self._create_program)
forward_end_op_index = self.get_forward_end_op_idx(whole_program) forward_end_op_index = self.get_forward_end_op_idx(whole_program)
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
...@@ -437,7 +432,9 @@ class PartialProgramLayer: ...@@ -437,7 +432,9 @@ class PartialProgramLayer:
return _param_grad_names(self._train_program.desc, self._params) return _param_grad_names(self._train_program.desc, self._params)
def get_forward_end_op_idx(self, program): def get_forward_end_op_idx(self, program):
return self._forward_end_index_map[_hash_with_id(program, self)] return self._forward_end_index_map[
paddle.utils._hash_with_id(program, self)
]
@LazyInitialized @LazyInitialized
def _out_grad_names(self): def _out_grad_names(self):
...@@ -637,7 +634,7 @@ class PartialProgramLayer: ...@@ -637,7 +634,7 @@ class PartialProgramLayer:
# make sure all status of is_test are False in train mode. # make sure all status of is_test are False in train mode.
program = _change_is_test_status(main_program.clone(), is_test=False) program = _change_is_test_status(main_program.clone(), is_test=False)
if self._hooker: if self._hooker:
program = self._hooker.before_append_backward(self, program) program = self._hooker.before_append_backward(program)
targets = [] targets = []
for out in self._outputs.tolist(): for out in self._outputs.tolist():
if isinstance(out, framework.Variable): if isinstance(out, framework.Variable):
...@@ -652,12 +649,14 @@ class PartialProgramLayer: ...@@ -652,12 +649,14 @@ class PartialProgramLayer:
if self._hooker: if self._hooker:
program, start_idx = self._hooker.after_append_backward( program, start_idx = self._hooker.after_append_backward(
self, program, start_idx program, start_idx
) )
self.prepare_gradient_aggregation(start_idx, main_program, program) self.prepare_gradient_aggregation(
start_idx + 1, main_program, program
)
self._forward_end_index_map[ self._forward_end_index_map[
_hash_with_id(program, self) paddle.utils._hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist()) ] = start_idx - len(self._outputs.tolist())
return program return program
......
...@@ -19,7 +19,7 @@ import threading ...@@ -19,7 +19,7 @@ import threading
import warnings import warnings
import weakref import weakref
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard from paddle.amp.auto_cast import _in_amp_guard
from paddle.fluid import _non_static_mode, core, framework from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
...@@ -194,7 +194,7 @@ class CacheKey: ...@@ -194,7 +194,7 @@ class CacheKey:
input_args_with_spec, input_args_with_spec,
input_kwargs_with_spec, input_kwargs_with_spec,
class_instance, class_instance,
**kwargs **kwargs,
): ):
""" """
Initializes a cache key. Initializes a cache key.
...@@ -568,7 +568,7 @@ class StaticFunction: ...@@ -568,7 +568,7 @@ class StaticFunction:
self._class_instance, self._class_instance,
**self._kwargs, **self._kwargs,
with_hook=with_hook, with_hook=with_hook,
is_train=is_train is_train=is_train,
) )
# 3. check whether hit the cache or build a new program for the input arguments # 3. check whether hit the cache or build a new program for the input arguments
...@@ -674,7 +674,7 @@ class StaticFunction: ...@@ -674,7 +674,7 @@ class StaticFunction:
concrete_program, _ = self.get_concrete_program( concrete_program, _ = self.get_concrete_program(
*desired_input_spec, *desired_input_spec,
with_hook=with_hook, with_hook=with_hook,
is_train=self._is_train_mode() is_train=self._is_train_mode(),
) )
return concrete_program return concrete_program
else: else:
...@@ -946,7 +946,7 @@ class ConcreteProgram: ...@@ -946,7 +946,7 @@ class ConcreteProgram:
function, function,
main_program, main_program,
startup_program=None, startup_program=None,
**kwargs **kwargs,
): ):
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
...@@ -1049,7 +1049,7 @@ class ConcreteProgram: ...@@ -1049,7 +1049,7 @@ class ConcreteProgram:
function=dygraph_function, function=dygraph_function,
main_program=main_program, main_program=main_program,
startup_program=startup_program, startup_program=startup_program,
**kwargs **kwargs,
) )
...@@ -1152,7 +1152,7 @@ class ProgramCache: ...@@ -1152,7 +1152,7 @@ class ProgramCache:
input_spec=cache_key.input_args_with_spec, input_spec=cache_key.input_args_with_spec,
input_kwargs_spec=cache_key.input_kwargs_with_spec, input_kwargs_spec=cache_key.input_kwargs_with_spec,
class_instance=cache_key.class_instance, class_instance=cache_key.class_instance,
**cache_key.kwargs **cache_key.kwargs,
) )
except Exception as e: except Exception as e:
if enable_fallback: if enable_fallback:
...@@ -1182,48 +1182,11 @@ class ProgramCache: ...@@ -1182,48 +1182,11 @@ class ProgramCache:
) )
) )
class PrimHooker(PartialProgramLayerHook):
def __init__(self):
self.custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
self.custom_vjps = {
op.type
for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
def before_append_backward(
self, partial_program_layer, forward_program
):
if core._is_fwd_prim_enabled():
to_prim(forward_program.block(0), self.custom_vjps)
return forward_program
def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
)
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
to_prim(whole_program.block(0))
new_start_index = (
len(whole_program.block(0).ops) - backward_length
)
return whole_program, new_start_index
def after_infer(self, partial_program_layer, infer_program):
if core._is_fwd_prim_enabled():
to_prim(infer_program.block(0))
return infer_program
partial_program = partial_program_from(concrete_program) partial_program = partial_program_from(concrete_program)
if ( if core._is_fwd_prim_enabled() and not _in_amp_guard():
core._is_fwd_prim_enabled() partial_program.set_hooker(
and not _in_amp_guard() PrimHooker(concrete_program.main_program)
and not _in_pure_fp16_guard() )
):
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program return concrete_program, partial_program
def __getitem__(self, item): def __getitem__(self, item):
...@@ -1279,6 +1242,38 @@ class ProgramCache: ...@@ -1279,6 +1242,38 @@ class ProgramCache:
self._caches = collections.OrderedDict() self._caches = collections.OrderedDict()
class PrimHooker(PartialProgramLayerHook):
def __init__(self, original_program):
if len(original_program.blocks) > 1:
raise ValueError(
'The primitive mode only support one block currently.'
)
self.custom_vjps = set()
if core._is_all_prim_enabled():
self.custom_vjps = {
op.type
for op in original_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
def before_append_backward(self, forward_program):
if core._is_fwd_prim_enabled():
_to_prim(forward_program.blocks, blacklist=self.custom_vjps)
return forward_program
def after_append_backward(self, whole_program, backward_start_idx):
backward_length = len(whole_program.block(0).ops) - backward_start_idx
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
_to_prim(whole_program.blocks, whitelist=self.custom_vjps)
new_start_index = len(whole_program.block(0).ops) - backward_length
return whole_program, new_start_index
def after_infer(self, infer_program):
if core._is_fwd_prim_enabled():
_to_prim(infer_program.block(0))
return infer_program
class ProgramTranslator: class ProgramTranslator:
""" """
Class to translate dygraph function into static graph function. The object Class to translate dygraph function into static graph function. The object
...@@ -1696,8 +1691,9 @@ def enable_to_static(enable_to_static_bool): ...@@ -1696,8 +1691,9 @@ def enable_to_static(enable_to_static_bool):
@switch_to_static_graph @switch_to_static_graph
def to_prim(blocks, exclude=frozenset()): def _to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
"""Swith to static graph and call to_prim."""
# TODO(Aurelius84): Fix this cycle import problem # TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi from paddle.incubate.autograd import primapi
primapi.to_prim(blocks, exclude) primapi.to_prim(blocks, blacklist=blacklist, whitelist=whitelist)
...@@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): ...@@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
): ):
op = program_desc.block(0).op(i) op = program_desc.block(0).op(i)
if op.type() in ['fill_any_like', "fill_constant"]: if op.type() == 'fill_any_like':
var_name = op.output('Out')[0] var_name = op.output('Out')[0]
names.append(var_name) names.append(var_name)
return names return names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册