未验证 提交 c2a5bb91 编写于 作者: Z Zhang Zheng 提交者: GitHub

Add new attr of fused_multi_transformer (#43730)

* Add new attr of fused_multi_transformer

* fix format

* add note

* add in layer

* fixfixfixfix
上级 a9bba5ba
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -63,6 +64,7 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputsDim("QKVW")[0];
bool trans_qkvw = ctx->Attrs().Get<bool>("trans_qkvw");
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
......@@ -79,24 +81,37 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3]"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
PADDLE_ENFORCE_EQ(
x_dim[2],
trans_qkvw ? y_dim[3] : y_dim[0],
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim,
y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
if (trans_qkvw) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
} else {
PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3],
y_dim[0],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(dim_embed, 3, num_head, dim_head),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
}
if (ctx->HasInputs("CacheKV")) {
......@@ -122,11 +137,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
y_dim[1],
trans_qkvw ? y_dim[1] : y_dim[2],
paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
y_dim[1],
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head
PADDLE_ENFORCE_GT(
c_dim[3],
......@@ -135,11 +150,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4],
y_dim[2],
trans_qkvw ? y_dim[2] : y_dim[3],
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
y_dim[2],
trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size
}
......@@ -258,6 +273,13 @@ class FusedMultiTransformerOpOpMaker
"upscale_in_train"));
});
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<bool>(
"trans_qkvw",
"Whether the weights of qkv should be transposed. If true,"
"the shape eights of qkv should be [3, num_head, dim_head, dim_embed]."
"Otherwise the shape of weights of qkv should be"
"[dim_embed, 3, num_head, dim_head]")
.SetDefault(true);
AddAttr<int>(
"ring_id",
......@@ -278,3 +300,12 @@ REGISTER_OPERATOR(
ops::FusedMultiTransformerOpOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_VERSION(fused_multi_transformer)
.AddCheckpoint(
R"ROC(
Add a new attribute [trans_qkvw] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"trans_qkvw",
"A flag to indicate whether to transpose for weights of qkv.",
true));
......@@ -1119,17 +1119,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto qkv_weights = ctx.MultiInput<Tensor>("QKVW");
auto qkv_biases = ctx.MultiInput<Tensor>("QKVBias");
const bool trans_qkvw = ctx.Attr<bool>("trans_qkvw");
const auto qkv_w_dims = qkv_weights[0]->dims();
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, true, false)
auto qkv_compute = AttnMatMul<T>(
dev_ctx, false, true, bsz_seq, output_size, input_size, compute_bias);
// (transA, transB, compute_bias) = (false, trans_qkvw, false)
auto qkv_compute = AttnMatMul<T>(dev_ctx,
false,
trans_qkvw,
bsz_seq,
output_size,
input_size,
compute_bias);
Tensor qkv_out;
auto *qkv_out_data =
qkv_out.mutable_data<T>({bsz, seq_len, 3, num_head, dim_head}, place);
......
......@@ -680,6 +680,7 @@ def fused_multi_transformer(x,
activation="gelu",
training=False,
mode='upscale_in_train',
trans_qkvw=True,
ring_id=-1,
name=None):
r"""
......@@ -756,6 +757,9 @@ def fused_multi_transformer(x,
- train: out = input * mask
- inference: out = input * (1.0 - p)
trans_qkvw (bool, optional): Whether to transpose for weights of qkv.
If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed].
Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default True.
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -826,8 +830,8 @@ def fused_multi_transformer(x,
ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases,
cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon,
'dropout_rate', dropout_rate, 'is_test', not training,
'dropout_implementation', mode, 'act_method', activation, 'ring_id',
ring_id)
'dropout_implementation', mode, 'act_method', activation,
'trans_qkvw', trans_qkvw, 'ring_id', ring_id)
if cache_kvs is not None:
return final_out, cache_kv_out
return final_out
......@@ -875,6 +879,7 @@ def fused_multi_transformer(x,
'is_test': not training,
'dropout_implementation': mode,
'act_method': activation,
'trans_qkvw': trans_qkvw,
'ring_id': ring_id
}
......
......@@ -1048,6 +1048,9 @@ class FusedMultiTransformer(Layer):
is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers
only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1.
nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp.
trans_qkvw (bool, optional): Whether to transpose for weights of qkv.
If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed].
Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default: True.
ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -1090,6 +1093,7 @@ class FusedMultiTransformer(Layer):
epsilon=1e-5,
num_layers=-1,
nranks=1,
trans_qkvw=True,
ring_id=-1,
name=None):
super(FusedMultiTransformer, self).__init__()
......@@ -1105,6 +1109,7 @@ class FusedMultiTransformer(Layer):
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._epsilon = epsilon
self._trans_qkvw = trans_qkvw
self._ring_id = ring_id
self.embed_dim = embed_dim
......@@ -1161,7 +1166,8 @@ class FusedMultiTransformer(Layer):
shape=[embed_dim],
is_bias=True)
qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
shape=[3, num_heads, self.head_dim, embed_dim]
if trans_qkvw else [embed_dim, 3, num_heads, self.head_dim],
attr=qkv_weight_attr,
dtype=self._dtype,
is_bias=False)
......@@ -1292,6 +1298,7 @@ class FusedMultiTransformer(Layer):
activation=self.activation,
training=self.training,
mode='upscale_in_train',
trans_qkvw=self._trans_qkvw,
ring_id=self._ring_id,
name=self.name)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册