未验证 提交 c5f4a9cc 编写于 作者: C carryyu 提交者: GitHub

add post layer norm (#44931)

上级 9336dd3e
...@@ -39,6 +39,7 @@ default_main_program().random_seed = 42 ...@@ -39,6 +39,7 @@ default_main_program().random_seed = 42
class TestFusedMultiTransformerOp(OpTest): class TestFusedMultiTransformerOp(OpTest):
def setUp(self): def setUp(self):
self.config() self.config()
self.generate_input_data() self.generate_input_data()
...@@ -61,36 +62,30 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -61,36 +62,30 @@ class TestFusedMultiTransformerOp(OpTest):
bias_attr = paddle.fluid.ParamAttr( bias_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.Constant(value=0.0005)) initializer=paddle.fluid.initializer.Constant(value=0.0005))
self.q_proj = Linear( self.q_proj = Linear(self.embed_dim,
self.embed_dim,
self.embed_dim, self.embed_dim,
self.weight_attr, self.weight_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
#bias_attr=self.bias_attr) #bias_attr=self.bias_attr)
self.k_proj = Linear( self.k_proj = Linear(self.kdim,
self.kdim,
self.embed_dim, self.embed_dim,
self.weight_attr, self.weight_attr,
bias_attr=self.bias_attr) bias_attr=self.bias_attr)
self.v_proj = Linear( self.v_proj = Linear(self.vdim,
self.vdim,
self.embed_dim, self.embed_dim,
self.weight_attr, self.weight_attr,
bias_attr=self.bias_attr) bias_attr=self.bias_attr)
self.out_proj = Linear( self.out_proj = Linear(self.embed_dim,
self.embed_dim,
self.embed_dim, self.embed_dim,
self.weight_attr, self.weight_attr,
bias_attr=self.bias_attr) bias_attr=self.bias_attr)
self.ffn1_proj = Linear( self.ffn1_proj = Linear(self.embed_dim,
self.embed_dim,
4 * self.embed_dim, 4 * self.embed_dim,
self.weight_attr, self.weight_attr,
bias_attr=self.bias_attr) bias_attr=self.bias_attr)
self.ffn2_proj = Linear( self.ffn2_proj = Linear(4 * self.embed_dim,
4 * self.embed_dim,
self.embed_dim, self.embed_dim,
self.weight_attr, self.weight_attr,
bias_attr=self.bias_attr) bias_attr=self.bias_attr)
...@@ -228,8 +223,10 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -228,8 +223,10 @@ class TestFusedMultiTransformerOp(OpTest):
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len] # --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul( qk_out = layers.matmul(x=q_out,
x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) y=k_out,
transpose_y=True,
alpha=self.head_dim**-0.5)
if self.debug: if self.debug:
print('qk out is') print('qk out is')
...@@ -249,8 +246,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -249,8 +246,7 @@ class TestFusedMultiTransformerOp(OpTest):
print('softmax out is') print('softmax out is')
print(softmax_out[0][0][0]) print(softmax_out[0][0][0])
if self.dropout_prob: if self.dropout_prob:
dropout_out = F.dropout( dropout_out = F.dropout(softmax_out,
softmax_out,
self.dropout_prob, self.dropout_prob,
training=self.training, training=self.training,
mode="upscale_in_train") mode="upscale_in_train")
...@@ -265,8 +261,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -265,8 +261,7 @@ class TestFusedMultiTransformerOp(OpTest):
print('fmha out is') print('fmha out is')
print(fmha_out[0][0][0]) print(fmha_out[0][0][0])
out_linear_in = tensor.reshape( out_linear_in = tensor.reshape(
x=fmha_out, x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
out = self.out_proj(out_linear_in) out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out) residual_out = residual + self.dropout(out)
...@@ -296,44 +291,44 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -296,44 +291,44 @@ class TestFusedMultiTransformerOp(OpTest):
def GetFusedMultiTransformerOut(self): def GetFusedMultiTransformerOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor( q_proj_weight = paddle.to_tensor(self.q_proj.weight,
self.q_proj.weight, stop_gradient=False) stop_gradient=False)
k_proj_weight = paddle.to_tensor( k_proj_weight = paddle.to_tensor(self.k_proj.weight,
self.k_proj.weight, stop_gradient=False) stop_gradient=False)
v_proj_weight = paddle.to_tensor( v_proj_weight = paddle.to_tensor(self.v_proj.weight,
self.v_proj.weight, stop_gradient=False) stop_gradient=False)
out_linear_weight = paddle.to_tensor( out_linear_weight = paddle.to_tensor(self.out_proj.weight,
self.out_proj.weight, stop_gradient=False) stop_gradient=False)
ffn1_weight = paddle.to_tensor( ffn1_weight = paddle.to_tensor(self.ffn1_proj.weight,
self.ffn1_proj.weight, stop_gradient=False) stop_gradient=False)
ffn2_weight = paddle.to_tensor( ffn2_weight = paddle.to_tensor(self.ffn2_proj.weight,
self.ffn2_proj.weight, stop_gradient=False) stop_gradient=False)
if self.bias_attr is False: if self.bias_attr is False:
qkv_bias_tensor = None qkv_bias_tensor = None
out_linear_bias = None out_linear_bias = None
else: else:
q_proj_bias = paddle.to_tensor( q_proj_bias = paddle.to_tensor(self.q_proj.bias,
self.q_proj.bias, stop_gradient=False) stop_gradient=False)
k_proj_bias = paddle.to_tensor( k_proj_bias = paddle.to_tensor(self.k_proj.bias,
self.k_proj.bias, stop_gradient=False) stop_gradient=False)
v_proj_bias = paddle.to_tensor( v_proj_bias = paddle.to_tensor(self.v_proj.bias,
self.v_proj.bias, stop_gradient=False) stop_gradient=False)
qkv_bias = np.concatenate( qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor( out_linear_bias = paddle.to_tensor(self.out_proj.bias,
self.out_proj.bias, stop_gradient=False) stop_gradient=False)
ffn1_bias = paddle.to_tensor( ffn1_bias = paddle.to_tensor(self.ffn1_proj.bias,
self.ffn1_proj.bias, stop_gradient=False) stop_gradient=False)
ffn2_bias = paddle.to_tensor( ffn2_bias = paddle.to_tensor(self.ffn2_proj.bias,
self.ffn2_proj.bias, stop_gradient=False) stop_gradient=False)
ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False) ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False)
ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False) ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False)
ffn_ln_scale = paddle.to_tensor( ffn_ln_scale = paddle.to_tensor(self.ffn_norm.weight,
self.ffn_norm.weight, stop_gradient=False) stop_gradient=False)
ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False) ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False)
q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
...@@ -351,8 +346,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -351,8 +346,7 @@ class TestFusedMultiTransformerOp(OpTest):
cache_kvs = [] cache_kvs = []
max_seq_length = (self.cache_length + 128) // 128 * 128 max_seq_length = (self.cache_length + 128) // 128 * 128
cache_kv = np.zeros( cache_kv = np.zeros([
[
2, self.batch_size, self.num_heads, max_seq_length, 2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim self.head_dim
], ],
...@@ -384,8 +378,9 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -384,8 +378,9 @@ class TestFusedMultiTransformerOp(OpTest):
assert self.query_length == self.cache_length assert self.query_length == self.cache_length
cache_kv[:] = 0 cache_kv[:] = 0
else: else:
time_step = paddle.to_tensor( time_step = paddle.to_tensor([self.cache_length],
[self.cache_length], dtype='int32', place=paddle.CPUPlace()) dtype='int32',
place=paddle.CPUPlace())
if self.has_attn_mask: if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else: else:
...@@ -417,12 +412,10 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -417,12 +412,10 @@ class TestFusedMultiTransformerOp(OpTest):
ffn_ln_scales.append(ffn_ln_scale) ffn_ln_scales.append(ffn_ln_scale)
ffn_ln_biases.append(ffn_ln_bias) ffn_ln_biases.append(ffn_ln_bias)
if self.has_cache_kv: if self.has_cache_kv:
cache_kvs.append( cache_kvs.append(paddle.to_tensor(cache_kv,
paddle.to_tensor( stop_gradient=False))
cache_kv, stop_gradient=False))
final_out = fused_multi_transformer( final_out = fused_multi_transformer(x,
x,
ln_scales, ln_scales,
ln_biases, ln_biases,
qkv_weights, qkv_weights,
...@@ -463,9 +456,9 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -463,9 +456,9 @@ class TestFusedMultiTransformerOp(OpTest):
if self.debug: if self.debug:
print("cache_k out timestep=128") print("cache_k out timestep=128")
print(cache_kv_out[0].reshape([ print(cache_kv_out[0].reshape(
2, bsz, num_head, v_elems, max_seq_len, elems [2, bsz, num_head, v_elems, max_seq_len,
])[0, 0, 0, :, self.cache_length, :]) elems])[0, 0, 0, :, self.cache_length, :])
print("cache_v out timestep=128") print("cache_v out timestep=128")
print(cache_kv_out[0][1, 0, 0, self.cache_length, :]) print(cache_kv_out[0][1, 0, 0, self.cache_length, :])
...@@ -486,18 +479,25 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -486,18 +479,25 @@ class TestFusedMultiTransformerOp(OpTest):
cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :] cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :]
np.testing.assert_allclose( np.testing.assert_allclose(cache_k_ref,
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol) cache_k,
np.testing.assert_allclose( rtol=self.rtol,
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol) atol=self.atol)
np.testing.assert_allclose(cache_v_ref,
cache_v,
rtol=self.rtol,
atol=self.atol)
if i == 0: if i == 0:
break break
np.testing.assert_allclose( np.testing.assert_allclose(final_out_ref,
final_out_ref, final_out, rtol=self.rtol, atol=self.atol) final_out,
rtol=self.rtol,
atol=self.atol)
class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
self.x_type = np.float16 self.x_type = np.float16
...@@ -505,6 +505,7 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): ...@@ -505,6 +505,7 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
self.has_cache_kv = True self.has_cache_kv = True
...@@ -514,6 +515,7 @@ class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp): ...@@ -514,6 +515,7 @@ class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
self.has_cache_kv = True self.has_cache_kv = True
...@@ -523,6 +525,7 @@ class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp): ...@@ -523,6 +525,7 @@ class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
self.has_cache_kv = True self.has_cache_kv = True
...@@ -530,12 +533,68 @@ class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp): ...@@ -530,12 +533,68 @@ class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
self.has_cache_kv = True self.has_cache_kv = True
self.gen_cache_kv = True self.gen_cache_kv = True
self.x_type = np.float16 self.x_type = np.float16
self.layers = 3 # odd layers self.layers = 3 # odd layers
self.pre_layer_norm = False
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册