未验证 提交 31ddaae2 编写于 作者: W WangXi 提交者: GitHub

fused_attention fused_feedforward api support Model Tensor Parallel (#42985)

上级 360b8383
...@@ -20,156 +20,11 @@ import paddle ...@@ -20,156 +20,11 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.incubate.nn.functional as incubate_f from paddle.incubate.nn import FusedMultiHeadAttention
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
from paddle.nn.initializer import Constant
paddle.enable_static() paddle.enable_static()
def _set_var_distributed(var):
if var is None:
return
var.is_distributed = True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
class ParallelFusedMultiHeadAttention(Layer):
def __init__(self,
embed_dim,
num_heads,
dropout_rate=0.5,
attn_dropout_rate=0.5,
kdim=None,
vdim=None,
normalize_before=False,
need_weights=False,
qkv_weight_attr=None,
qkv_bias_attr=None,
linear_weight_attr=None,
linear_bias_attr=None,
pre_ln_scale_attr=None,
pre_ln_bias_attr=None,
ln_scale_attr=None,
ln_bias_attr=None,
epsilon=1e-5,
nranks=1,
ring_id=-1,
name=None):
super(ParallelFusedMultiHeadAttention, self).__init__()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but received {}".format(embed_dim))
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but received {}".format(num_heads))
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._epsilon = epsilon
self._ring_id = ring_id
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.kdim = kdim
self.vdim = vdim
self.need_weights = need_weights
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now."
# tensor model parallel
assert num_heads % nranks == 0
num_heads = num_heads // nranks
self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
attr=qkv_weight_attr,
dtype=self._dtype,
is_bias=False)
self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
attr=qkv_bias_attr,
dtype=self._dtype,
is_bias=True)
self.linear_weight = self.create_parameter(
shape=[num_heads * self.head_dim, embed_dim],
attr=linear_weight_attr,
dtype=self._dtype,
is_bias=False)
self.linear_bias = self.create_parameter(shape=[embed_dim],
attr=linear_bias_attr,
dtype=self._dtype,
is_bias=True)
# tensor model parallel
if nranks > 1:
assert ring_id != -1
# column parallel
_set_var_distributed(self.qkv_weight)
_set_var_distributed(self.qkv_bias)
# row parallel
_set_var_distributed(self.linear_weight)
if normalize_before:
self.pre_ln_scale = self.create_parameter(
attr=pre_ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.pre_ln_bias = self.create_parameter(attr=pre_ln_bias_attr,
shape=[embed_dim],
is_bias=True)
self.ln_scale = None
self.ln_bias = None
else:
self.pre_ln_scale = None
self.pre_ln_bias = None
self.ln_scale = self.create_parameter(
attr=ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(attr=ln_bias_attr,
shape=[embed_dim],
is_bias=True)
self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate
self.name = name
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
out = incubate_f.fused_multi_head_attention(
x=query,
qkv_weight=self.qkv_weight,
linear_weight=self.linear_weight,
pre_layer_norm=self.normalize_before,
pre_ln_scale=self.pre_ln_scale,
pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
pre_ln_epsilon=self._epsilon,
qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=self._epsilon,
training=self.training,
ring_id=self._ring_id,
name=self.name)
return out
def get_param_attr(weight, bias): def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight)) initializer=fluid.initializer.NumpyArrayInitializer(weight))
...@@ -208,40 +63,40 @@ def create_model(data, rank): ...@@ -208,40 +63,40 @@ def create_model(data, rank):
qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b) qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b) linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b)
attn = ParallelFusedMultiHeadAttention(hidden, attn = FusedMultiHeadAttention(hidden,
n_head, n_head,
dropout_rate=0.0, dropout_rate=0.0,
attn_dropout_rate=0.0, attn_dropout_rate=0.0,
normalize_before=False, normalize_before=False,
qkv_weight_attr=qkv_w_attr, qkv_weight_attr=qkv_w_attr,
qkv_bias_attr=qkv_b_attr, qkv_bias_attr=qkv_b_attr,
linear_weight_attr=linear_w_attr, linear_weight_attr=linear_w_attr,
linear_bias_attr=linear_b_attr, linear_bias_attr=linear_b_attr,
pre_ln_scale_attr=pre_ln_w_attr, pre_ln_scale_attr=pre_ln_w_attr,
pre_ln_bias_attr=pre_ln_b_attr, pre_ln_bias_attr=pre_ln_b_attr,
ln_scale_attr=pre_ln_w_attr, ln_scale_attr=pre_ln_w_attr,
ln_bias_attr=pre_ln_b_attr, ln_bias_attr=pre_ln_b_attr,
nranks=MODEL_PARALLEL_SIZE, nranks=MODEL_PARALLEL_SIZE,
ring_id=0) ring_id=0)
result = attn(data) result = attn(data)
else: else:
pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b) pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b) qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b) linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b)
attn = ParallelFusedMultiHeadAttention(hidden, attn = FusedMultiHeadAttention(hidden,
n_head, n_head,
dropout_rate=0.0, dropout_rate=0.0,
attn_dropout_rate=0.0, attn_dropout_rate=0.0,
normalize_before=False, normalize_before=False,
qkv_weight_attr=qkv_w_attr, qkv_weight_attr=qkv_w_attr,
qkv_bias_attr=qkv_b_attr, qkv_bias_attr=qkv_b_attr,
linear_weight_attr=linear_w_attr, linear_weight_attr=linear_w_attr,
linear_bias_attr=linear_b_attr, linear_bias_attr=linear_b_attr,
pre_ln_scale_attr=pre_ln_w_attr, pre_ln_scale_attr=pre_ln_w_attr,
pre_ln_bias_attr=pre_ln_b_attr, pre_ln_bias_attr=pre_ln_b_attr,
ln_scale_attr=pre_ln_w_attr, ln_scale_attr=pre_ln_w_attr,
ln_bias_attr=pre_ln_b_attr) ln_bias_attr=pre_ln_b_attr)
result = attn(data) result = attn(data)
predict = paddle.sum(result) predict = paddle.sum(result)
......
...@@ -20,11 +20,7 @@ import paddle ...@@ -20,11 +20,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddle.incubate.nn import FusedFeedForward
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.nn.initializer import Constant
paddle.enable_static() paddle.enable_static()
...@@ -34,239 +30,6 @@ IN_SIZE = 2 * MODEL_PARALLEL_SIZE ...@@ -34,239 +30,6 @@ IN_SIZE = 2 * MODEL_PARALLEL_SIZE
OUT_SIZE = 2 * MODEL_PARALLEL_SIZE OUT_SIZE = 2 * MODEL_PARALLEL_SIZE
def fused_feedforward(x,
linear1_weight,
linear2_weight,
linear1_bias=None,
linear2_bias=None,
ln1_scale=None,
ln1_bias=None,
ln2_scale=None,
ln2_bias=None,
dropout1_rate=0.5,
dropout2_rate=0.5,
activation="relu",
ln1_epsilon=1e-5,
ln2_epsilon=1e-5,
pre_layer_norm=False,
training=True,
mode='upscale_in_train',
ring_id=-1,
name=None):
seed = None
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
)
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
helper = LayerHelper("fused_feedforward")
dtype = x.dtype
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'fused_feedforward')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fused_feedforward')
out = helper.create_variable_for_type_inference(x.dtype)
dropout1_mask = helper.create_variable_for_type_inference(
'uint8', stop_gradient=True)
dropout2_mask = helper.create_variable_for_type_inference(
'uint8', stop_gradient=True)
ln1_mean = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln1_variance = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln2_mean = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln2_variance = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
linear1_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln1_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
dropout1_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
dropout2_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
helper.append_op(type='fused_feedforward',
inputs={
'X': x,
'Linear1Weight': linear1_weight,
'Linear1Bias': linear1_bias,
'Linear2Weight': linear2_weight,
'Linear2Bias': linear2_bias,
'Ln1Scale': ln1_scale,
'Ln1Bias': ln1_bias,
'Ln2Scale': ln2_scale,
'Ln2Bias': ln2_bias,
},
outputs={
'Out': out,
'Dropout1Mask': dropout1_mask,
'Dropout2Mask': dropout2_mask,
'Ln1Mean': ln1_mean,
'Ln1Variance': ln1_variance,
'Ln2Mean': ln2_mean,
'Ln2Variance': ln2_variance,
'Linear1Out': linear1_out,
'Ln1Out': ln1_out,
'Dropout1Out': dropout1_out,
'Dropout2Out': dropout2_out,
},
attrs={
'dropout1_rate': dropout1_rate,
'dropout2_rate': dropout2_rate,
'act_method': activation,
'pre_layer_norm': pre_layer_norm,
'ln1_epsilon': ln1_epsilon,
'ln2_epsilon': ln2_epsilon,
'dropout1_is_test': not training,
'dropout2_is_test': not training,
'dropout1_fix_seed': seed is not None,
'dropout2_fix_seed': seed is not None,
'dropout1_seed': seed if seed is not None else 0,
'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode,
'dropout2_implementation': mode,
'ring_id': ring_id,
})
return out
def _set_var_distributed(var):
if var is None:
return
var.is_distributed = True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
class ParallelFusedFeedForward(Layer):
def __init__(self,
d_model,
dim_feedforward,
dropout_rate=0.1,
epsilon=1e-05,
activation="relu",
act_dropout_rate=None,
normalize_before=False,
linear1_weight_attr=None,
linear1_bias_attr=None,
linear2_weight_attr=None,
linear2_bias_attr=None,
ln1_scale_attr=None,
ln1_bias_attr=None,
ln2_scale_attr=None,
ln2_bias_attr=None,
nranks=1,
ring_id=-1,
name=None):
super(ParallelFusedFeedForward, self).__init__()
assert d_model > 0, (
"Expected d_model to be greater than 0, but received {}".format(
d_model))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, but received {}".
format(dim_feedforward))
self._dtype = self._helper.get_default_dtype()
self._d_model = d_model
assert dim_feedforward % nranks == 0
dim_feedforward = dim_feedforward // nranks
self._dim_feedforward = dim_feedforward
self._dropout_rate = dropout_rate
self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self._act_method = activation
self._normalize_before = normalize_before
self._epsilon = epsilon
self._ring_id = ring_id
self._linear1_weight = self.create_parameter(
shape=[d_model, dim_feedforward],
attr=linear1_weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear1_bias = self.create_parameter(shape=[dim_feedforward],
attr=linear1_bias_attr,
dtype=self._dtype,
is_bias=True)
self._linear2_weight = self.create_parameter(
shape=[dim_feedforward, d_model],
attr=linear2_weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear2_bias = self.create_parameter(shape=[d_model],
attr=linear2_bias_attr,
dtype=self._dtype,
is_bias=True)
if nranks > 1:
assert ring_id != -1
# column parallel
_set_var_distributed(self._linear1_weight)
_set_var_distributed(self._linear1_bias)
_set_var_distributed(self._linear2_weight)
if normalize_before:
self._ln1_scale = self.create_parameter(
shape=[d_model],
attr=ln1_scale_attr,
is_bias=False,
default_initializer=Constant(1.0))
self._ln1_bias = self.create_parameter(shape=[d_model],
attr=ln1_bias_attr,
is_bias=True)
self._ln2_scale = None
self._ln2_bias = None
else:
self._ln1_bias = None
self._ln2_bias = None
self._ln2_scale = self.create_parameter(
shape=[d_model],
attr=ln2_scale_attr,
is_bias=False,
default_initializer=Constant(1.0))
self._ln2_bias = self.create_parameter(shape=[d_model],
attr=ln2_bias_attr,
is_bias=True)
self.name = name
def forward(self, src, cache=None):
out = fused_feedforward(src,
self._linear1_weight,
self._linear2_weight,
self._linear1_bias,
self._linear2_bias,
self._ln1_scale,
self._ln1_bias,
self._ln2_scale,
self._ln2_bias,
dropout1_rate=self._act_dropout_rate,
dropout2_rate=self._dropout_rate,
activation=self._act_method,
ln1_epsilon=self._epsilon,
ln2_epsilon=self._epsilon,
pre_layer_norm=self._normalize_before,
training=self.training,
ring_id=self._ring_id,
name=self.name)
return out
def get_param_attr(weight, bias): def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight)) initializer=fluid.initializer.NumpyArrayInitializer(weight))
...@@ -295,19 +58,19 @@ def create_model(data, rank): ...@@ -295,19 +58,19 @@ def create_model(data, rank):
w0_attr, b0_attr = get_param_attr(col_w0, col_b0) w0_attr, b0_attr = get_param_attr(col_w0, col_b0)
w1_attr, b1_attr = get_param_attr(row_w1, b1) w1_attr, b1_attr = get_param_attr(row_w1, b1)
ffn = ParallelFusedFeedForward(IN_SIZE, ffn = FusedFeedForward(IN_SIZE,
OUT_SIZE, OUT_SIZE,
dropout_rate=0.0, dropout_rate=0.0,
activation='gelu', activation='gelu',
normalize_before=True, normalize_before=True,
linear1_weight_attr=w0_attr, linear1_weight_attr=w0_attr,
linear1_bias_attr=b0_attr, linear1_bias_attr=b0_attr,
linear2_weight_attr=w1_attr, linear2_weight_attr=w1_attr,
linear2_bias_attr=b1_attr, linear2_bias_attr=b1_attr,
ln1_scale_attr=ln_w_attr, ln1_scale_attr=ln_w_attr,
ln1_bias_attr=ln_b_attr, ln1_bias_attr=ln_b_attr,
nranks=MODEL_PARALLEL_SIZE, nranks=MODEL_PARALLEL_SIZE,
ring_id=0) ring_id=0)
#ffn.eval() #ffn.eval()
result = ffn(data) result = ffn(data)
else: else:
...@@ -315,17 +78,17 @@ def create_model(data, rank): ...@@ -315,17 +78,17 @@ def create_model(data, rank):
w0_attr, b0_attr = get_param_attr(w0, b0) w0_attr, b0_attr = get_param_attr(w0, b0)
w1_attr, b1_attr = get_param_attr(w1, b1) w1_attr, b1_attr = get_param_attr(w1, b1)
ffn = ParallelFusedFeedForward(IN_SIZE, ffn = FusedFeedForward(IN_SIZE,
OUT_SIZE, OUT_SIZE,
dropout_rate=0.0, dropout_rate=0.0,
activation='gelu', activation='gelu',
normalize_before=True, normalize_before=True,
linear1_weight_attr=w0_attr, linear1_weight_attr=w0_attr,
linear1_bias_attr=b0_attr, linear1_bias_attr=b0_attr,
linear2_weight_attr=w1_attr, linear2_weight_attr=w1_attr,
linear2_bias_attr=b1_attr, linear2_bias_attr=b1_attr,
ln1_scale_attr=ln_w_attr, ln1_scale_attr=ln_w_attr,
ln1_bias_attr=ln_b_attr) ln1_bias_attr=ln_b_attr)
#ffn.eval() #ffn.eval()
result = ffn(data) result = ffn(data)
......
...@@ -83,7 +83,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -83,7 +83,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
if ln_bias is None: if ln_bias is None:
has_bias = False has_bias = False
if (pre_layer_norm): if pre_layer_norm:
ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias) ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)
num_head = qkv_weight.shape[1] num_head = qkv_weight.shape[1]
...@@ -97,7 +97,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -97,7 +97,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
if qkv_bias is not None: if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2]) qkv_bias.shape[2])
if (pre_layer_norm): if pre_layer_norm:
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight) qkv = fc(ln_out, qkv_weight)
if qkv_bias is not None: if qkv_bias is not None:
...@@ -239,12 +239,12 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -239,12 +239,12 @@ class TestFusedAttentionAPI(unittest.TestCase):
attn_mask_tensor = paddle.to_tensor(self.attn_mask) attn_mask_tensor = paddle.to_tensor(self.attn_mask)
else: else:
attn_mask_tensor = None attn_mask_tensor = None
fused_attn = FusedMultiHeadAttention(self.embed_dim, self.num_heads, fused_attn = FusedMultiHeadAttention(
self.dropout_prob, self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.vdim, self.pre_layer_norm, self.need_weight, self.weight_attr, self.bias_attr,
self.need_weight, self.weight_attr, self.weight_attr, self.bias_attr, self.weight_attr, self.bias_attr,
self.bias_attr) self.weight_attr, self.bias_attr)
if self.bias_attr is not False: if self.bias_attr is not False:
qkv_bias = np.random.random( qkv_bias = np.random.random(
fused_attn.qkv_bias.shape).astype('float32') fused_attn.qkv_bias.shape).astype('float32')
...@@ -260,13 +260,19 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -260,13 +260,19 @@ class TestFusedAttentionAPI(unittest.TestCase):
if self.bias_attr is not False: if self.bias_attr is not False:
fused_attn_qkv_bias = fused_attn.qkv_bias.numpy() fused_attn_qkv_bias = fused_attn.qkv_bias.numpy()
fused_attn_linear_bias = fused_attn.linear_bias.numpy() fused_attn_linear_bias = fused_attn.linear_bias.numpy()
fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy() if self.pre_layer_norm:
fused_attn_ln_bias = fused_attn.ln_bias.numpy() fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy()
fused_attn_ln_bias = None
else:
fused_attn_pre_ln_bias = None
fused_attn_ln_bias = fused_attn.ln_bias.numpy()
ref_out = compute_reference( ref_out = compute_reference(
self.pre_layer_norm, self.query, self.attn_mask, self.pre_layer_norm, self.query, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn_pre_ln_bias, fused_attn.pre_ln_scale.numpy() if self.pre_layer_norm else None,
fused_attn.ln_scale.numpy(), fused_attn_ln_bias, fused_attn_pre_ln_bias,
fused_attn.ln_scale.numpy() if not self.pre_layer_norm else None,
fused_attn_ln_bias,
fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias, fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(), fused_attn_linear_bias) fused_attn.linear_weight.numpy(), fused_attn_linear_bias)
np.testing.assert_allclose(ref_out, np.testing.assert_allclose(ref_out,
...@@ -275,12 +281,12 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -275,12 +281,12 @@ class TestFusedAttentionAPI(unittest.TestCase):
atol=self.atol) atol=self.atol)
def run_static(self): def run_static(self):
fused_attn = FusedMultiHeadAttention(self.embed_dim, self.num_heads, fused_attn = FusedMultiHeadAttention(
self.dropout_prob, self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.vdim, self.pre_layer_norm, self.need_weight, self.weight_attr, self.bias_attr,
self.need_weight, self.weight_attr, self.weight_attr, self.bias_attr, self.weight_attr, self.bias_attr,
self.bias_attr) self.weight_attr, self.bias_attr)
x = paddle.static.data( x = paddle.static.data(
name='X', name='X',
...@@ -304,58 +310,118 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -304,58 +310,118 @@ class TestFusedAttentionAPI(unittest.TestCase):
qkv_bias = None qkv_bias = None
linear_bias = None linear_bias = None
ln_scale = None
ln_2_scale = None
ln_bias = None ln_bias = None
ln_2_bias = None ln_2_bias = None
if self.has_attn_mask: if self.has_attn_mask:
if self.bias_attr is False: if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, out_linear_weight, ln_scale = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
"SrcMask": self.attn_mask "X": self.query,
}, "SrcMask": self.attn_mask
fetch_list=[ },
final_out, fused_attn.qkv_weight, fetch_list=[
fused_attn.linear_weight, fused_attn.pre_ln_scale, final_out,
fused_attn.ln_scale fused_attn.qkv_weight,
]) fused_attn.linear_weight,
fused_attn.pre_ln_scale,
])
else:
out, qkv_weight, out_linear_weight, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
"SrcMask": self.attn_mask
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.ln_scale
])
else: else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
"SrcMask": self.attn_mask "X": self.query,
}, "SrcMask": self.attn_mask
fetch_list=[ },
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, fetch_list=[
fused_attn.linear_weight, fused_attn.linear_bias, final_out,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.qkv_weight,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.qkv_bias,
]) fused_attn.linear_weight,
fused_attn.linear_bias,
fused_attn.pre_ln_scale,
fused_attn.pre_ln_bias,
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
"SrcMask": self.attn_mask
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.qkv_bias, fused_attn.linear_weight,
fused_attn.linear_bias, fused_attn.ln_scale,
fused_attn.ln_bias
])
else: else:
if self.bias_attr is False: if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, out_linear_weight, ln_scale = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
}, "X": self.query,
fetch_list=[ },
final_out, fused_attn.qkv_weight, fetch_list=[
fused_attn.linear_weight, fused_attn.pre_ln_scale, final_out,
fused_attn.ln_scale fused_attn.qkv_weight,
]) fused_attn.linear_weight,
fused_attn.pre_ln_scale,
])
else:
out, qkv_weight, out_linear_weight, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.ln_scale
])
else: else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
}, "X": self.query,
fetch_list=[ },
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, fetch_list=[
fused_attn.linear_weight, fused_attn.linear_bias, final_out,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.qkv_weight,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.qkv_bias,
]) fused_attn.linear_weight,
fused_attn.linear_bias,
fused_attn.pre_ln_scale,
fused_attn.pre_ln_bias,
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.qkv_bias, fused_attn.linear_weight,
fused_attn.linear_bias, fused_attn.ln_scale,
fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self): def test_static_api(self):
......
...@@ -45,6 +45,7 @@ def fused_feedforward(x, ...@@ -45,6 +45,7 @@ def fused_feedforward(x,
pre_layer_norm=False, pre_layer_norm=False,
training=True, training=True,
mode='upscale_in_train', mode='upscale_in_train',
ring_id=-1,
name=None): name=None):
r""" r"""
This is a fusion operator to compute feed forward layer in transformer model architecture. This is a fusion operator to compute feed forward layer in transformer model architecture.
...@@ -88,6 +89,7 @@ def fused_feedforward(x, ...@@ -88,6 +89,7 @@ def fused_feedforward(x,
- train: out = input * mask - train: out = input * mask
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -132,7 +134,8 @@ def fused_feedforward(x, ...@@ -132,7 +134,8 @@ def fused_feedforward(x,
"dropout1_fix_seed", seed is not None, "dropout2_fix_seed", seed "dropout1_fix_seed", seed is not None, "dropout2_fix_seed", seed
is not None, "dropout1_seed", seed if seed is not None else 0, is not None, "dropout1_seed", seed if seed is not None else 0,
"dropout2_seed", seed if seed is not None else 0, "dropout2_seed", seed if seed is not None else 0,
'dropout1_implementation', mode, 'dropout2_implementation', mode) 'dropout1_implementation', mode, 'dropout2_implementation', mode,
'ring_id', ring_id)
return out return out
helper = LayerHelper("fused_feedforward") helper = LayerHelper("fused_feedforward")
...@@ -206,7 +209,8 @@ def fused_feedforward(x, ...@@ -206,7 +209,8 @@ def fused_feedforward(x,
'dropout1_seed': seed if seed is not None else 0, 'dropout1_seed': seed if seed is not None else 0,
'dropout2_seed': seed if seed is not None else 0, 'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode, 'dropout1_implementation': mode,
'dropout2_implementation': mode 'dropout2_implementation': mode,
'ring_id': ring_id,
}) })
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册