提交 82713eb6 编写于 作者: W wangruting

init layer_norm

上级 637dfe49
...@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.n_shape = x.shape[1:] self.n_shape = self.x.shape
self.w = paddle.randn([4]) self.w = paddle.randn([4])
self.b = paddle.randn([4]) self.b = paddle.randn([4])
self.x.stop_gradient = False self.x.stop_gradient = False
...@@ -86,7 +86,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -86,7 +86,7 @@ class TestPrimForward(unittest.TestCase):
self.assertTrue('layer_norm' not in fwd_ops) self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim_forward(self): def test_cinn_prim_forward(self):
dy_res = self.train(use_prim=False) dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True) cinn_res = self.train(use_prim=True)
...@@ -94,7 +94,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -94,7 +94,7 @@ class TestPrimForward(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6 cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
) )
class TestPrimForwardAndBackward(unittest.TestCase): class TestPrimForwardAndBackward(unittest.TestCase):
""" """
...@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.n_shape = x.shape[1:] self.n_shape = self.x.shape
self.w = paddle.randn([4]) self.w = paddle.randn([4])
self.b = paddle.randn([4]) self.b = paddle.randn([4])
self.x.stop_gradient = False self.x.stop_gradient = False
......
...@@ -20,7 +20,6 @@ from utils import TOLERANCE ...@@ -20,7 +20,6 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle import _C_ops, in_dynamic_mode
def generate_data(shape1, shape2, shape3, dtype="float32"): def generate_data(shape1, shape2, shape3, dtype="float32"):
...@@ -38,7 +37,6 @@ class Attr: ...@@ -38,7 +37,6 @@ class Attr:
self.shape1 = None self.shape1 = None
self.shape2 = None self.shape2 = None
self.shape3 = None self.shape3 = None
def set_dtype(self, dtype) -> None: def set_dtype(self, dtype) -> None:
self.dtype = dtype self.dtype = dtype
...@@ -66,14 +64,15 @@ attrs = Attr() ...@@ -66,14 +64,15 @@ attrs = Attr()
def fn(x, norm_shape, w, b): def fn(x, norm_shape, w, b):
return F.layer_norm(x, norm_shape, w, b) return F.layer_norm(x, norm_shape, w, b)
def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
axis = np.arange(begin_norm_axis,len(input.shape)) def layer_norm_(input, weight, bias, epsilon=1e-05, begin_norm_axis=0):
axis = np.arange(begin_norm_axis, len(input.shape))
mean = paddle.mean(input, axis=axis, keepdim=True) mean = paddle.mean(input, axis=axis, keepdim=True)
t1 = input - mean t1 = input - mean
t2 = paddle.pow( t1, 2.0) t2 = paddle.pow(t1, 2.0)
t3 = paddle.mean( t2, axis=axis, keepdim=True) t3 = paddle.mean(t2, axis=axis, keepdim=True)
t4 = t3 + epsilon t4 = t3 + epsilon
t5 = paddle.sqrt( t4 ) t5 = paddle.sqrt(t4)
t7 = t1 / t5 t7 = t1 / t5
out = t7 out = t7
if weight is not None: if weight is not None:
...@@ -82,15 +81,15 @@ def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0): ...@@ -82,15 +81,15 @@ def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
if bias is not None: if bias is not None:
bias = paddle.reshape(bias, input.shape[begin_norm_axis:]) bias = paddle.reshape(bias, input.shape[begin_norm_axis:])
out = out + paddle.broadcast_to(bias, out.shape) out = out + paddle.broadcast_to(bias, out.shape)
return out return out
def composite_forward(x, norm_shape, w, b): def composite_forward(x, norm_shape, w, b):
b_axis = len(x.shape) - len(norm_shape) b_axis = len(x.shape) - len(norm_shape)
return layer_norm_(x, w, b, begin_norm_axis=b_axis) return layer_norm_(x, w, b, begin_norm_axis=b_axis)
def expect_forward(x, norm_shape, w, b): def expect_forward(x, norm_shape, w, b):
return fn(x, norm_shape, w, b) return fn(x, norm_shape, w, b)
...@@ -98,10 +97,10 @@ def expect_forward(x, norm_shape, w, b): ...@@ -98,10 +97,10 @@ def expect_forward(x, norm_shape, w, b):
class TestCompositelayer_norm(unittest.TestCase): class TestCompositelayer_norm(unittest.TestCase):
def setUp(self): def setUp(self):
self.dtypes = ["float16", "float32"] self.dtypes = ["float16", "float32"]
self.n_shape = [[3, 4],[3], [2, 3]] self.n_shape = [[3, 4], [3], [2, 3]]
self.shape1s = [[3, 4],[2, 4, 3], [2, 2, 3]] self.shape1s = [[3, 4], [2, 4, 3], [2, 2, 3]]
self.shape2s = [[12],[3],[6]] self.shape2s = [[12], [3], [6]]
self.shape3s = [[12],[3],[6]] self.shape3s = [[12], [3], [6]]
def cal_composite(self, inputs, norm_shape, weight, bias): def cal_composite(self, inputs, norm_shape, weight, bias):
paddle.enable_static() paddle.enable_static()
...@@ -115,11 +114,9 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -115,11 +114,9 @@ class TestCompositelayer_norm(unittest.TestCase):
w = paddle.static.data( w = paddle.static.data(
'w', shape=weight.shape, dtype=str(weight.dtype) 'w', shape=weight.shape, dtype=str(weight.dtype)
) )
b = paddle.static.data( b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
'b', shape=bias.shape, dtype=str(bias.dtype)
)
y = fn(x, norm_shape, w, b) y = fn(x, norm_shape, w, b)
blocks = main_program.blocks blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops] fwd_ops = [op.type for op in blocks[0].ops]
...@@ -135,13 +132,14 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -135,13 +132,14 @@ class TestCompositelayer_norm(unittest.TestCase):
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
res = exe.run( res = exe.run(
main_program, main_program,
feed={ feed={
'x': inputs, 'x': inputs,
'w': weight, 'w': weight,
'b': bias, 'b': bias,
}, },
fetch_list=[y]) fetch_list=[y],
)
paddle.disable_static() paddle.disable_static()
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
return res return res
...@@ -154,12 +152,9 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -154,12 +152,9 @@ class TestCompositelayer_norm(unittest.TestCase):
b_p = paddle.to_tensor(b) b_p = paddle.to_tensor(b)
expect = expect_forward(x_p, n_shape, w_p, b_p).numpy() expect = expect_forward(x_p, n_shape, w_p, b_p).numpy()
# actual = self.cal_composite(x_p, n_shape, w_p, b_p)
print("expect = ", expect)
#actual = self.cal_composite(x_p, n_shape, w_p, b_p)
actual = composite_forward(x_p, n_shape, w_p, b_p).numpy() actual = composite_forward(x_p, n_shape, w_p, b_p).numpy()
print("actual = ", actual)
assert expect.dtype == actual.dtype assert expect.dtype == actual.dtype
np.testing.assert_allclose( np.testing.assert_allclose(
expect, expect,
...@@ -180,9 +175,14 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -180,9 +175,14 @@ class TestCompositelayer_norm(unittest.TestCase):
def test_forward(self): def test_forward(self):
for j in self.dtypes: for j in self.dtypes:
for t in range(0,len(self.shape1s)): for t in range(0, len(self.shape1s)):
attrs.set_dtype(j) attrs.set_dtype(j)
attrs.set_shape(self.n_shape[t], self.shape1s[t], self.shape2s[t], self.shape3s[t]) attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_forward() self.compare_forward()
......
...@@ -20,7 +20,6 @@ from utils import TOLERANCE ...@@ -20,7 +20,6 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core from paddle.fluid import core
from paddle import _C_ops, in_dynamic_mode
def generate_data(shape1, shape2, shape3, dtype="float32"): def generate_data(shape1, shape2, shape3, dtype="float32"):
...@@ -38,7 +37,6 @@ class Attr: ...@@ -38,7 +37,6 @@ class Attr:
self.shape1 = None self.shape1 = None
self.shape2 = None self.shape2 = None
self.shape3 = None self.shape3 = None
def set_dtype(self, dtype) -> None: def set_dtype(self, dtype) -> None:
self.dtype = dtype self.dtype = dtype
...@@ -66,6 +64,7 @@ attrs = Attr() ...@@ -66,6 +64,7 @@ attrs = Attr()
def fn(x, norm_shape, w, b): def fn(x, norm_shape, w, b):
return F.layer_norm(x, norm_shape, w, b) return F.layer_norm(x, norm_shape, w, b)
# def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0): # def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
# axis = np.arange(begin_norm_axis,len(input.shape)) # axis = np.arange(begin_norm_axis,len(input.shape))
# mean = paddle.mean(input, axis=axis, keepdim=True) # mean = paddle.mean(input, axis=axis, keepdim=True)
...@@ -82,7 +81,7 @@ def fn(x, norm_shape, w, b): ...@@ -82,7 +81,7 @@ def fn(x, norm_shape, w, b):
# if bias is not None: # if bias is not None:
# bias = paddle.reshape(bias, input.shape[begin_norm_axis:]) # bias = paddle.reshape(bias, input.shape[begin_norm_axis:])
# out = out + paddle.broadcast_to(bias, out.shape) # out = out + paddle.broadcast_to(bias, out.shape)
# return out # return out
# def composite_forward(x, norm_shape, w, b): # def composite_forward(x, norm_shape, w, b):
...@@ -90,11 +89,10 @@ def fn(x, norm_shape, w, b): ...@@ -90,11 +89,10 @@ def fn(x, norm_shape, w, b):
# return layer_norm_(x, w, b, begin_norm_axis=b_axis) # return layer_norm_(x, w, b, begin_norm_axis=b_axis)
def expect_backward(x, norm_shape, w, b): def expect_backward(x, norm_shape, w, b):
paddle.disable_static() paddle.disable_static()
x.stop_gradient = False x.stop_gradient = False
res = fn(x, norm_shape, w, b ) res = fn(x, norm_shape, w, b)
gradients = paddle.grad(res, x) gradients = paddle.grad(res, x)
return gradients return gradients
...@@ -103,10 +101,10 @@ def expect_backward(x, norm_shape, w, b): ...@@ -103,10 +101,10 @@ def expect_backward(x, norm_shape, w, b):
class TestCompositelayer_norm(unittest.TestCase): class TestCompositelayer_norm(unittest.TestCase):
def setUp(self): def setUp(self):
self.dtypes = ["float16", "float32"] self.dtypes = ["float16", "float32"]
self.n_shape = [[3, 4],[3], [2, 3]] self.n_shape = [[3, 4], [3], [2, 3]]
self.shape1s = [[3, 4],[2, 4, 3], [2, 2, 3]] self.shape1s = [[3, 4], [2, 4, 3], [2, 2, 3]]
self.shape2s = [[12],[3],[6]] self.shape2s = [[12], [3], [6]]
self.shape3s = [[12],[3],[6]] self.shape3s = [[12], [3], [6]]
def cal_composite_backward(self, inputs, norm_shape, weight, bias): def cal_composite_backward(self, inputs, norm_shape, weight, bias):
paddle.enable_static() paddle.enable_static()
...@@ -121,11 +119,9 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -121,11 +119,9 @@ class TestCompositelayer_norm(unittest.TestCase):
w = paddle.static.data( w = paddle.static.data(
'w', shape=weight.shape, dtype=str(weight.dtype) 'w', shape=weight.shape, dtype=str(weight.dtype)
) )
b = paddle.static.data( b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
'b', shape=bias.shape, dtype=str(bias.dtype)
)
y = fn(x, norm_shape, w, b) y = fn(x, norm_shape, w, b)
blocks = main_program.blocks blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops] fwd_ops = [op.type for op in blocks[0].ops]
...@@ -147,13 +143,14 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -147,13 +143,14 @@ class TestCompositelayer_norm(unittest.TestCase):
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
res = exe.run( res = exe.run(
main_program, main_program,
feed={ feed={
'x': inputs, 'x': inputs,
'w': weight, 'w': weight,
'b': bias, 'b': bias,
}, },
fetch_list=[z]) fetch_list=[z],
)
paddle.disable_static() paddle.disable_static()
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
return res return res
...@@ -188,9 +185,14 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -188,9 +185,14 @@ class TestCompositelayer_norm(unittest.TestCase):
def test_backward(self): def test_backward(self):
for j in self.dtypes: for j in self.dtypes:
for t in range(0,len(self.shape1s)): for t in range(0, len(self.shape1s)):
attrs.set_dtype(j) attrs.set_dtype(j)
attrs.set_shape(self.n_shape[t], self.shape1s[t], self.shape2s[t], self.shape3s[t]) attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_backward() self.compare_backward()
...@@ -198,10 +200,10 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): ...@@ -198,10 +200,10 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
def setUp(self): def setUp(self):
core._set_prim_backward_enabled(True) core._set_prim_backward_enabled(True)
self.dtypes = ["float16", "float32"] self.dtypes = ["float16", "float32"]
self.n_shape = [[3, 4],[3], [2, 3]] self.n_shape = [[3, 4], [3], [2, 3]]
self.shape1s = [[3, 4],[2, 4, 3], [2, 2, 3]] self.shape1s = [[3, 4], [2, 4, 3], [2, 2, 3]]
self.shape2s = [[12],[3],[6]] self.shape2s = [[12], [3], [6]]
self.shape3s = [[12],[3],[6]] self.shape3s = [[12], [3], [6]]
def cal_composite_backward(self, inputs, norm_shape, weight, bias): def cal_composite_backward(self, inputs, norm_shape, weight, bias):
paddle.enable_static() paddle.enable_static()
...@@ -216,11 +218,9 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): ...@@ -216,11 +218,9 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
w = paddle.static.data( w = paddle.static.data(
'w', shape=weight.shape, dtype=str(weight.dtype) 'w', shape=weight.shape, dtype=str(weight.dtype)
) )
b = paddle.static.data( b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
'b', shape=bias.shape, dtype=str(bias.dtype)
)
y = fn(x, norm_shape, w, b) y = fn(x, norm_shape, w, b)
blocks = main_program.blocks blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks) paddle.incubate.autograd.to_prim(blocks)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
...@@ -228,13 +228,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): ...@@ -228,13 +228,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
res = exe.run( res = exe.run(
main_program, main_program,
feed={ feed={
'x': inputs, 'x': inputs,
'w': weight, 'w': weight,
'b': bias, 'b': bias,
}, },
fetch_list=[z]) fetch_list=[z],
)
paddle.disable_static() paddle.disable_static()
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
return res return res
...@@ -269,9 +270,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): ...@@ -269,9 +270,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
def test_prim_backward(self): def test_prim_backward(self):
for j in self.dtypes: for j in self.dtypes:
for t in range(0,len(self.shape1s)): for t in range(0, len(self.shape1s)):
attrs.set_dtype(j) attrs.set_dtype(j)
attrs.set_shape(self.n_shape[t], self.shape1s[t], self.shape2s[t], self.shape3s[t]) attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_backward() self.compare_backward()
......
...@@ -104,21 +104,21 @@ def composite_batchnorm( ...@@ -104,21 +104,21 @@ def composite_batchnorm(
@REGISTER_COMPOSITE('layer_norm') @REGISTER_COMPOSITE('layer_norm')
def layernorm_composite (x, scale, bias, epsilon, begin_norm_axis): def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
axis = np.arange(begin_norm_axis,len(x.shape)) axis = np.arange(begin_norm_axis, len(x.shape))
mean_ = mean(x, axis=axis, keepdim=True) mean_ = mean(x, axis=axis, keepdim=True)
difference = x - mean_ difference = x - mean_
var_tmp1 = pow( difference, 2.0) var_tmp1 = pow(difference, 2.0)
variance = mean( var_tmp1, axis=axis, keepdim=True) variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon var_tmp3 = variance + epsilon
sqrt_var = sqrt( var_tmp3 ) sqrt_var = sqrt(var_tmp3)
out = difference / sqrt_var out = difference / sqrt_var
if scale is not None: if scale is not None:
scale = reshape(scale, x.shape[begin_norm_axis:]) scale = reshape(scale, x.shape[begin_norm_axis:])
out = t7 * broadcast_to(scale, out.shape) out = out * broadcast_to(scale, out.shape)
if bias is not None: if bias is not None:
bias = reshape(bias, x.shape[begin_norm_axis:]) bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + broadcast_to(bias, out.shape) out = out + broadcast_to(bias, out.shape)
return out, mean_, variance return out, mean_, variance
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册