提交 300f36c0 编写于 作者: C cxxly 提交者: Xiaoxu Chen

[Prim] enable whitelist and blacklist for custom_vjp

上级 e4a93b05
......@@ -446,6 +446,10 @@ def __sync_stat_with_flag(flag):
)
def _is_all_prim_enabled():
return _is_fwd_prim_enabled() and _is_bwd_prim_enabled()
# Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors.
def _test_use_sync(value):
__sync_stat_with_flag(value)
......
......@@ -1046,14 +1046,51 @@ class TestToPrim(unittest.TestCase):
paddle.disable_static()
@param.parameterized.expand((({'dropout'},),))
def test_exclude(self, exclude):
def test_blacklist(self, blacklist):
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.rand((1,))
y = paddle.nn.functional.dropout(x)
primapi.to_prim(program.blocks, exclude)
paddle.nn.functional.softmax(
paddle.nn.functional.dropout(paddle.rand((1,)))
)
primapi.to_prim(program.blocks, blacklist=blacklist)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in blacklist)))
@param.parameterized.expand((({'dropout'},),))
def test_whitelist(self, whitelist):
program = paddle.static.Program()
with paddle.static.program_guard(program):
paddle.nn.functional.softmax(
paddle.nn.functional.dropout(paddle.rand((1,)))
)
primapi.to_prim(program.blocks, whitelist=whitelist)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in exclude)))
self.assertTrue(all(tuple(op not in ops for op in whitelist)))
@param.parameterized.expand((({'softmax'}, {'softmax', 'dropout'}),))
def test_both_not_empty(self, blacklist, whitelist):
program = paddle.static.Program()
with paddle.static.program_guard(program):
paddle.nn.functional.softmax(
paddle.nn.functional.dropout(paddle.rand((1,)))
)
primapi.to_prim(
program.blocks, blacklist=blacklist, whitelist=whitelist
)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in blacklist)))
@param.parameterized.expand(((('dropout',), 'softmax'),))
def test_type_error(self, blacklist, whitelist):
program = paddle.static.Program()
with paddle.static.program_guard(program):
paddle.nn.functional.softmax(
paddle.nn.functional.dropout(paddle.rand((1,)))
)
with self.assertRaises(TypeError):
primapi.to_prim(
program.blocks, blacklist=blacklist, whitelist=whitelist
)
if __name__ == '__main__':
......
......@@ -77,6 +77,8 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
......
......@@ -93,6 +93,8 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
......
......@@ -89,6 +89,8 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
......
......@@ -100,6 +100,8 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
# Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than
# main_program here, as main_program is original program before to_prim.
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid import core
from paddle.jit.dy2static import partial_program, program_translator
class TestPartiaProgramLayerHook(unittest.TestCase):
def setUp(self):
self._hook = partial_program.PartialProgramLayerHook()
def test_before_append_backward(self):
self.assertIsNone(self._hook.before_append_backward(None))
def test_after_append_backward(self):
self.assertIsNone(self._hook.after_append_backward(None, 0))
def test_after_infer(self):
self.assertIsNone(self._hook.after_infer(None))
class TestPrimHook(unittest.TestCase):
def setUp(self):
core._set_prim_all_enabled(False)
def f():
return paddle.nn.functional.dropout(paddle.rand((1,)))
concrete_program, partial_program = paddle.jit.to_static(
f
).get_concrete_program()
self._hook = program_translator.PrimHooker(
concrete_program.main_program
)
self._forward = partial_program.forward_program
self._whole = partial_program._train_program
core._set_prim_all_enabled(True)
def tearDown(self):
core._set_prim_all_enabled(False)
def test_before_append_backward(self):
self._hook.before_append_backward(self._forward)
self.assertNotIn(
'dropout', tuple(op.type for op in self._forward.blocks[0].ops)
)
def test_after_append_backward(self):
self._hook.after_append_backward(self._whole, 0)
self.assertNotIn(
'dropout_grad', tuple(op.type for op in self._whole.blocks[0].ops)
)
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,7 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.fluid import core, framework
from paddle.incubate.autograd import primapi
from paddle.nn import BatchNorm
from paddle.tensor import ones # noqa: F401
from paddle.tensor import zeros # noqa: F401
......@@ -183,7 +184,7 @@ class TestCompositeBatchNorm(unittest.TestCase):
attrs.use_global_stats,
)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
exe = paddle.static.Executor()
exe.run(startup_program)
......
......@@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
np.random.seed(2023)
......@@ -190,7 +191,7 @@ class TestCompositeBatchNorm(unittest.TestCase):
attrs.use_global_stats,
)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
z = paddle.static.gradients([y], [x1])
......
......@@ -19,6 +19,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
from paddle.incubate.autograd import primapi
np.random.seed(2023)
......@@ -154,7 +155,7 @@ class TestCompositeDropout(unittest.TestCase):
input_, p, training=(not is_test), mode=mode
)
if core._is_fwd_prim_enabled():
paddle.incubate.autograd.to_prim(mp.blocks)
primapi.to_prim(mp.blocks)
grad = paddle.static.gradients(output, input_)[0]
exe = paddle.static.Executor(self.place)
exe.run(sp)
......
......@@ -22,6 +22,7 @@ np.random.seed(2013)
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"):
......@@ -89,7 +90,7 @@ class TestCompositeGelu(unittest.TestCase):
# Ensure that gelu in original block
self.assertTrue('gelu' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that gelu is splitted into small ops
......
......@@ -22,6 +22,7 @@ np.random.seed(2013)
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"):
......@@ -97,7 +98,7 @@ class TestCompositeGelu(unittest.TestCase):
# Ensure that gelu in original block
self.assertTrue('gelu' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that gelu is splitted into small ops
......@@ -164,7 +165,7 @@ class TestCompositeGeluPrimBackward(unittest.TestCase):
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
......
......@@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape1, shape2, shape3, dtype="float32"):
......@@ -98,7 +99,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops
......@@ -137,7 +138,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops
......
......@@ -22,6 +22,7 @@ from utils import SUB_TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
TOLERANCE_NUMPY = {
"float32": {"rtol": 2e-5, "atol": 2e-5},
......@@ -196,7 +197,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops
......@@ -242,7 +243,7 @@ class TestCompositelayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops
......@@ -341,7 +342,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
y = fn(x, norm_shape, w, b)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
......@@ -374,7 +375,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
y = fn(x, norm_shape, weight, bias)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
......@@ -480,7 +481,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase):
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops
......@@ -532,7 +533,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase):
)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
......
......@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle
import paddle.tensor as tensor
from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"):
......@@ -93,7 +94,7 @@ class TestCompositeMean(unittest.TestCase):
# Ensure that reduce_mean in original block
self.assertTrue('reduce_mean' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean is splitted into small ops
......
......@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle
import paddle.tensor as tensor
from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"):
......@@ -99,7 +100,7 @@ class TestCompositeMean(unittest.TestCase):
# Ensure that reduce_mean in original block
self.assertTrue('reduce_mean' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean is splitted into small ops
......@@ -173,7 +174,7 @@ class TestCompositeMeanPrimBackward(unittest.TestCase):
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
......
......@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"):
......@@ -87,7 +88,7 @@ class TestCompositeSoftmax(unittest.TestCase):
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
......
......@@ -20,6 +20,7 @@ from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
def generate_data(shape, dtype="float32"):
......@@ -93,7 +94,7 @@ class TestCompositeSoftmax(unittest.TestCase):
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
......@@ -158,7 +159,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
......
......@@ -20,6 +20,7 @@ import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.incubate.autograd import primapi
class TestPrimFlags(unittest.TestCase):
......@@ -64,6 +65,12 @@ class TestPrimFlags(unittest.TestCase):
with self.assertRaises(TypeError):
core._test_use_sync("aaaa")
core._set_prim_all_enabled(True)
self.assertTrue(core._is_all_prim_enabled())
core._set_prim_all_enabled(False)
self.assertFalse(core._is_all_prim_enabled())
class TestPrimBlacklistFlags(unittest.TestCase):
def not_in_blacklist(self):
......@@ -83,7 +90,7 @@ class TestPrimBlacklistFlags(unittest.TestCase):
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
......@@ -113,7 +120,7 @@ class TestPrimBlacklistFlags(unittest.TestCase):
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
......
......@@ -18,6 +18,7 @@ import numpy as np
import paddle
from paddle.fluid import core
from paddle.incubate.autograd import primapi
paddle.framework.random._manual_program_seed(2023)
......@@ -49,7 +50,7 @@ class TestCompositeCopyOp(unittest.TestCase):
# Ensure that dropout in original block
self.assertTrue('dropout' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that dropout is not splitted into small ops
......
......@@ -56,7 +56,7 @@ class TestCustomVJP(unittest.TestCase):
'elementwise_mul',
'scale',
'cast',
'fill_constant',
'fill_any_like',
'cast',
'elementwise_mul',
'fill_constant',
......
......@@ -22,6 +22,7 @@ import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
from paddle.incubate.autograd import primapi
from paddle.jit.dy2static.utils import parse_arg_and_kwargs
......@@ -588,7 +589,7 @@ class PrimForwardChecker:
args, len(inputs_sig)
)
ret = flatten(_as_list(self.python_api(*args)))
paddle.incubate.autograd.to_prim(main_program.blocks)
primapi.to_prim(main_program.blocks)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret)
......@@ -1018,7 +1019,7 @@ class PrimGradChecker(PrimForwardChecker):
outputs_dict = self.get_output_dict(
self.outputs, fw_outs, outputs_sig
)
paddle.incubate.autograd.to_prim(main_program.blocks)
primapi.to_prim(main_program.blocks)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .functional import Hessian, Jacobian, jvp, vjp
from .primapi import forward_grad, grad, to_prim
from .primapi import forward_grad, grad
from .primx import prim2orig
from .utils import disable_prim, enable_prim, prim_enabled
......@@ -25,5 +25,4 @@ __all__ = [ # noqa
'disable_prim',
'forward_grad',
'grad',
'to_prim',
]
......@@ -217,11 +217,18 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only
def to_prim(blocks, exclude=frozenset()):
def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
The operators in blacklist will be excluded from program when lowering into primitives, and only the
operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means
an operator both in blacklist and whitelist will not be lowering.
The finally set that will be lowering is:
(blocks.ops & ops have decomposite rule & whitelist) - blacklist
Args:
exclude(frozenset): The Operators that will be exclude in lowering.
blacklist(frozenset): The Operators that will be exclude when lowering into primitives.
whitelist(frozenset): Only the operators in whitelist will be lowering into primitives.
"""
if not core._is_fwd_prim_enabled():
return
......@@ -239,15 +246,28 @@ def to_prim(blocks, exclude=frozenset()):
raise TypeError(
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
if not isinstance(exclude, (set, frozenset)):
if not isinstance(blacklist, (set, frozenset)):
raise TypeError(
f'Expected type of blacklisst is set|frozenset, but got {type(blacklist)}.'
)
if not isinstance(whitelist, (set, frozenset)):
raise TypeError(
f'Expected type of exclude is set|frozenset, but got {type(exclude)}.'
f'Expected type of whiltelist is set|frozenset, but got {type(whitelist)}.'
)
blacklist = prim_config["forward_blacklist"] | blacklist
with framework.program_guard(main_program):
print("Lowering composite forward ops begin...")
primx._lower_composite(
blocks, prim_config["forward_blacklist"] | exclude
)
if len(blacklist) > 0 and len(whitelist) > 0:
filter_ = lambda x: x.type in whitelist and x.type not in blacklist
elif len(blacklist) > 0 and len(whitelist) == 0:
filter_ = lambda x: x.type not in blacklist
elif len(blacklist) == 0 and len(whitelist) > 0:
filter_ = lambda x: x.type in whitelist
else:
filter_ = lambda x: True
primx._lower_composite(blocks, filter_)
replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}")
......@@ -550,8 +550,11 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp()
def _lower_composite(block, blacklist=frozenset()):
# Some functions which are only used in _lower.
def _lower_composite(
block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True
):
"""The operators in block wich satisfy the filter conditon will be decomposite into primitives."""
def bind(args, to_bind, value_table):
for i in range(len(args)):
if isinstance(args[i], list):
......@@ -603,7 +606,7 @@ def _lower_composite(block, blacklist=frozenset()):
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
if lookup_fn(op.type) is not None and filter_(op):
change = True
op_name = op.type
prim_config["composite_ops_record"].add(op_name)
......@@ -681,12 +684,12 @@ def _lower_composite(block, blacklist=frozenset()):
# composite ops may contain other composite ops, thus, call _lower_composite again.
if change:
_lower_composite(block, blacklist)
_lower_composite(block, filter_)
return
elif isinstance(block, typing.Sequence):
for item in block:
_lower_composite(item, blacklist)
_lower_composite(item, filter_)
return
else:
raise TypeError
......
......@@ -144,15 +144,13 @@ class ProgramInfo:
class PartialProgramLayerHook:
def before_append_backward(self, partial_program_layer, forward_program):
def before_append_backward(self, forward_program):
...
def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
def after_append_backward(self, whole_program, backward_start_idx):
...
def after_infer(self, partial_program_layer, infer_program):
def after_infer(self, infer_program):
...
......@@ -264,7 +262,7 @@ class PartialProgramLayer:
for_test=is_infer_mode
)
if self._hooker:
infer_program = self._hooker.after_infer(self, infer_program)
infer_program = self._hooker.after_infer(infer_program)
return infer_program
else:
train_program = self._append_backward_desc(
......@@ -298,11 +296,9 @@ class PartialProgramLayer:
pure_fp16_program, self._amp_list, use_fp16_guard=False
)
core.check_and_set_prim_all_enabled()
from paddle.incubate.autograd.primapi import to_prim
to_prim(pure_fp16_program.blocks)
if is_infer_mode:
if self._hooker:
pure_fp16_program = self._hooker.after_infer(pure_fp16_program)
return pure_fp16_program
else:
train_pure_fp16_program = self._append_backward_desc(
......@@ -314,7 +310,6 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_forward_backward_train_program(self):
whole_program = self._train_program
# _, forward_end_op_index = self._infer_info('fp32', self._create_program)
forward_end_op_index = self.get_forward_end_op_idx(whole_program)
assert forward_end_op_index >= 0
......@@ -437,7 +432,9 @@ class PartialProgramLayer:
return _param_grad_names(self._train_program.desc, self._params)
def get_forward_end_op_idx(self, program):
return self._forward_end_index_map[_hash_with_id(program, self)]
return self._forward_end_index_map[
paddle.utils._hash_with_id(program, self)
]
@LazyInitialized
def _out_grad_names(self):
......@@ -637,7 +634,7 @@ class PartialProgramLayer:
# make sure all status of is_test are False in train mode.
program = _change_is_test_status(main_program.clone(), is_test=False)
if self._hooker:
program = self._hooker.before_append_backward(self, program)
program = self._hooker.before_append_backward(program)
targets = []
for out in self._outputs.tolist():
if isinstance(out, framework.Variable):
......@@ -652,12 +649,14 @@ class PartialProgramLayer:
if self._hooker:
program, start_idx = self._hooker.after_append_backward(
self, program, start_idx
program, start_idx
)
self.prepare_gradient_aggregation(
start_idx + 1, main_program, program
)
self.prepare_gradient_aggregation(start_idx, main_program, program)
self._forward_end_index_map[
_hash_with_id(program, self)
paddle.utils._hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist())
return program
......
......@@ -19,7 +19,7 @@ import threading
import warnings
import weakref
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.amp.auto_cast import _in_amp_guard
from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers
......@@ -194,7 +194,7 @@ class CacheKey:
input_args_with_spec,
input_kwargs_with_spec,
class_instance,
**kwargs
**kwargs,
):
"""
Initializes a cache key.
......@@ -568,7 +568,7 @@ class StaticFunction:
self._class_instance,
**self._kwargs,
with_hook=with_hook,
is_train=is_train
is_train=is_train,
)
# 3. check whether hit the cache or build a new program for the input arguments
......@@ -674,7 +674,7 @@ class StaticFunction:
concrete_program, _ = self.get_concrete_program(
*desired_input_spec,
with_hook=with_hook,
is_train=self._is_train_mode()
is_train=self._is_train_mode(),
)
return concrete_program
else:
......@@ -946,7 +946,7 @@ class ConcreteProgram:
function,
main_program,
startup_program=None,
**kwargs
**kwargs,
):
self.inputs = inputs
self.outputs = outputs
......@@ -1049,7 +1049,7 @@ class ConcreteProgram:
function=dygraph_function,
main_program=main_program,
startup_program=startup_program,
**kwargs
**kwargs,
)
......@@ -1152,7 +1152,7 @@ class ProgramCache:
input_spec=cache_key.input_args_with_spec,
input_kwargs_spec=cache_key.input_kwargs_with_spec,
class_instance=cache_key.class_instance,
**cache_key.kwargs
**cache_key.kwargs,
)
except Exception as e:
if enable_fallback:
......@@ -1182,48 +1182,11 @@ class ProgramCache:
)
)
class PrimHooker(PartialProgramLayerHook):
def __init__(self):
self.custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
self.custom_vjps = {
op.type
for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
def before_append_backward(
self, partial_program_layer, forward_program
):
if core._is_fwd_prim_enabled():
to_prim(forward_program.block(0), self.custom_vjps)
return forward_program
def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
)
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
to_prim(whole_program.block(0))
new_start_index = (
len(whole_program.block(0).ops) - backward_length
)
return whole_program, new_start_index
def after_infer(self, partial_program_layer, infer_program):
if core._is_fwd_prim_enabled():
to_prim(infer_program.block(0))
return infer_program
partial_program = partial_program_from(concrete_program)
if (
core._is_fwd_prim_enabled()
and not _in_amp_guard()
and not _in_pure_fp16_guard()
):
partial_program.set_hooker(PrimHooker())
if core._is_fwd_prim_enabled() and not _in_amp_guard():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program)
)
return concrete_program, partial_program
def __getitem__(self, item):
......@@ -1279,6 +1242,38 @@ class ProgramCache:
self._caches = collections.OrderedDict()
class PrimHooker(PartialProgramLayerHook):
def __init__(self, original_program):
if len(original_program.blocks) > 1:
raise ValueError(
'The primitive mode only support one block currently.'
)
self.custom_vjps = set()
if core._is_all_prim_enabled():
self.custom_vjps = {
op.type
for op in original_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
def before_append_backward(self, forward_program):
if core._is_fwd_prim_enabled():
_to_prim(forward_program.blocks, blacklist=self.custom_vjps)
return forward_program
def after_append_backward(self, whole_program, backward_start_idx):
backward_length = len(whole_program.block(0).ops) - backward_start_idx
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
_to_prim(whole_program.blocks, whitelist=self.custom_vjps)
new_start_index = len(whole_program.block(0).ops) - backward_length
return whole_program, new_start_index
def after_infer(self, infer_program):
if core._is_fwd_prim_enabled():
_to_prim(infer_program.block(0))
return infer_program
class ProgramTranslator:
"""
Class to translate dygraph function into static graph function. The object
......@@ -1696,8 +1691,9 @@ def enable_to_static(enable_to_static_bool):
@switch_to_static_graph
def to_prim(blocks, exclude=frozenset()):
def _to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
"""Swith to static graph and call to_prim."""
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi
primapi.to_prim(blocks, exclude)
primapi.to_prim(blocks, blacklist=blacklist, whitelist=whitelist)
......@@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
):
op = program_desc.block(0).op(i)
if op.type() in ['fill_any_like', "fill_constant"]:
if op.type() == 'fill_any_like':
var_name = op.output('Out')[0]
names.append(var_name)
return names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册