未验证 提交 1be2cbeb 编写于 作者: X Xiaoxu Chen 提交者: GitHub

【prim】add dropout composite rule (#50497)

* map output from composite rule to origin op

add mean layer_norm dropout op map

add input map check

composite softmax support input shape []

* polish log

* [prim] add dropout composite rule

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
上级 2fa91d71
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
np.random.seed(2023)
place = (
paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
)
@param.parameterized_class(
('name', 'x', 'p', 'is_test', 'mode', 'seed', 'dtype', 'place'),
(
(
'fp32',
np.random.rand(100000),
0.3,
False,
'upscale_in_train',
1002,
'float32',
place,
),
(
'fp64',
np.random.rand(100000),
0.7,
False,
'upscale_in_train',
9999,
'float64',
place,
),
(
'is_test=True',
np.random.rand(100000),
0.5,
True,
'upscale_in_train',
1002,
'float32',
place,
),
(
'p=1.0',
np.random.rand(100000),
1.0,
True,
'upscale_in_train',
1002,
'float32',
place,
),
(
'p=1.0,test=False',
np.random.rand(100000),
1.0,
False,
'upscale_in_train',
1002,
'float32',
place,
),
(
'p=0.0',
np.random.rand(100000),
1.0,
True,
'upscale_in_train',
1002,
'float32',
place,
),
(
'downgrade_train',
np.random.rand(100000),
0.5,
False,
'downscale_in_infer',
1002,
'float32',
place,
),
(
'fp32_cpu',
np.random.rand(100000),
0.6,
False,
'upscale_in_train',
9899,
'float64',
paddle.CPUPlace(),
),
(
'fp64_cpu',
np.random.rand(100000),
0.6,
False,
'upscale_in_train',
9899,
'float64',
paddle.CPUPlace(),
),
(
'downgrade_train_cpu',
np.random.rand(100000),
0.5,
False,
'downscale_in_infer',
1002,
'float32',
paddle.CPUPlace(),
),
),
)
class TestCompositeDropout(unittest.TestCase):
@classmethod
def setUpClass(cls):
paddle.enable_static()
cls.x = cls.x.astype(cls.dtype)
@classmethod
def tearDownClass(cls):
paddle.disable_static()
def test_comp(self):
def dropout(x, p, is_test, mode, seed=0):
paddle.seed(seed)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input_ = paddle.static.data('x', shape=x.shape, dtype=x.dtype)
input_.stop_gradient = False
output = paddle.nn.functional.dropout(
input_, p, training=(not is_test), mode=mode
)
if core._is_fwd_prim_enabled():
paddle.incubate.autograd.to_prim(mp.blocks)
grad = paddle.static.gradients(output, input_)[0]
exe = paddle.static.Executor(self.place)
exe.run(sp)
fwd, rev = exe.run(
mp, feed={input_.name: x}, fetch_list=[output, grad]
)
return fwd, rev, mp
core._set_prim_forward_enabled(False)
desired_fwd, desired_rev, _ = dropout(
self.x, self.p, self.is_test, self.mode, self.seed
)
core._set_prim_forward_enabled(True)
actual_fwd, actual_rev, prog = dropout(
self.x, self.p, self.is_test, self.mode, self.seed
)
self.assertTrue('dropout' not in [op.type for op in prog.block(0).ops])
np.testing.assert_allclose(
actual_fwd.sum(),
desired_fwd.sum(),
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed
atol=0,
)
np.testing.assert_allclose(
actual_rev.sum(),
desired_rev.sum(),
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed
atol=0,
)
core._set_prim_all_enabled(True)
actual_fwd, actual_rev, _ = dropout(
self.x, self.p, self.is_test, self.mode, self.seed
)
np.testing.assert_allclose(
actual_fwd.sum(),
desired_fwd.sum(),
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed
atol=0,
)
np.testing.assert_allclose(
actual_rev.sum(),
desired_rev.sum(),
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed
atol=0,
)
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,8 @@
import functools
import operator
from paddle.fluid import core
from .primitives import * # noqa: F403
from .primreg import REGISTER_COMPOSITE, lookup_composite
......@@ -178,3 +180,46 @@ def mean_composite(x, axis, keepdim):
dtype=sum_x.dtype,
)
return divide(sum_x, norm)
@REGISTER_COMPOSITE('dropout')
def dropout_composite(x, seed_tensor, p, is_test, mode, seed, fix_seed):
"""define composite rule of op dropout.
upscale_in_train:
train: out = input * mask / ( 1.0 - p )
inference: out = input
downscale_in_infer
train: out = input * mask
inference: out = input * (1.0 - p)
"""
fix_seed = True if fix_seed is None else fix_seed
seed = seed if fix_seed else 0
upscale_in_train = mode == "upscale_in_train"
mask = bernoulli(shape=x.shape, dtype=x.dtype, p=p, seed=seed)
if upscale_in_train:
if not is_test:
# Process p=1.0 for avoid devide zero error (x*mask/(1.0-p))
if p == 1.0:
return 0.0 * x, zeros(x.shape, core.VarDesc.VarType.UINT8)
else:
return x * mask / (1.0 - p), cast(
mask, core.VarDesc.VarType.UINT8
)
else:
return assign(x), cast(mask, core.VarDesc.VarType.UINT8)
else:
if not is_test:
return x * mask, cast(mask, core.VarDesc.VarType.UINT8)
else:
return x * (1.0 - p), cast(mask, core.VarDesc.VarType.UINT8)
def bernoulli(shape, dtype, p, seed=0):
return cast(
greater_equal(
uniform(shape, dtype, min=0.0, max=1.0, seed=seed),
fill_constant(shape, dtype, p),
),
dtype,
)
......@@ -33,6 +33,7 @@ from paddle.tensor import erfinv # noqa: F401
from paddle.tensor import exp # noqa: F401
from paddle.tensor import expm1 # noqa: F401
from paddle.tensor import full # noqa: F401
from paddle.tensor import greater_equal # noqa: F401
from paddle.tensor import lgamma # noqa: F401
from paddle.tensor import log # noqa: F401
from paddle.tensor import log1p # noqa: F401
......@@ -55,6 +56,7 @@ from paddle.tensor import subtract # noqa: F401
from paddle.tensor import sum # noqa: F401
from paddle.tensor import tan # noqa: F401
from paddle.tensor import tanh # noqa: F401
from paddle.tensor import uniform # noqa: F401
from paddle.tensor import zeros # noqa: F401
from paddle.tensor.creation import assign # noqa: F401
from paddle.tensor.manipulation import cast # noqa: F401
......@@ -116,5 +118,7 @@ others = [
'fill_constant',
'reshape',
'full',
'uniform',
'greater_equal',
]
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册