提交 82713eb6 编写于 作者: W wangruting

init layer_norm

上级 637dfe49
......@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.n_shape = x.shape[1:]
self.n_shape = self.x.shape
self.w = paddle.randn([4])
self.b = paddle.randn([4])
self.x.stop_gradient = False
......@@ -86,7 +86,7 @@ class TestPrimForward(unittest.TestCase):
self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim_forward(self):
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
......@@ -94,7 +94,7 @@ class TestPrimForward(unittest.TestCase):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
)
class TestPrimForwardAndBackward(unittest.TestCase):
"""
......@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.n_shape = x.shape[1:]
self.n_shape = self.x.shape
self.w = paddle.randn([4])
self.b = paddle.randn([4])
self.x.stop_gradient = False
......
......@@ -20,7 +20,6 @@ from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle import _C_ops, in_dynamic_mode
def generate_data(shape1, shape2, shape3, dtype="float32"):
......@@ -38,7 +37,6 @@ class Attr:
self.shape1 = None
self.shape2 = None
self.shape3 = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
......@@ -66,14 +64,15 @@ attrs = Attr()
def fn(x, norm_shape, w, b):
return F.layer_norm(x, norm_shape, w, b)
def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
axis = np.arange(begin_norm_axis,len(input.shape))
def layer_norm_(input, weight, bias, epsilon=1e-05, begin_norm_axis=0):
axis = np.arange(begin_norm_axis, len(input.shape))
mean = paddle.mean(input, axis=axis, keepdim=True)
t1 = input - mean
t2 = paddle.pow( t1, 2.0)
t3 = paddle.mean( t2, axis=axis, keepdim=True)
t2 = paddle.pow(t1, 2.0)
t3 = paddle.mean(t2, axis=axis, keepdim=True)
t4 = t3 + epsilon
t5 = paddle.sqrt( t4 )
t5 = paddle.sqrt(t4)
t7 = t1 / t5
out = t7
if weight is not None:
......@@ -82,15 +81,15 @@ def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
if bias is not None:
bias = paddle.reshape(bias, input.shape[begin_norm_axis:])
out = out + paddle.broadcast_to(bias, out.shape)
return out
def composite_forward(x, norm_shape, w, b):
b_axis = len(x.shape) - len(norm_shape)
return layer_norm_(x, w, b, begin_norm_axis=b_axis)
def expect_forward(x, norm_shape, w, b):
return fn(x, norm_shape, w, b)
......@@ -98,10 +97,10 @@ def expect_forward(x, norm_shape, w, b):
class TestCompositelayer_norm(unittest.TestCase):
def setUp(self):
self.dtypes = ["float16", "float32"]
self.n_shape = [[3, 4],[3], [2, 3]]
self.shape1s = [[3, 4],[2, 4, 3], [2, 2, 3]]
self.shape2s = [[12],[3],[6]]
self.shape3s = [[12],[3],[6]]
self.n_shape = [[3, 4], [3], [2, 3]]
self.shape1s = [[3, 4], [2, 4, 3], [2, 2, 3]]
self.shape2s = [[12], [3], [6]]
self.shape3s = [[12], [3], [6]]
def cal_composite(self, inputs, norm_shape, weight, bias):
paddle.enable_static()
......@@ -115,11 +114,9 @@ class TestCompositelayer_norm(unittest.TestCase):
w = paddle.static.data(
'w', shape=weight.shape, dtype=str(weight.dtype)
)
b = paddle.static.data(
'b', shape=bias.shape, dtype=str(bias.dtype)
)
b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
y = fn(x, norm_shape, w, b)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
......@@ -135,13 +132,14 @@ class TestCompositelayer_norm(unittest.TestCase):
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
main_program,
feed={
'x': inputs,
'w': weight,
'b': bias,
},
fetch_list=[y])
},
fetch_list=[y],
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
......@@ -154,12 +152,9 @@ class TestCompositelayer_norm(unittest.TestCase):
b_p = paddle.to_tensor(b)
expect = expect_forward(x_p, n_shape, w_p, b_p).numpy()
print("expect = ", expect)
#actual = self.cal_composite(x_p, n_shape, w_p, b_p)
# actual = self.cal_composite(x_p, n_shape, w_p, b_p)
actual = composite_forward(x_p, n_shape, w_p, b_p).numpy()
print("actual = ", actual)
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
......@@ -180,9 +175,14 @@ class TestCompositelayer_norm(unittest.TestCase):
def test_forward(self):
for j in self.dtypes:
for t in range(0,len(self.shape1s)):
for t in range(0, len(self.shape1s)):
attrs.set_dtype(j)
attrs.set_shape(self.n_shape[t], self.shape1s[t], self.shape2s[t], self.shape3s[t])
attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_forward()
......
......@@ -20,7 +20,6 @@ from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
from paddle import _C_ops, in_dynamic_mode
def generate_data(shape1, shape2, shape3, dtype="float32"):
......@@ -38,7 +37,6 @@ class Attr:
self.shape1 = None
self.shape2 = None
self.shape3 = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
......@@ -66,6 +64,7 @@ attrs = Attr()
def fn(x, norm_shape, w, b):
return F.layer_norm(x, norm_shape, w, b)
# def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
# axis = np.arange(begin_norm_axis,len(input.shape))
# mean = paddle.mean(input, axis=axis, keepdim=True)
......@@ -82,7 +81,7 @@ def fn(x, norm_shape, w, b):
# if bias is not None:
# bias = paddle.reshape(bias, input.shape[begin_norm_axis:])
# out = out + paddle.broadcast_to(bias, out.shape)
# return out
# def composite_forward(x, norm_shape, w, b):
......@@ -90,11 +89,10 @@ def fn(x, norm_shape, w, b):
# return layer_norm_(x, w, b, begin_norm_axis=b_axis)
def expect_backward(x, norm_shape, w, b):
paddle.disable_static()
x.stop_gradient = False
res = fn(x, norm_shape, w, b )
res = fn(x, norm_shape, w, b)
gradients = paddle.grad(res, x)
return gradients
......@@ -103,10 +101,10 @@ def expect_backward(x, norm_shape, w, b):
class TestCompositelayer_norm(unittest.TestCase):
def setUp(self):
self.dtypes = ["float16", "float32"]
self.n_shape = [[3, 4],[3], [2, 3]]
self.shape1s = [[3, 4],[2, 4, 3], [2, 2, 3]]
self.shape2s = [[12],[3],[6]]
self.shape3s = [[12],[3],[6]]
self.n_shape = [[3, 4], [3], [2, 3]]
self.shape1s = [[3, 4], [2, 4, 3], [2, 2, 3]]
self.shape2s = [[12], [3], [6]]
self.shape3s = [[12], [3], [6]]
def cal_composite_backward(self, inputs, norm_shape, weight, bias):
paddle.enable_static()
......@@ -121,11 +119,9 @@ class TestCompositelayer_norm(unittest.TestCase):
w = paddle.static.data(
'w', shape=weight.shape, dtype=str(weight.dtype)
)
b = paddle.static.data(
'b', shape=bias.shape, dtype=str(bias.dtype)
)
b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
y = fn(x, norm_shape, w, b)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
......@@ -147,13 +143,14 @@ class TestCompositelayer_norm(unittest.TestCase):
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
main_program,
feed={
'x': inputs,
'w': weight,
'b': bias,
},
fetch_list=[z])
},
fetch_list=[z],
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
......@@ -188,9 +185,14 @@ class TestCompositelayer_norm(unittest.TestCase):
def test_backward(self):
for j in self.dtypes:
for t in range(0,len(self.shape1s)):
for t in range(0, len(self.shape1s)):
attrs.set_dtype(j)
attrs.set_shape(self.n_shape[t], self.shape1s[t], self.shape2s[t], self.shape3s[t])
attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_backward()
......@@ -198,10 +200,10 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float16", "float32"]
self.n_shape = [[3, 4],[3], [2, 3]]
self.shape1s = [[3, 4],[2, 4, 3], [2, 2, 3]]
self.shape2s = [[12],[3],[6]]
self.shape3s = [[12],[3],[6]]
self.n_shape = [[3, 4], [3], [2, 3]]
self.shape1s = [[3, 4], [2, 4, 3], [2, 2, 3]]
self.shape2s = [[12], [3], [6]]
self.shape3s = [[12], [3], [6]]
def cal_composite_backward(self, inputs, norm_shape, weight, bias):
paddle.enable_static()
......@@ -216,11 +218,9 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
w = paddle.static.data(
'w', shape=weight.shape, dtype=str(weight.dtype)
)
b = paddle.static.data(
'b', shape=bias.shape, dtype=str(bias.dtype)
)
b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
y = fn(x, norm_shape, w, b)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
z = paddle.static.gradients([y], x)
......@@ -228,13 +228,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
main_program,
feed={
'x': inputs,
'w': weight,
'b': bias,
},
fetch_list=[z])
},
fetch_list=[z],
)
paddle.disable_static()
core._set_prim_all_enabled(False)
return res
......@@ -269,9 +270,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
def test_prim_backward(self):
for j in self.dtypes:
for t in range(0,len(self.shape1s)):
for t in range(0, len(self.shape1s)):
attrs.set_dtype(j)
attrs.set_shape(self.n_shape[t], self.shape1s[t], self.shape2s[t], self.shape3s[t])
attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_backward()
......
......@@ -104,21 +104,21 @@ def composite_batchnorm(
@REGISTER_COMPOSITE('layer_norm')
def layernorm_composite (x, scale, bias, epsilon, begin_norm_axis):
axis = np.arange(begin_norm_axis,len(x.shape))
def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
axis = np.arange(begin_norm_axis, len(x.shape))
mean_ = mean(x, axis=axis, keepdim=True)
difference = x - mean_
var_tmp1 = pow( difference, 2.0)
variance = mean( var_tmp1, axis=axis, keepdim=True)
var_tmp1 = pow(difference, 2.0)
variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon
sqrt_var = sqrt( var_tmp3 )
sqrt_var = sqrt(var_tmp3)
out = difference / sqrt_var
if scale is not None:
scale = reshape(scale, x.shape[begin_norm_axis:])
out = t7 * broadcast_to(scale, out.shape)
out = out * broadcast_to(scale, out.shape)
if bias is not None:
bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + broadcast_to(bias, out.shape)
return out, mean_, variance
\ No newline at end of file
return out, mean_, variance
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册