未验证 提交 2f4763ee 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】Layer norm (#50422)

* fix composite mean op map

* fix composite check output

* init layer_norm

* init layer_norm

* map output from composite rule to origin op

* add dropout op map

* add input map check

* polish log

* modify rules

* success test_forward

* modify test without cinn

* modify cinn test

* modify cinn test

* except fp64

* except fp64

* delete flatten

* delete unused change

* review

* pass cpu test

* code style

* delete flatten fp16 error

* modify flatten test

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
上级 27281e1f
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
TOLERANCE = {
"float16": {"rtol": 1e-2, "atol": 1e-2},
"float32": {"rtol": 1e-5, "atol": 1e-5},
"float64": {"rtol": 1e-13, "atol": 1e-13},
}
def generate_data(dtype="float32"):
np_data1 = np.random.random([2, 64]).astype(dtype)
np_data2 = np.random.random([64]).astype(dtype)
np_data3 = np.random.random([64]).astype(dtype)
return np_data1, np_data2, np_data3
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(64, 64)
def forward(self, x, w, b):
n_shape = x.shape[1:]
out = F.layer_norm(x, n_shape, w, b)
return out[0]
class TestPrimForward(unittest.TestCase):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
set this flag as False to avoid prim_backward.
core.set_prim_backward(False)
"""
def setUp(self):
self.x = None
self.w = None
self.b = None
self.dtypes = ["float16", "float32"]
def train(self, use_prim):
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
core._set_prim_forward_enabled(use_prim)
core._add_skip_comp_ops("sqrt")
# TODO(Ruting) delete this after modify sqrt
if use_prim:
net = apply_to_static(net, use_prim)
out = net(self.x, self.w, self.b)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
self.check_prim(net, use_prim)
return out.numpy()
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim_forward(self):
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu":
print("need pass this case")
continue
x_n, w_n, b_n = generate_data(dtype)
self.x = paddle.to_tensor(x_n)
self.w = paddle.to_tensor(w_n)
self.b = paddle.to_tensor(b_n)
self.x.stop_gradient = False
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
self.x = None
self.w = None
self.b = None
self.dtypes = ["float16", "float32"]
def train(self, use_prim):
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
core._set_prim_all_enabled(use_prim)
core._add_skip_comp_ops("sqrt")
# TODO(Ruting) delete this after modify sqrt
if use_prim:
net = apply_to_static(net, use_prim)
out = net(self.x, self.w, self.b)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
self.check_prim(net, use_prim)
return out.numpy()
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu":
print("need pass this case")
continue
x_n, w_n, b_n = generate_data(dtype)
self.x = paddle.to_tensor(x_n)
self.w = paddle.to_tensor(w_n)
self.b = paddle.to_tensor(b_n)
self.x.stop_gradient = False
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
if __name__ == '__main__':
unittest.main()
...@@ -273,25 +273,6 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -273,25 +273,6 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_ValueError3) self.assertRaises(ValueError, test_ValueError3)
def test_type():
# dtype must be float32, float64, int8, int32, int64, uint8.
x2 = (
np.arange(
image_shape[0]
* image_shape[1]
* image_shape[2]
* image_shape[3]
).reshape(image_shape)
/ 100.0
)
x2 = x2.astype('float16')
x2_var = paddle.fluid.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16'
)
paddle.flatten(x2_var)
self.assertRaises(TypeError, test_type)
def test_InputError(): def test_InputError():
out = paddle.flatten(x) out = paddle.flatten(x)
......
...@@ -272,24 +272,6 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -272,24 +272,6 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_ValueError3) self.assertRaises(ValueError, test_ValueError3)
def test_type():
# dtype must be float32, float64, int8, int32, int64, uint8.
x2 = (
np.arange(
image_shape[0]
* image_shape[1]
* image_shape[2]
* image_shape[3]
).reshape(image_shape)
/ 100.0
)
x2 = x2.astype('float16')
x2_var = paddle.fluid.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16'
)
paddle.flatten(x2_var)
self.assertRaises(TypeError, test_type)
def test_InputError(): def test_InputError():
out = paddle.flatten(x) out = paddle.flatten(x)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from utils import SUB_TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape1, shape2, shape3, dtype="float32"):
np.random.seed(200)
np_data1 = np.random.random(shape1).astype(dtype)
np_data2 = np.random.random(shape2).astype(dtype)
np_data3 = np.random.random(shape3).astype(dtype)
return np_data1, np_data2, np_data3
class Attr:
def __init__(self) -> None:
self.dtype = None
self.n_shape = None
self.shape1 = None
self.shape2 = None
self.shape3 = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_shape(self, n_shape, shape1, shape2, shape3) -> None:
self.n_shape = n_shape
self.shape1 = shape1
self.shape2 = shape2
self.shape3 = shape3
return
def get_rtol(self, flag):
rtol = SUB_TOLERANCE[self.dtype][flag].get("rtol")
return rtol
def get_atol(self, flag):
atol = SUB_TOLERANCE[self.dtype][flag].get("atol")
return atol
attrs = Attr()
def fn(x, norm_shape, w, b):
return F.layer_norm(x, norm_shape, w, b)
def expect_forward(x, norm_shape, w, b):
return fn(x, norm_shape, w, b)
class TestCompositelayer_norm(unittest.TestCase):
def setUp(self):
self.dtypes = ["float16", "float32", "float64"]
self.n_shape = [[4], [64, 128], [64]]
self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]]
self.shape2s = [[4], [64 * 128], [64]]
self.shape3s = [[4], [64 * 128], [64]]
def cal_composite(self, inputs, norm_shape, weight, bias):
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
w = paddle.static.data(
'w', shape=weight.shape, dtype=str(weight.dtype)
)
b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
y = fn(x, norm_shape, w, b)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops_new)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
feed={
'x': inputs,
'w': weight,
'b': bias,
},
fetch_list=[y],
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
def cal2_composite(self, inputs, norm_shape, weight, bias):
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
y = fn(x, norm_shape, weight, bias)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops_new)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
feed={
'x': inputs,
},
fetch_list=[y],
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
def compare_forward(self):
x, w, b = generate_data(
attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype
)
n_shape = attrs.n_shape
x_p = paddle.to_tensor(x)
w_p = paddle.to_tensor(w)
b_p = paddle.to_tensor(b)
expect = expect_forward(x_p, n_shape, w_p, b_p).numpy()
actual = self.cal_composite(x, n_shape, w, b)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("forward"),
atol=attrs.get_atol("forward"),
)
expect_2 = expect_forward(x_p, n_shape, None, None).numpy()
actual_2 = self.cal2_composite(x, n_shape, None, None)[0]
assert expect_2.dtype == actual_2.dtype
np.testing.assert_allclose(
expect_2,
actual_2,
rtol=attrs.get_rtol("forward"),
atol=attrs.get_atol("forward"),
)
def test_forward(self):
for j in self.dtypes:
if paddle.device.get_device() == "cpu" and j == "float16":
print("need pass this case")
continue
for t in range(0, len(self.shape1s)):
attrs.set_dtype(j)
attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_forward()
if __name__ == '__main__':
unittest.main()
...@@ -33,6 +33,11 @@ TOLERANCE = { ...@@ -33,6 +33,11 @@ TOLERANCE = {
# this tolerance is for big composite ops like batch_norm. # this tolerance is for big composite ops like batch_norm.
SUB_TOLERANCE = { SUB_TOLERANCE = {
"float16": {
"forward": {"rtol": 1e-2, "atol": 1e-2},
"backward": {"rtol": 1e-2, "atol": 1e-2},
"prim_backward": {"rtol": 1e-2, "atol": 1e-2},
},
"float32": { "float32": {
"forward": {"rtol": 1e-5, "atol": 1e-5}, "forward": {"rtol": 1e-5, "atol": 1e-5},
"backward": {"rtol": 1e-5, "atol": 1e-5}, "backward": {"rtol": 1e-5, "atol": 1e-5},
......
...@@ -206,25 +206,6 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -206,25 +206,6 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_ValueError5) self.assertRaises(ValueError, test_ValueError5)
def test_type():
# dtype must be float32, float64, int8, int32, int64, uint8.
x2 = (
np.arange(
image_shape[0]
* image_shape[1]
* image_shape[2]
* image_shape[3]
).reshape(image_shape)
/ 100.0
)
x2 = x2.astype('float16')
x2_var = paddle.fluid.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16'
)
paddle.flatten(x2_var)
self.assertRaises(TypeError, test_type)
def test_InputError(): def test_InputError():
out = paddle.flatten(x) out = paddle.flatten(x)
......
...@@ -264,25 +264,6 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -264,25 +264,6 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_ValueError3) self.assertRaises(ValueError, test_ValueError3)
def test_type():
# dtype must be float32, float64, int8, int32, int64
x2 = (
np.arange(
image_shape[0]
* image_shape[1]
* image_shape[2]
* image_shape[3]
).reshape(image_shape)
/ 100.0
)
x2 = x2.astype('float16')
x2_var = paddle.fluid.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16'
)
paddle.flatten(x2_var)
self.assertRaises(TypeError, test_type)
def test_InputError(): def test_InputError():
out = paddle.flatten(x) out = paddle.flatten(x)
......
...@@ -110,6 +110,35 @@ def composite_batchnorm( ...@@ -110,6 +110,35 @@ def composite_batchnorm(
return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space
@REGISTER_COMPOSITE('layer_norm')
def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
"""
define composite rule of op layer_norm
out = (x - mean(x)) / sqrt(var + epsilon))
var = mean((x-mean(x))^2)
"""
axis = tuple(range(begin_norm_axis, len(x.shape)))
mean_ = mean(x, axis=axis, keepdim=True)
difference = x - mean_
var_tmp1 = difference * difference
variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon
sqrt_var = sqrt(var_tmp3)
out = difference / sqrt_var
if scale is not None:
scale = reshape(scale, x.shape[begin_norm_axis:])
out = out * scale
if bias is not None:
bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + bias
mean_ = reshape(mean_, [-1])
variance = reshape(variance, [-1])
return out, mean_, variance
@REGISTER_COMPOSITE('gelu') @REGISTER_COMPOSITE('gelu')
def gelu_composite(x, approximate): def gelu_composite(x, approximate):
"""define composite rule of op gelu""" """define composite rule of op gelu"""
......
...@@ -184,7 +184,9 @@ def _get_args_values(op, phi_name): ...@@ -184,7 +184,9 @@ def _get_args_values(op, phi_name):
and arg_name in op_content["attrs"].keys() and arg_name in op_content["attrs"].keys()
): ):
arg_name = op_content["attrs"][arg_name] arg_name = op_content["attrs"][arg_name]
# Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded. # Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded.
if arg_name not in op.attr_names: if arg_name not in op.attr_names:
attrs.append(None) attrs.append(None)
else: else:
......
...@@ -1582,7 +1582,16 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -1582,7 +1582,16 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'], [
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
],
'flatten', 'flatten',
) )
helper = LayerHelper('flatten', **locals()) helper = LayerHelper('flatten', **locals())
...@@ -3285,7 +3294,7 @@ def broadcast_to(x, shape, name=None): ...@@ -3285,7 +3294,7 @@ def broadcast_to(x, shape, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['bool', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'broadcast_to', 'broadcast_to',
) )
check_type(shape, 'shape', (list, tuple, Variable), 'broadcast_to') check_type(shape, 'shape', (list, tuple, Variable), 'broadcast_to')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册