未验证 提交 c5f4a9cc 编写于 作者: C carryyu 提交者: GitHub

add post layer norm (#44931)

上级 9336dd3e
......@@ -39,6 +39,7 @@ default_main_program().random_seed = 42
class TestFusedMultiTransformerOp(OpTest):
def setUp(self):
self.config()
self.generate_input_data()
......@@ -61,39 +62,33 @@ class TestFusedMultiTransformerOp(OpTest):
bias_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.Constant(value=0.0005))
self.q_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=bias_attr)
self.q_proj = Linear(self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=bias_attr)
#bias_attr=self.bias_attr)
self.k_proj = Linear(
self.kdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.v_proj = Linear(
self.vdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.out_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn1_proj = Linear(
self.embed_dim,
4 * self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn2_proj = Linear(
4 * self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.k_proj = Linear(self.kdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.v_proj = Linear(self.vdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.out_proj = Linear(self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn1_proj = Linear(self.embed_dim,
4 * self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn2_proj = Linear(4 * self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
paddle.set_default_dtype(np.float32)
self.norm = LayerNorm(self.embed_dim)
......@@ -228,8 +223,10 @@ class TestFusedMultiTransformerOp(OpTest):
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(
x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5)
qk_out = layers.matmul(x=q_out,
y=k_out,
transpose_y=True,
alpha=self.head_dim**-0.5)
if self.debug:
print('qk out is')
......@@ -249,11 +246,10 @@ class TestFusedMultiTransformerOp(OpTest):
print('softmax out is')
print(softmax_out[0][0][0])
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
dropout_out = F.dropout(softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
......@@ -265,8 +261,7 @@ class TestFusedMultiTransformerOp(OpTest):
print('fmha out is')
print(fmha_out[0][0][0])
out_linear_in = tensor.reshape(
x=fmha_out,
shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out)
......@@ -296,44 +291,44 @@ class TestFusedMultiTransformerOp(OpTest):
def GetFusedMultiTransformerOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor(
self.q_proj.weight, stop_gradient=False)
k_proj_weight = paddle.to_tensor(
self.k_proj.weight, stop_gradient=False)
v_proj_weight = paddle.to_tensor(
self.v_proj.weight, stop_gradient=False)
out_linear_weight = paddle.to_tensor(
self.out_proj.weight, stop_gradient=False)
ffn1_weight = paddle.to_tensor(
self.ffn1_proj.weight, stop_gradient=False)
ffn2_weight = paddle.to_tensor(
self.ffn2_proj.weight, stop_gradient=False)
q_proj_weight = paddle.to_tensor(self.q_proj.weight,
stop_gradient=False)
k_proj_weight = paddle.to_tensor(self.k_proj.weight,
stop_gradient=False)
v_proj_weight = paddle.to_tensor(self.v_proj.weight,
stop_gradient=False)
out_linear_weight = paddle.to_tensor(self.out_proj.weight,
stop_gradient=False)
ffn1_weight = paddle.to_tensor(self.ffn1_proj.weight,
stop_gradient=False)
ffn2_weight = paddle.to_tensor(self.ffn2_proj.weight,
stop_gradient=False)
if self.bias_attr is False:
qkv_bias_tensor = None
out_linear_bias = None
else:
q_proj_bias = paddle.to_tensor(
self.q_proj.bias, stop_gradient=False)
k_proj_bias = paddle.to_tensor(
self.k_proj.bias, stop_gradient=False)
v_proj_bias = paddle.to_tensor(
self.v_proj.bias, stop_gradient=False)
q_proj_bias = paddle.to_tensor(self.q_proj.bias,
stop_gradient=False)
k_proj_bias = paddle.to_tensor(self.k_proj.bias,
stop_gradient=False)
v_proj_bias = paddle.to_tensor(self.v_proj.bias,
stop_gradient=False)
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False)
ffn1_bias = paddle.to_tensor(
self.ffn1_proj.bias, stop_gradient=False)
ffn2_bias = paddle.to_tensor(
self.ffn2_proj.bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor(self.out_proj.bias,
stop_gradient=False)
ffn1_bias = paddle.to_tensor(self.ffn1_proj.bias,
stop_gradient=False)
ffn2_bias = paddle.to_tensor(self.ffn2_proj.bias,
stop_gradient=False)
ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False)
ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False)
ffn_ln_scale = paddle.to_tensor(
self.ffn_norm.weight, stop_gradient=False)
ffn_ln_scale = paddle.to_tensor(self.ffn_norm.weight,
stop_gradient=False)
ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False)
q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
......@@ -351,12 +346,11 @@ class TestFusedMultiTransformerOp(OpTest):
cache_kvs = []
max_seq_length = (self.cache_length + 128) // 128 * 128
cache_kv = np.zeros(
[
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
elems = 4
if self.x_type is np.float16:
......@@ -384,8 +378,9 @@ class TestFusedMultiTransformerOp(OpTest):
assert self.query_length == self.cache_length
cache_kv[:] = 0
else:
time_step = paddle.to_tensor(
[self.cache_length], dtype='int32', place=paddle.CPUPlace())
time_step = paddle.to_tensor([self.cache_length],
dtype='int32',
place=paddle.CPUPlace())
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
......@@ -417,31 +412,29 @@ class TestFusedMultiTransformerOp(OpTest):
ffn_ln_scales.append(ffn_ln_scale)
ffn_ln_biases.append(ffn_ln_bias)
if self.has_cache_kv:
cache_kvs.append(
paddle.to_tensor(
cache_kv, stop_gradient=False))
final_out = fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
out_weights,
out_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
training=self.training)
cache_kvs.append(paddle.to_tensor(cache_kv,
stop_gradient=False))
final_out = fused_multi_transformer(x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
out_weights,
out_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
training=self.training)
if self.has_cache_kv:
return final_out[0], final_out[1]
......@@ -463,9 +456,9 @@ class TestFusedMultiTransformerOp(OpTest):
if self.debug:
print("cache_k out timestep=128")
print(cache_kv_out[0].reshape([
2, bsz, num_head, v_elems, max_seq_len, elems
])[0, 0, 0, :, self.cache_length, :])
print(cache_kv_out[0].reshape(
[2, bsz, num_head, v_elems, max_seq_len,
elems])[0, 0, 0, :, self.cache_length, :])
print("cache_v out timestep=128")
print(cache_kv_out[0][1, 0, 0, self.cache_length, :])
......@@ -486,18 +479,25 @@ class TestFusedMultiTransformerOp(OpTest):
cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :]
np.testing.assert_allclose(
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(cache_k_ref,
cache_k,
rtol=self.rtol,
atol=self.atol)
np.testing.assert_allclose(cache_v_ref,
cache_v,
rtol=self.rtol,
atol=self.atol)
if i == 0:
break
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(final_out_ref,
final_out,
rtol=self.rtol,
atol=self.atol)
class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
......@@ -505,6 +505,7 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
......@@ -514,6 +515,7 @@ class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
......@@ -523,6 +525,7 @@ class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
......@@ -530,12 +533,68 @@ class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册