未验证 提交 64076727 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Support amp logic for layer_norm and softmax (#51473)

* support amp logic for layer_norm and softmax

* fix layer_norm amp

* fix layernorm api and dropout fp16

* fix layernorm api and dropout fp16

* fix bn, ln dtype in float16

* fix dropout fp16

* fix comment
上级 14f1973d
...@@ -19,8 +19,9 @@ from utils import SUB_TOLERANCE ...@@ -19,8 +19,9 @@ from utils import SUB_TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core, framework
from paddle.incubate.autograd import primapi from paddle.incubate.autograd import primapi
from paddle.nn import LayerNorm
def generate_data(shape1, shape2, shape3, dtype="float32"): def generate_data(shape1, shape2, shape3, dtype="float32"):
...@@ -43,7 +44,7 @@ class Attr: ...@@ -43,7 +44,7 @@ class Attr:
self.dtype = dtype self.dtype = dtype
return return
def set_shape(self, n_shape, shape1, shape2, shape3) -> None: def set_shape(self, n_shape, shape1=[], shape2=[], shape3=[]) -> None:
self.n_shape = n_shape self.n_shape = n_shape
self.shape1 = shape1 self.shape1 = shape1
self.shape2 = shape2 self.shape2 = shape2
...@@ -72,7 +73,7 @@ def expect_forward(x, norm_shape, w, b): ...@@ -72,7 +73,7 @@ def expect_forward(x, norm_shape, w, b):
class TestCompositelayer_norm(unittest.TestCase): class TestCompositelayer_norm(unittest.TestCase):
def setUp(self): def setUp(self):
self.dtypes = ["float16", "float32", "float64"] self.dtypes = ["float32", "float64"]
self.n_shape = [[4], [64, 128], [64]] self.n_shape = [[4], [64, 128], [64]]
self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]] self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]]
self.shape2s = [[4], [64 * 128], [64]] self.shape2s = [[4], [64 * 128], [64]]
...@@ -203,5 +204,72 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -203,5 +204,72 @@ class TestCompositelayer_norm(unittest.TestCase):
self.compare_forward() self.compare_forward()
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=False)
class PrimeNet(paddle.nn.Layer):
def __init__(self, n_shape):
super(PrimeNet, self).__init__()
self.ln = LayerNorm(n_shape)
def forward(self, x):
out = self.ln(x)
return out
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.n_shape = [[4], [64, 128], [64]]
self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]]
def train(self, use_prim):
self.x = paddle.randn(attrs.shape1, dtype="float32")
self.x.stop_gradient = False
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = PrimeNet(attrs.n_shape)
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
net = paddle.amp.decorate(models=net, level='O2')
net = apply_to_static(net, False)
with paddle.amp.auto_cast(level='O2'):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
return loss
def compare_forward(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False)
actual = self.train(True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
def test_forward(self):
for t in range(0, len(self.shape1s)):
attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
)
self.compare_forward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,7 +19,7 @@ from utils import TOLERANCE ...@@ -19,7 +19,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core, framework
from paddle.incubate.autograd import primapi from paddle.incubate.autograd import primapi
...@@ -129,5 +129,78 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -129,5 +129,78 @@ class TestCompositeSoftmax(unittest.TestCase):
self.compare_forward() self.compare_forward()
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=False)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.sf = F.softmax
def forward(self, x, current_axis):
out = self.sf(x, axis=current_axis)
return out
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.shapes = [[], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def train(self, use_prim):
self.x = paddle.randn(attrs.shape, dtype="float32")
self.x.stop_gradient = False
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
net = paddle.amp.decorate(models=net, level='O2')
net = apply_to_static(net, False)
with paddle.amp.auto_cast(level='O2'):
out = net(self.x, attrs.axis)
loss = paddle.mean(out)
grad = paddle.grad(loss, self.x)
return loss, grad
def compare_forward(self):
if not attrs.shape and attrs.axis not in [-1, 0]:
# op softmax does not support both case
return
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False)
actual = self.train(True)
np.testing.assert_allclose(
expected[0],
actual[0],
rtol=1e-3,
atol=1e-3,
)
np.testing.assert_allclose(
expected[1],
actual[1],
rtol=1e-3,
atol=1e-3,
)
def test_forward(self):
for i in self.axes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_shape(t)
self.compare_forward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -34,15 +34,26 @@ def _composite(op, *args): ...@@ -34,15 +34,26 @@ def _composite(op, *args):
@REGISTER_COMPOSITE('softmax') @REGISTER_COMPOSITE('softmax')
def softmax_composite(x, axis): def softmax_composite(x, axis):
"""define composite rule of op softmax""" """define composite rule of op softmax"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
# Softmax need fp32 compute since it has sum op in
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
if not x.shape: if not x.shape:
# do not return 1, to ensure gradients # do not return 1, to ensure gradients
res = exp(x - x) res = exp(x - x)
if is_amp:
res = cast(res, "float16")
return res return res
max_temp = max(x, axis, keepdim=True) max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True max_temp.stop_gradient = True
molecular = exp(x - max_temp) molecular = exp(x - max_temp)
denominator = sum(molecular, axis=axis, keepdim=True) denominator = sum(molecular, axis=axis, keepdim=True)
res = divide(molecular, denominator) res = divide(molecular, denominator)
if is_amp:
res = cast(res, "float16")
return res return res
...@@ -65,7 +76,6 @@ def composite_batchnorm( ...@@ -65,7 +76,6 @@ def composite_batchnorm(
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16": if convert_dtype(x.dtype) == "float16":
print("Running batch_norm in amp")
is_amp = True is_amp = True
x = cast(x, "float32") x = cast(x, "float32")
...@@ -128,6 +138,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): ...@@ -128,6 +138,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
out = (x - mean(x)) / sqrt(var + epsilon)) out = (x - mean(x)) / sqrt(var + epsilon))
var = mean((x-mean(x))^2) var = mean((x-mean(x))^2)
""" """
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
axis = tuple(range(begin_norm_axis, len(x.shape))) axis = tuple(range(begin_norm_axis, len(x.shape)))
mean_ = mean(x, axis=axis, keepdim=True) mean_ = mean(x, axis=axis, keepdim=True)
...@@ -147,6 +163,8 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): ...@@ -147,6 +163,8 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
mean_ = reshape(mean_, [-1]) mean_ = reshape(mean_, [-1])
variance = reshape(variance, [-1]) variance = reshape(variance, [-1])
if is_amp:
out = cast(out, "float16")
return out, mean_, variance return out, mean_, variance
...@@ -315,6 +333,7 @@ def dropout_composite(x, seed_tensor, p, is_test, mode, seed, fix_seed): ...@@ -315,6 +333,7 @@ def dropout_composite(x, seed_tensor, p, is_test, mode, seed, fix_seed):
fix_seed = True if fix_seed is None else fix_seed fix_seed = True if fix_seed is None else fix_seed
seed = seed if fix_seed else 0 seed = seed if fix_seed else 0
upscale_in_train = mode == "upscale_in_train" upscale_in_train = mode == "upscale_in_train"
mask = bernoulli(shape=x.shape, dtype=x.dtype, p=p, seed=seed) mask = bernoulli(shape=x.shape, dtype=x.dtype, p=p, seed=seed)
if upscale_in_train: if upscale_in_train:
......
...@@ -631,7 +631,7 @@ def _lower_composite( ...@@ -631,7 +631,7 @@ def _lower_composite(
elif new_out is not None: elif new_out is not None:
assert orig_out.dtype == new_out.dtype, ( assert orig_out.dtype == new_out.dtype, (
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, ' f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
f'but orig_out.dtype={orig_out.dtype} and new_out.dtype={new_out.dtype}' f'but orig_out: {orig_out.name}.dtype={orig_out.dtype} and new_out: {new_out.name}.dtype={new_out.dtype}'
) )
if orig_out.shape and new_out.shape: if orig_out.shape and new_out.shape:
assert ( assert (
...@@ -639,7 +639,7 @@ def _lower_composite( ...@@ -639,7 +639,7 @@ def _lower_composite(
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.' ), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
assert orig_out.shape == new_out.shape, ( assert orig_out.shape == new_out.shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' f'but orig_out: {orig_out.name}.shape={orig_out.shape} and new_out: {new_out.name}.shape={new_out.shape}'
) )
assert not (orig_out is None) ^ ( assert not (orig_out is None) ^ (
new_out is None new_out is None
......
...@@ -320,9 +320,7 @@ class PartialProgramLayer: ...@@ -320,9 +320,7 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_amp_program(self): def _create_forward_backward_train_amp_program(self):
whole_program = self._train_amp_program whole_program = self._train_amp_program
_, forward_end_op_index = self._infer_info( forward_end_op_index = self.get_forward_end_op_idx(whole_program)
'amp', self._create_amp_program
)
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
...@@ -332,9 +330,7 @@ class PartialProgramLayer: ...@@ -332,9 +330,7 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self): def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._train_pure_fp16_program whole_program = self._train_pure_fp16_program
_, forward_end_op_index = self._infer_info( forward_end_op_index = self.get_forward_end_op_idx(whole_program)
'fp16', self._create_pure_fp16_program
)
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
......
...@@ -238,8 +238,11 @@ def batch_norm( ...@@ -238,8 +238,11 @@ def batch_norm(
} }
helper = LayerHelper('batch_norm', **locals()) helper = LayerHelper('batch_norm', **locals())
from paddle.fluid.data_feeder import convert_dtype
param_dtype = x.dtype if x.dtype != 'float16' else 'float32' param_dtype = (
x.dtype if convert_dtype(x.dtype) != 'float16' else 'float32'
)
saved_mean = helper.create_variable_for_type_inference( saved_mean = helper.create_variable_for_type_inference(
dtype=param_dtype, stop_gradient=True dtype=param_dtype, stop_gradient=True
) )
...@@ -348,15 +351,18 @@ def layer_norm( ...@@ -348,15 +351,18 @@ def layer_norm(
# create output # create output
helper = LayerHelper('layer_norm', **locals()) helper = LayerHelper('layer_norm', **locals())
from paddle.fluid.data_feeder import convert_dtype
dtype = x.dtype param_dtype = (
x.dtype if convert_dtype(x.dtype) != 'float16' else 'float32'
)
mean_out = helper.create_variable_for_type_inference( mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True dtype=param_dtype, stop_gradient=True
) )
variance_out = helper.create_variable_for_type_inference( variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True dtype=param_dtype, stop_gradient=True
) )
layer_norm_out = helper.create_variable_for_type_inference(dtype) layer_norm_out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
type="layer_norm", type="layer_norm",
......
...@@ -654,7 +654,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -654,7 +654,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
) )
else: else:
check_type(shape, 'shape', (list, tuple, Variable), 'uniform/rand') check_type(shape, 'shape', (list, tuple, Variable), 'uniform/rand')
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform/rand') check_dtype(
dtype, 'dtype', ('float16', 'float32', 'float64'), 'uniform/rand'
)
check_type(min, 'min', (float, int, Variable), 'uniform/rand') check_type(min, 'min', (float, int, Variable), 'uniform/rand')
check_type(max, 'max', (float, int, Variable), 'uniform/rand') check_type(max, 'max', (float, int, Variable), 'uniform/rand')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册