未验证 提交 31ddaae2 编写于 作者: W WangXi 提交者: GitHub

fused_attention fused_feedforward api support Model Tensor Parallel (#42985)

上级 360b8383
...@@ -20,156 +20,11 @@ import paddle ...@@ -20,156 +20,11 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.incubate.nn.functional as incubate_f from paddle.incubate.nn import FusedMultiHeadAttention
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
from paddle.nn.initializer import Constant
paddle.enable_static() paddle.enable_static()
def _set_var_distributed(var):
if var is None:
return
var.is_distributed = True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
class ParallelFusedMultiHeadAttention(Layer):
def __init__(self,
embed_dim,
num_heads,
dropout_rate=0.5,
attn_dropout_rate=0.5,
kdim=None,
vdim=None,
normalize_before=False,
need_weights=False,
qkv_weight_attr=None,
qkv_bias_attr=None,
linear_weight_attr=None,
linear_bias_attr=None,
pre_ln_scale_attr=None,
pre_ln_bias_attr=None,
ln_scale_attr=None,
ln_bias_attr=None,
epsilon=1e-5,
nranks=1,
ring_id=-1,
name=None):
super(ParallelFusedMultiHeadAttention, self).__init__()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but received {}".format(embed_dim))
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but received {}".format(num_heads))
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._epsilon = epsilon
self._ring_id = ring_id
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.kdim = kdim
self.vdim = vdim
self.need_weights = need_weights
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now."
# tensor model parallel
assert num_heads % nranks == 0
num_heads = num_heads // nranks
self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
attr=qkv_weight_attr,
dtype=self._dtype,
is_bias=False)
self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
attr=qkv_bias_attr,
dtype=self._dtype,
is_bias=True)
self.linear_weight = self.create_parameter(
shape=[num_heads * self.head_dim, embed_dim],
attr=linear_weight_attr,
dtype=self._dtype,
is_bias=False)
self.linear_bias = self.create_parameter(shape=[embed_dim],
attr=linear_bias_attr,
dtype=self._dtype,
is_bias=True)
# tensor model parallel
if nranks > 1:
assert ring_id != -1
# column parallel
_set_var_distributed(self.qkv_weight)
_set_var_distributed(self.qkv_bias)
# row parallel
_set_var_distributed(self.linear_weight)
if normalize_before:
self.pre_ln_scale = self.create_parameter(
attr=pre_ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.pre_ln_bias = self.create_parameter(attr=pre_ln_bias_attr,
shape=[embed_dim],
is_bias=True)
self.ln_scale = None
self.ln_bias = None
else:
self.pre_ln_scale = None
self.pre_ln_bias = None
self.ln_scale = self.create_parameter(
attr=ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(attr=ln_bias_attr,
shape=[embed_dim],
is_bias=True)
self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate
self.name = name
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
out = incubate_f.fused_multi_head_attention(
x=query,
qkv_weight=self.qkv_weight,
linear_weight=self.linear_weight,
pre_layer_norm=self.normalize_before,
pre_ln_scale=self.pre_ln_scale,
pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
pre_ln_epsilon=self._epsilon,
qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=self._epsilon,
training=self.training,
ring_id=self._ring_id,
name=self.name)
return out
def get_param_attr(weight, bias): def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight)) initializer=fluid.initializer.NumpyArrayInitializer(weight))
...@@ -208,40 +63,40 @@ def create_model(data, rank): ...@@ -208,40 +63,40 @@ def create_model(data, rank):
qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b) qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b) linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b)
attn = ParallelFusedMultiHeadAttention(hidden, attn = FusedMultiHeadAttention(hidden,
n_head, n_head,
dropout_rate=0.0, dropout_rate=0.0,
attn_dropout_rate=0.0, attn_dropout_rate=0.0,
normalize_before=False, normalize_before=False,
qkv_weight_attr=qkv_w_attr, qkv_weight_attr=qkv_w_attr,
qkv_bias_attr=qkv_b_attr, qkv_bias_attr=qkv_b_attr,
linear_weight_attr=linear_w_attr, linear_weight_attr=linear_w_attr,
linear_bias_attr=linear_b_attr, linear_bias_attr=linear_b_attr,
pre_ln_scale_attr=pre_ln_w_attr, pre_ln_scale_attr=pre_ln_w_attr,
pre_ln_bias_attr=pre_ln_b_attr, pre_ln_bias_attr=pre_ln_b_attr,
ln_scale_attr=pre_ln_w_attr, ln_scale_attr=pre_ln_w_attr,
ln_bias_attr=pre_ln_b_attr, ln_bias_attr=pre_ln_b_attr,
nranks=MODEL_PARALLEL_SIZE, nranks=MODEL_PARALLEL_SIZE,
ring_id=0) ring_id=0)
result = attn(data) result = attn(data)
else: else:
pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b) pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b) qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b) linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b)
attn = ParallelFusedMultiHeadAttention(hidden, attn = FusedMultiHeadAttention(hidden,
n_head, n_head,
dropout_rate=0.0, dropout_rate=0.0,
attn_dropout_rate=0.0, attn_dropout_rate=0.0,
normalize_before=False, normalize_before=False,
qkv_weight_attr=qkv_w_attr, qkv_weight_attr=qkv_w_attr,
qkv_bias_attr=qkv_b_attr, qkv_bias_attr=qkv_b_attr,
linear_weight_attr=linear_w_attr, linear_weight_attr=linear_w_attr,
linear_bias_attr=linear_b_attr, linear_bias_attr=linear_b_attr,
pre_ln_scale_attr=pre_ln_w_attr, pre_ln_scale_attr=pre_ln_w_attr,
pre_ln_bias_attr=pre_ln_b_attr, pre_ln_bias_attr=pre_ln_b_attr,
ln_scale_attr=pre_ln_w_attr, ln_scale_attr=pre_ln_w_attr,
ln_bias_attr=pre_ln_b_attr) ln_bias_attr=pre_ln_b_attr)
result = attn(data) result = attn(data)
predict = paddle.sum(result) predict = paddle.sum(result)
......
...@@ -20,11 +20,7 @@ import paddle ...@@ -20,11 +20,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddle.incubate.nn import FusedFeedForward
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.nn.initializer import Constant
paddle.enable_static() paddle.enable_static()
...@@ -34,239 +30,6 @@ IN_SIZE = 2 * MODEL_PARALLEL_SIZE ...@@ -34,239 +30,6 @@ IN_SIZE = 2 * MODEL_PARALLEL_SIZE
OUT_SIZE = 2 * MODEL_PARALLEL_SIZE OUT_SIZE = 2 * MODEL_PARALLEL_SIZE
def fused_feedforward(x,
linear1_weight,
linear2_weight,
linear1_bias=None,
linear2_bias=None,
ln1_scale=None,
ln1_bias=None,
ln2_scale=None,
ln2_bias=None,
dropout1_rate=0.5,
dropout2_rate=0.5,
activation="relu",
ln1_epsilon=1e-5,
ln2_epsilon=1e-5,
pre_layer_norm=False,
training=True,
mode='upscale_in_train',
ring_id=-1,
name=None):
seed = None
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
)
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
helper = LayerHelper("fused_feedforward")
dtype = x.dtype
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'fused_feedforward')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fused_feedforward')
out = helper.create_variable_for_type_inference(x.dtype)
dropout1_mask = helper.create_variable_for_type_inference(
'uint8', stop_gradient=True)
dropout2_mask = helper.create_variable_for_type_inference(
'uint8', stop_gradient=True)
ln1_mean = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln1_variance = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln2_mean = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln2_variance = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
linear1_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
ln1_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
dropout1_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
dropout2_out = helper.create_variable_for_type_inference(x.dtype,
stop_gradient=True)
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
helper.append_op(type='fused_feedforward',
inputs={
'X': x,
'Linear1Weight': linear1_weight,
'Linear1Bias': linear1_bias,
'Linear2Weight': linear2_weight,
'Linear2Bias': linear2_bias,
'Ln1Scale': ln1_scale,
'Ln1Bias': ln1_bias,
'Ln2Scale': ln2_scale,
'Ln2Bias': ln2_bias,
},
outputs={
'Out': out,
'Dropout1Mask': dropout1_mask,
'Dropout2Mask': dropout2_mask,
'Ln1Mean': ln1_mean,
'Ln1Variance': ln1_variance,
'Ln2Mean': ln2_mean,
'Ln2Variance': ln2_variance,
'Linear1Out': linear1_out,
'Ln1Out': ln1_out,
'Dropout1Out': dropout1_out,
'Dropout2Out': dropout2_out,
},
attrs={
'dropout1_rate': dropout1_rate,
'dropout2_rate': dropout2_rate,
'act_method': activation,
'pre_layer_norm': pre_layer_norm,
'ln1_epsilon': ln1_epsilon,
'ln2_epsilon': ln2_epsilon,
'dropout1_is_test': not training,
'dropout2_is_test': not training,
'dropout1_fix_seed': seed is not None,
'dropout2_fix_seed': seed is not None,
'dropout1_seed': seed if seed is not None else 0,
'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode,
'dropout2_implementation': mode,
'ring_id': ring_id,
})
return out
def _set_var_distributed(var):
if var is None:
return
var.is_distributed = True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
class ParallelFusedFeedForward(Layer):
def __init__(self,
d_model,
dim_feedforward,
dropout_rate=0.1,
epsilon=1e-05,
activation="relu",
act_dropout_rate=None,
normalize_before=False,
linear1_weight_attr=None,
linear1_bias_attr=None,
linear2_weight_attr=None,
linear2_bias_attr=None,
ln1_scale_attr=None,
ln1_bias_attr=None,
ln2_scale_attr=None,
ln2_bias_attr=None,
nranks=1,
ring_id=-1,
name=None):
super(ParallelFusedFeedForward, self).__init__()
assert d_model > 0, (
"Expected d_model to be greater than 0, but received {}".format(
d_model))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, but received {}".
format(dim_feedforward))
self._dtype = self._helper.get_default_dtype()
self._d_model = d_model
assert dim_feedforward % nranks == 0
dim_feedforward = dim_feedforward // nranks
self._dim_feedforward = dim_feedforward
self._dropout_rate = dropout_rate
self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self._act_method = activation
self._normalize_before = normalize_before
self._epsilon = epsilon
self._ring_id = ring_id
self._linear1_weight = self.create_parameter(
shape=[d_model, dim_feedforward],
attr=linear1_weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear1_bias = self.create_parameter(shape=[dim_feedforward],
attr=linear1_bias_attr,
dtype=self._dtype,
is_bias=True)
self._linear2_weight = self.create_parameter(
shape=[dim_feedforward, d_model],
attr=linear2_weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear2_bias = self.create_parameter(shape=[d_model],
attr=linear2_bias_attr,
dtype=self._dtype,
is_bias=True)
if nranks > 1:
assert ring_id != -1
# column parallel
_set_var_distributed(self._linear1_weight)
_set_var_distributed(self._linear1_bias)
_set_var_distributed(self._linear2_weight)
if normalize_before:
self._ln1_scale = self.create_parameter(
shape=[d_model],
attr=ln1_scale_attr,
is_bias=False,
default_initializer=Constant(1.0))
self._ln1_bias = self.create_parameter(shape=[d_model],
attr=ln1_bias_attr,
is_bias=True)
self._ln2_scale = None
self._ln2_bias = None
else:
self._ln1_bias = None
self._ln2_bias = None
self._ln2_scale = self.create_parameter(
shape=[d_model],
attr=ln2_scale_attr,
is_bias=False,
default_initializer=Constant(1.0))
self._ln2_bias = self.create_parameter(shape=[d_model],
attr=ln2_bias_attr,
is_bias=True)
self.name = name
def forward(self, src, cache=None):
out = fused_feedforward(src,
self._linear1_weight,
self._linear2_weight,
self._linear1_bias,
self._linear2_bias,
self._ln1_scale,
self._ln1_bias,
self._ln2_scale,
self._ln2_bias,
dropout1_rate=self._act_dropout_rate,
dropout2_rate=self._dropout_rate,
activation=self._act_method,
ln1_epsilon=self._epsilon,
ln2_epsilon=self._epsilon,
pre_layer_norm=self._normalize_before,
training=self.training,
ring_id=self._ring_id,
name=self.name)
return out
def get_param_attr(weight, bias): def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight)) initializer=fluid.initializer.NumpyArrayInitializer(weight))
...@@ -295,19 +58,19 @@ def create_model(data, rank): ...@@ -295,19 +58,19 @@ def create_model(data, rank):
w0_attr, b0_attr = get_param_attr(col_w0, col_b0) w0_attr, b0_attr = get_param_attr(col_w0, col_b0)
w1_attr, b1_attr = get_param_attr(row_w1, b1) w1_attr, b1_attr = get_param_attr(row_w1, b1)
ffn = ParallelFusedFeedForward(IN_SIZE, ffn = FusedFeedForward(IN_SIZE,
OUT_SIZE, OUT_SIZE,
dropout_rate=0.0, dropout_rate=0.0,
activation='gelu', activation='gelu',
normalize_before=True, normalize_before=True,
linear1_weight_attr=w0_attr, linear1_weight_attr=w0_attr,
linear1_bias_attr=b0_attr, linear1_bias_attr=b0_attr,
linear2_weight_attr=w1_attr, linear2_weight_attr=w1_attr,
linear2_bias_attr=b1_attr, linear2_bias_attr=b1_attr,
ln1_scale_attr=ln_w_attr, ln1_scale_attr=ln_w_attr,
ln1_bias_attr=ln_b_attr, ln1_bias_attr=ln_b_attr,
nranks=MODEL_PARALLEL_SIZE, nranks=MODEL_PARALLEL_SIZE,
ring_id=0) ring_id=0)
#ffn.eval() #ffn.eval()
result = ffn(data) result = ffn(data)
else: else:
...@@ -315,17 +78,17 @@ def create_model(data, rank): ...@@ -315,17 +78,17 @@ def create_model(data, rank):
w0_attr, b0_attr = get_param_attr(w0, b0) w0_attr, b0_attr = get_param_attr(w0, b0)
w1_attr, b1_attr = get_param_attr(w1, b1) w1_attr, b1_attr = get_param_attr(w1, b1)
ffn = ParallelFusedFeedForward(IN_SIZE, ffn = FusedFeedForward(IN_SIZE,
OUT_SIZE, OUT_SIZE,
dropout_rate=0.0, dropout_rate=0.0,
activation='gelu', activation='gelu',
normalize_before=True, normalize_before=True,
linear1_weight_attr=w0_attr, linear1_weight_attr=w0_attr,
linear1_bias_attr=b0_attr, linear1_bias_attr=b0_attr,
linear2_weight_attr=w1_attr, linear2_weight_attr=w1_attr,
linear2_bias_attr=b1_attr, linear2_bias_attr=b1_attr,
ln1_scale_attr=ln_w_attr, ln1_scale_attr=ln_w_attr,
ln1_bias_attr=ln_b_attr) ln1_bias_attr=ln_b_attr)
#ffn.eval() #ffn.eval()
result = ffn(data) result = ffn(data)
......
...@@ -83,7 +83,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -83,7 +83,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
if ln_bias is None: if ln_bias is None:
has_bias = False has_bias = False
if (pre_layer_norm): if pre_layer_norm:
ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias) ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)
num_head = qkv_weight.shape[1] num_head = qkv_weight.shape[1]
...@@ -97,7 +97,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -97,7 +97,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
if qkv_bias is not None: if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] * qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2]) qkv_bias.shape[2])
if (pre_layer_norm): if pre_layer_norm:
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight) qkv = fc(ln_out, qkv_weight)
if qkv_bias is not None: if qkv_bias is not None:
...@@ -239,12 +239,12 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -239,12 +239,12 @@ class TestFusedAttentionAPI(unittest.TestCase):
attn_mask_tensor = paddle.to_tensor(self.attn_mask) attn_mask_tensor = paddle.to_tensor(self.attn_mask)
else: else:
attn_mask_tensor = None attn_mask_tensor = None
fused_attn = FusedMultiHeadAttention(self.embed_dim, self.num_heads, fused_attn = FusedMultiHeadAttention(
self.dropout_prob, self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.vdim, self.pre_layer_norm, self.need_weight, self.weight_attr, self.bias_attr,
self.need_weight, self.weight_attr, self.weight_attr, self.bias_attr, self.weight_attr, self.bias_attr,
self.bias_attr) self.weight_attr, self.bias_attr)
if self.bias_attr is not False: if self.bias_attr is not False:
qkv_bias = np.random.random( qkv_bias = np.random.random(
fused_attn.qkv_bias.shape).astype('float32') fused_attn.qkv_bias.shape).astype('float32')
...@@ -260,13 +260,19 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -260,13 +260,19 @@ class TestFusedAttentionAPI(unittest.TestCase):
if self.bias_attr is not False: if self.bias_attr is not False:
fused_attn_qkv_bias = fused_attn.qkv_bias.numpy() fused_attn_qkv_bias = fused_attn.qkv_bias.numpy()
fused_attn_linear_bias = fused_attn.linear_bias.numpy() fused_attn_linear_bias = fused_attn.linear_bias.numpy()
fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy() if self.pre_layer_norm:
fused_attn_ln_bias = fused_attn.ln_bias.numpy() fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy()
fused_attn_ln_bias = None
else:
fused_attn_pre_ln_bias = None
fused_attn_ln_bias = fused_attn.ln_bias.numpy()
ref_out = compute_reference( ref_out = compute_reference(
self.pre_layer_norm, self.query, self.attn_mask, self.pre_layer_norm, self.query, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn_pre_ln_bias, fused_attn.pre_ln_scale.numpy() if self.pre_layer_norm else None,
fused_attn.ln_scale.numpy(), fused_attn_ln_bias, fused_attn_pre_ln_bias,
fused_attn.ln_scale.numpy() if not self.pre_layer_norm else None,
fused_attn_ln_bias,
fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias, fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(), fused_attn_linear_bias) fused_attn.linear_weight.numpy(), fused_attn_linear_bias)
np.testing.assert_allclose(ref_out, np.testing.assert_allclose(ref_out,
...@@ -275,12 +281,12 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -275,12 +281,12 @@ class TestFusedAttentionAPI(unittest.TestCase):
atol=self.atol) atol=self.atol)
def run_static(self): def run_static(self):
fused_attn = FusedMultiHeadAttention(self.embed_dim, self.num_heads, fused_attn = FusedMultiHeadAttention(
self.dropout_prob, self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.vdim, self.pre_layer_norm, self.need_weight, self.weight_attr, self.bias_attr,
self.need_weight, self.weight_attr, self.weight_attr, self.bias_attr, self.weight_attr, self.bias_attr,
self.bias_attr) self.weight_attr, self.bias_attr)
x = paddle.static.data( x = paddle.static.data(
name='X', name='X',
...@@ -304,58 +310,118 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -304,58 +310,118 @@ class TestFusedAttentionAPI(unittest.TestCase):
qkv_bias = None qkv_bias = None
linear_bias = None linear_bias = None
ln_scale = None
ln_2_scale = None
ln_bias = None ln_bias = None
ln_2_bias = None ln_2_bias = None
if self.has_attn_mask: if self.has_attn_mask:
if self.bias_attr is False: if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, out_linear_weight, ln_scale = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
"SrcMask": self.attn_mask "X": self.query,
}, "SrcMask": self.attn_mask
fetch_list=[ },
final_out, fused_attn.qkv_weight, fetch_list=[
fused_attn.linear_weight, fused_attn.pre_ln_scale, final_out,
fused_attn.ln_scale fused_attn.qkv_weight,
]) fused_attn.linear_weight,
fused_attn.pre_ln_scale,
])
else:
out, qkv_weight, out_linear_weight, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
"SrcMask": self.attn_mask
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.ln_scale
])
else: else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
"SrcMask": self.attn_mask "X": self.query,
}, "SrcMask": self.attn_mask
fetch_list=[ },
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, fetch_list=[
fused_attn.linear_weight, fused_attn.linear_bias, final_out,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.qkv_weight,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.qkv_bias,
]) fused_attn.linear_weight,
fused_attn.linear_bias,
fused_attn.pre_ln_scale,
fused_attn.pre_ln_bias,
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
"SrcMask": self.attn_mask
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.qkv_bias, fused_attn.linear_weight,
fused_attn.linear_bias, fused_attn.ln_scale,
fused_attn.ln_bias
])
else: else:
if self.bias_attr is False: if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, out_linear_weight, ln_scale = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
}, "X": self.query,
fetch_list=[ },
final_out, fused_attn.qkv_weight, fetch_list=[
fused_attn.linear_weight, fused_attn.pre_ln_scale, final_out,
fused_attn.ln_scale fused_attn.qkv_weight,
]) fused_attn.linear_weight,
fused_attn.pre_ln_scale,
])
else:
out, qkv_weight, out_linear_weight, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.ln_scale
])
else: else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( if self.pre_layer_norm:
paddle.static.default_main_program(), out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias = exe.run(
feed={ paddle.static.default_main_program(),
"X": self.query, feed={
}, "X": self.query,
fetch_list=[ },
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, fetch_list=[
fused_attn.linear_weight, fused_attn.linear_bias, final_out,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.qkv_weight,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.qkv_bias,
]) fused_attn.linear_weight,
fused_attn.linear_bias,
fused_attn.pre_ln_scale,
fused_attn.pre_ln_bias,
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={
"X": self.query,
},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.qkv_bias, fused_attn.linear_weight,
fused_attn.linear_bias, fused_attn.ln_scale,
fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self): def test_static_api(self):
......
...@@ -45,6 +45,7 @@ def fused_feedforward(x, ...@@ -45,6 +45,7 @@ def fused_feedforward(x,
pre_layer_norm=False, pre_layer_norm=False,
training=True, training=True,
mode='upscale_in_train', mode='upscale_in_train',
ring_id=-1,
name=None): name=None):
r""" r"""
This is a fusion operator to compute feed forward layer in transformer model architecture. This is a fusion operator to compute feed forward layer in transformer model architecture.
...@@ -88,6 +89,7 @@ def fused_feedforward(x, ...@@ -88,6 +89,7 @@ def fused_feedforward(x,
- train: out = input * mask - train: out = input * mask
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -132,7 +134,8 @@ def fused_feedforward(x, ...@@ -132,7 +134,8 @@ def fused_feedforward(x,
"dropout1_fix_seed", seed is not None, "dropout2_fix_seed", seed "dropout1_fix_seed", seed is not None, "dropout2_fix_seed", seed
is not None, "dropout1_seed", seed if seed is not None else 0, is not None, "dropout1_seed", seed if seed is not None else 0,
"dropout2_seed", seed if seed is not None else 0, "dropout2_seed", seed if seed is not None else 0,
'dropout1_implementation', mode, 'dropout2_implementation', mode) 'dropout1_implementation', mode, 'dropout2_implementation', mode,
'ring_id', ring_id)
return out return out
helper = LayerHelper("fused_feedforward") helper = LayerHelper("fused_feedforward")
...@@ -206,7 +209,8 @@ def fused_feedforward(x, ...@@ -206,7 +209,8 @@ def fused_feedforward(x,
'dropout1_seed': seed if seed is not None else 0, 'dropout1_seed': seed if seed is not None else 0,
'dropout2_seed': seed if seed is not None else 0, 'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode, 'dropout1_implementation': mode,
'dropout2_implementation': mode 'dropout2_implementation': mode,
'ring_id': ring_id,
}) })
return out return out
......
...@@ -101,12 +101,12 @@ class FusedBiasDropoutResidualLayerNorm(Layer): ...@@ -101,12 +101,12 @@ class FusedBiasDropoutResidualLayerNorm(Layer):
Applies fused_bias_dropout_residual_layer_norm operation. Applies fused_bias_dropout_residual_layer_norm operation.
Parameters: Parameters:
x (Tensor): The input tensor. It is a tensor with shape x (Tensor): The input tensor. It is a tensor with shape
`[batch_size, seq_len, embed_dim]`. The data type should be `[batch_size, seq_len, embed_dim]`. The data type should be
float32 or float64. float32 or float64.
residual (Tensor, optional): The residual tensor. It is a tensor residual (Tensor, optional): The residual tensor. It is a tensor
with shape `[batch_size, value_length, vdim]`. The data type with shape `[batch_size, value_length, vdim]`. The data type
should be float32 or float64. should be float32 or float64.
Returns: Returns:
Tensor|tuple: It is a tensor that has the same shape and data type \ Tensor|tuple: It is a tensor that has the same shape and data type \
...@@ -158,15 +158,39 @@ class FusedMultiHeadAttention(Layer): ...@@ -158,15 +158,39 @@ class FusedMultiHeadAttention(Layer):
(True) or post_layer_norm architecture (False). Default False. (True) or post_layer_norm architecture (False). Default False.
need_weights (bool, optional): Indicate whether to return the attention need_weights (bool, optional): Indicate whether to return the attention
weights. Now, only False is supported. Default False. weights. Now, only False is supported. Default False.
weight_attr(ParamAttr, optional): To specify the weight parameter property. qkv_weight_attr(ParamAttr, optional): To specify the weight parameter property
Default: None, which means the default weight parameter property is used. for QKV projection computation. Default: None, which means the default weight
See usage for details in :code:`ParamAttr`. parameter property is used. See usage for details in :code:`ParamAttr`.
bias_attr (ParamAttr|bool, optional): To specify the bias parameter property. qkv_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
Default: None, which means the default bias parameter property is used. for QKV projection computation. The `False` value means the corresponding layer
If it is set to False, this layer will not have trainable bias parameter. would not have trainable bias parameter. Default: None, which means the
See usage for details in :code:`ParamAttr`. default bias parameter property is used. See usage for details in :code:`ParamAttr`.
linear_weight_attr(ParamAttr, optional): To specify the weight parameter property
for linear projection computation. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
linear_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
for linear projection computation. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
pre_ln_scale_attr(ParamAttr, optional): To specify the weight parameter property
for pre_layer_norm computation. Otherwise, all layers both use it as
`attr` to create parameters. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
pre_ln_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
for pre_layer_norm computation. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
ln_scale_attr(ParamAttr, optional): To specify the weight parameter property
for post_layer_norm computation. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
ln_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
for post_layer_norm computation. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
epsilon (float, optional): The small value added to the variance to prevent epsilon (float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05. division by zero. Default: 1e-05.
nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
Examples: Examples:
...@@ -191,9 +215,17 @@ class FusedMultiHeadAttention(Layer): ...@@ -191,9 +215,17 @@ class FusedMultiHeadAttention(Layer):
vdim=None, vdim=None,
normalize_before=False, normalize_before=False,
need_weights=False, need_weights=False,
weight_attr=None, qkv_weight_attr=None,
bias_attr=None, qkv_bias_attr=None,
linear_weight_attr=None,
linear_bias_attr=None,
pre_ln_scale_attr=None,
pre_ln_bias_attr=None,
ln_scale_attr=None,
ln_bias_attr=None,
epsilon=1e-5, epsilon=1e-5,
nranks=1,
ring_id=-1,
name=None): name=None):
super(FusedMultiHeadAttention, self).__init__() super(FusedMultiHeadAttention, self).__init__()
...@@ -204,9 +236,8 @@ class FusedMultiHeadAttention(Layer): ...@@ -204,9 +236,8 @@ class FusedMultiHeadAttention(Layer):
self.normalize_before = normalize_before self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._epsilon = epsilon self._epsilon = epsilon
self._ring_id = ring_id
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
...@@ -215,41 +246,61 @@ class FusedMultiHeadAttention(Layer): ...@@ -215,41 +246,61 @@ class FusedMultiHeadAttention(Layer):
self.vdim = vdim self.vdim = vdim
self.need_weights = need_weights self.need_weights = need_weights
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now." assert need_weights is False, "Only support need_weight is False now."
# tensor model parallel
assert num_heads % nranks == 0
num_heads = num_heads // nranks
self.qkv_weight = self.create_parameter( self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim], shape=[3, num_heads, self.head_dim, embed_dim],
attr=self._weight_attr, attr=qkv_weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
self.qkv_bias = self.create_parameter( self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim], shape=[3, num_heads, self.head_dim],
attr=self._bias_attr, attr=qkv_bias_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
self.linear_weight = self.create_parameter(shape=[embed_dim, embed_dim], self.linear_weight = self.create_parameter(
attr=self._weight_attr, shape=[num_heads * self.head_dim, embed_dim],
dtype=self._dtype, attr=linear_weight_attr,
is_bias=False) dtype=self._dtype,
is_bias=False)
self.linear_bias = self.create_parameter(shape=[embed_dim], self.linear_bias = self.create_parameter(shape=[embed_dim],
attr=self._bias_attr, attr=linear_bias_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
self.pre_ln_scale = self.create_parameter( # tensor model parallel
attr=self._weight_attr, if nranks > 1:
shape=[embed_dim], assert ring_id != -1
default_initializer=Constant(value=1.0)) # column parallel
self.pre_ln_bias = self.create_parameter(attr=self._bias_attr, _set_var_distributed(self.qkv_weight)
_set_var_distributed(self.qkv_bias)
# row parallel
_set_var_distributed(self.linear_weight)
if normalize_before:
self.pre_ln_scale = self.create_parameter(
attr=pre_ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.pre_ln_bias = self.create_parameter(attr=pre_ln_bias_attr,
shape=[embed_dim],
is_bias=True)
self.ln_scale = None
self.ln_bias = None
else:
self.pre_ln_scale = None
self.pre_ln_bias = None
self.ln_scale = self.create_parameter(
attr=ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(attr=ln_bias_attr,
shape=[embed_dim], shape=[embed_dim],
is_bias=True) is_bias=True)
self.ln_scale = self.create_parameter(
attr=self._weight_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(attr=self._bias_attr,
shape=[embed_dim],
is_bias=True)
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate self.attn_dropout_rate = attn_dropout_rate
...@@ -294,8 +345,6 @@ class FusedMultiHeadAttention(Layer): ...@@ -294,8 +345,6 @@ class FusedMultiHeadAttention(Layer):
# Support bool or int mask # Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, query.dtype) attn_mask = _convert_attention_mask(attn_mask, query.dtype)
assert cache == None, "Only support cache is None now."
out = incubate_f.fused_multi_head_attention( out = incubate_f.fused_multi_head_attention(
x=query, x=query,
qkv_weight=self.qkv_weight, qkv_weight=self.qkv_weight,
...@@ -308,11 +357,13 @@ class FusedMultiHeadAttention(Layer): ...@@ -308,11 +357,13 @@ class FusedMultiHeadAttention(Layer):
pre_ln_epsilon=self._epsilon, pre_ln_epsilon=self._epsilon,
qkv_bias=self.qkv_bias, qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias, linear_bias=self.linear_bias,
cache_kv=cache,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate, attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=self._epsilon, ln_epsilon=self._epsilon,
training=self.training, training=self.training,
ring_id=self._ring_id,
name=self.name) name=self.name)
return out return out
...@@ -338,14 +389,38 @@ class FusedFeedForward(Layer): ...@@ -338,14 +389,38 @@ class FusedFeedForward(Layer):
If None, use the value of `dropout_rate`. Default None If None, use the value of `dropout_rate`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization normalize_before (bool, optional): Indicate whether to put layer normalization
into, preprocessing or postprocessing. Default False into, preprocessing or postprocessing. Default False
weight_attr (ParamAttr, optional): The attribute for the learnable weight of this layer. linear1_weight_attr(ParamAttr, optional): To specify the weight parameter property
The default value is None and the weight will be initialized to zero. For detailed for FFN first linear. Default: None, which means the default weight
information, please refer to paddle.ParamAttr. parameter property is used. See usage for details in :code:`ParamAttr`.
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias of thi layer. linear1_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
If it is set to False, no bias will be added to the output. If it is set to None or one for FFN first linear. The `False` value means the corresponding layer would
kind of ParamAttr, a bias parameter will be created according to ParamAttr. For detailed not have trainable bias parameter. Default: None, which means the default bias
information, please refer to paddle.ParamAttr. The default value is None and the bias parameter property is used. See usage for details in :code:`ParamAttr`.
will be initialized to zero. linear2_weight_attr(ParamAttr, optional): To specify the weight parameter property
for FFN second linear. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
linear2_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
for FFN second linear. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
ln1_scale_attr(ParamAttr, optional): To specify the weight parameter property
for FFN pre_layer_norm. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
ln1_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
for FFN pre_layer_norm. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
ln2_scale_attr(ParamAttr, optional): To specify the weight parameter property
for FFN post_layer_norm. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
ln2_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
for FFN layer_norm. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -369,8 +444,16 @@ class FusedFeedForward(Layer): ...@@ -369,8 +444,16 @@ class FusedFeedForward(Layer):
activation="relu", activation="relu",
act_dropout_rate=None, act_dropout_rate=None,
normalize_before=False, normalize_before=False,
weight_attr=None, linear1_weight_attr=None,
bias_attr=None, linear1_bias_attr=None,
linear2_weight_attr=None,
linear2_bias_attr=None,
ln1_scale_attr=None,
ln1_bias_attr=None,
ln2_scale_attr=None,
ln2_bias_attr=None,
nranks=1,
ring_id=-1,
name=None): name=None):
super(FusedFeedForward, self).__init__() super(FusedFeedForward, self).__init__()
...@@ -383,51 +466,68 @@ class FusedFeedForward(Layer): ...@@ -383,51 +466,68 @@ class FusedFeedForward(Layer):
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self._d_model = d_model self._d_model = d_model
assert dim_feedforward % nranks == 0
dim_feedforward = dim_feedforward // nranks
self._dim_feedforward = dim_feedforward self._dim_feedforward = dim_feedforward
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self._act_method = activation self._act_method = activation
self._normalize_before = normalize_before self._normalize_before = normalize_before
self._epsilon = epsilon self._epsilon = epsilon
self._ring_id = ring_id
self._linear1_weight = self.create_parameter( self._linear1_weight = self.create_parameter(
shape=[d_model, dim_feedforward], shape=[d_model, dim_feedforward],
attr=weight_attr, attr=linear1_weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
self._linear1_bias = self.create_parameter(shape=[dim_feedforward], self._linear1_bias = self.create_parameter(shape=[dim_feedforward],
attr=bias_attr, attr=linear1_bias_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
self._linear2_weight = self.create_parameter( self._linear2_weight = self.create_parameter(
shape=[dim_feedforward, d_model], shape=[dim_feedforward, d_model],
attr=weight_attr, attr=linear2_weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
self._linear2_bias = self.create_parameter(shape=[d_model], self._linear2_bias = self.create_parameter(shape=[d_model],
attr=bias_attr, attr=linear2_bias_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
self._ln1_scale = self.create_parameter( if nranks > 1:
shape=[d_model], assert ring_id != -1
attr=None, # column parallel
is_bias=False, _set_var_distributed(self._linear1_weight)
default_initializer=Constant(1.0)) _set_var_distributed(self._linear1_bias)
self._ln1_bias = self.create_parameter(shape=[d_model], _set_var_distributed(self._linear2_weight)
attr=None,
is_bias=True) if normalize_before:
self._ln1_scale = self.create_parameter(
self._ln2_scale = self.create_parameter( shape=[d_model],
shape=[d_model], attr=ln1_scale_attr,
attr=None, is_bias=False,
is_bias=False, default_initializer=Constant(1.0))
default_initializer=Constant(1.0)) self._ln1_bias = self.create_parameter(shape=[d_model],
self._ln2_bias = self.create_parameter(shape=[d_model], attr=ln1_bias_attr,
attr=None, is_bias=True)
is_bias=True) self._ln2_scale = None
self._ln2_bias = None
else:
self._ln1_scale = None
self._ln1_bias = None
self._ln2_scale = self.create_parameter(
shape=[d_model],
attr=ln2_scale_attr,
is_bias=False,
default_initializer=Constant(1.0))
self._ln2_bias = self.create_parameter(shape=[d_model],
attr=ln2_bias_attr,
is_bias=True)
self.name = name self.name = name
def forward(self, src, cache=None): def forward(self, src, cache=None):
...@@ -448,6 +548,7 @@ class FusedFeedForward(Layer): ...@@ -448,6 +548,7 @@ class FusedFeedForward(Layer):
ln2_epsilon=self._epsilon, ln2_epsilon=self._epsilon,
pre_layer_norm=self._normalize_before, pre_layer_norm=self._normalize_before,
training=self.training, training=self.training,
ring_id=self._ring_id,
name=self.name) name=self.name)
return out return out
...@@ -553,8 +654,14 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -553,8 +654,14 @@ class FusedTransformerEncoderLayer(Layer):
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attn_dropout_rate=attn_dropout_rate, attn_dropout_rate=attn_dropout_rate,
normalize_before=self.normalize_before, normalize_before=self.normalize_before,
weight_attr=weight_attrs[0], qkv_weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0]) qkv_bias_attr=bias_attrs[0],
linear_weight_attr=weight_attrs[0],
linear_bias_attr=bias_attrs[0],
pre_ln_scale_attr=weight_attrs[0],
pre_ln_bias_attr=bias_attrs[0],
ln_scale_attr=weight_attrs[0],
ln_bias_attr=bias_attrs[0])
self.ffn = FusedFeedForward(d_model, self.ffn = FusedFeedForward(d_model,
dim_feedforward, dim_feedforward,
...@@ -562,8 +669,10 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -562,8 +669,10 @@ class FusedTransformerEncoderLayer(Layer):
activation=activation, activation=activation,
act_dropout_rate=act_dropout_rate, act_dropout_rate=act_dropout_rate,
normalize_before=self.normalize_before, normalize_before=self.normalize_before,
weight_attr=weight_attrs[1], linear1_weight_attr=weight_attrs[1],
bias_attr=bias_attrs[1]) linear1_bias_attr=bias_attrs[1],
linear2_weight_attr=weight_attrs[1],
linear2_bias_attr=bias_attrs[1])
def forward(self, src, src_mask=None, cache=None): def forward(self, src, src_mask=None, cache=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册