未验证 提交 c2a5bb91 编写于 作者: Z Zhang Zheng 提交者: GitHub

Add new attr of fused_multi_transformer (#43730)

* Add new attr of fused_multi_transformer

* fix format

* add note

* add in layer

* fixfixfixfix
上级 a9bba5ba
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -63,6 +64,7 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { ...@@ -63,6 +64,7 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
// y: qkv's weight: [3, num_head, dim_head, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputsDim("QKVW")[0]; auto y_dim = ctx->GetInputsDim("QKVW")[0];
bool trans_qkvw = ctx->Attrs().Get<bool>("trans_qkvw");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dim.size(), x_dim.size(),
3, 3,
...@@ -79,24 +81,37 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { ...@@ -79,24 +81,37 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
"but received dimensions of" "but received dimensions of"
"Input is [%d]", "Input is [%d]",
y_dim.size())); y_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2], PADDLE_ENFORCE_EQ(
y_dim[3], x_dim[2],
platform::errors::InvalidArgument( trans_qkvw ? y_dim[3] : y_dim[0],
"ShapeError: the dimension of x_dim[2] and y_dim[3]" platform::errors::InvalidArgument(
"must be equal. But received: the shape " "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"of input x = [%s], and the shape of " "true) or y_dim[0](trans_qkvw is false)"
"input qkv_weight = [%s]", "must be equal. But received: the shape "
x_dim, "of input x = [%s], and the shape of "
y_dim)); "input qkv_weight = [%s]",
x_dim,
y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) { if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], if (trans_qkvw) {
y_dim[3], PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
platform::errors::InvalidArgument( y_dim[3],
"The dimensions of qkv_weight must be 4" platform::errors::InvalidArgument(
"(3, num_head, dim_head, dim_embed)," "The dimensions of qkv_weight must be 4"
"and must satisfy the limitations: " "(3, num_head, dim_head, dim_embed),"
"(num_head * dim_head == dim_embed)")); "and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
} else {
PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3],
y_dim[0],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(dim_embed, 3, num_head, dim_head),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
} }
if (ctx->HasInputs("CacheKV")) { if (ctx->HasInputs("CacheKV")) {
...@@ -122,11 +137,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { ...@@ -122,11 +137,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
x_dim[0], x_dim[0],
c_dim[1])); // batch_size c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2], PADDLE_ENFORCE_EQ(c_dim[2],
y_dim[1], trans_qkvw ? y_dim[1] : y_dim[2],
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num " "The third dim of CacheKV must be equal with num "
"head %d, but got %d", "head %d, but got %d",
y_dim[1], trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head c_dim[2])); // num_head
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
c_dim[3], c_dim[3],
...@@ -135,11 +150,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { ...@@ -135,11 +150,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
"The forth dim of CacheKV must be greater than 0, but got %d", "The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4], PADDLE_ENFORCE_EQ(c_dim[4],
y_dim[2], trans_qkvw ? y_dim[2] : y_dim[3],
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head " "The fifth dim of CacheKV must be equal with head "
"size %d, but got %d", "size %d, but got %d",
y_dim[2], trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size c_dim[4])); // head_size
} }
...@@ -258,6 +273,13 @@ class FusedMultiTransformerOpOpMaker ...@@ -258,6 +273,13 @@ class FusedMultiTransformerOpOpMaker
"upscale_in_train")); "upscale_in_train"));
}); });
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu"); AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<bool>(
"trans_qkvw",
"Whether the weights of qkv should be transposed. If true,"
"the shape eights of qkv should be [3, num_head, dim_head, dim_embed]."
"Otherwise the shape of weights of qkv should be"
"[dim_embed, 3, num_head, dim_head]")
.SetDefault(true);
AddAttr<int>( AddAttr<int>(
"ring_id", "ring_id",
...@@ -278,3 +300,12 @@ REGISTER_OPERATOR( ...@@ -278,3 +300,12 @@ REGISTER_OPERATOR(
ops::FusedMultiTransformerOpOpMaker, ops::FusedMultiTransformerOpOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_VERSION(fused_multi_transformer)
.AddCheckpoint(
R"ROC(
Add a new attribute [trans_qkvw] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"trans_qkvw",
"A flag to indicate whether to transpose for weights of qkv.",
true));
...@@ -1119,17 +1119,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1119,17 +1119,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// y: qkv's weight: [3, num_head, dim_head, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto qkv_weights = ctx.MultiInput<Tensor>("QKVW"); auto qkv_weights = ctx.MultiInput<Tensor>("QKVW");
auto qkv_biases = ctx.MultiInput<Tensor>("QKVBias"); auto qkv_biases = ctx.MultiInput<Tensor>("QKVBias");
const bool trans_qkvw = ctx.Attr<bool>("trans_qkvw");
const auto qkv_w_dims = qkv_weights[0]->dims(); const auto qkv_w_dims = qkv_weights[0]->dims();
int num_head = qkv_w_dims[1]; int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
int dim_head = qkv_w_dims[2]; int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
int hidden_size = num_head * dim_head; int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size; int output_size = 3 * hidden_size;
int input_size = dim_embed; int input_size = dim_embed;
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, true, false) // (transA, transB, compute_bias) = (false, trans_qkvw, false)
auto qkv_compute = AttnMatMul<T>( auto qkv_compute = AttnMatMul<T>(dev_ctx,
dev_ctx, false, true, bsz_seq, output_size, input_size, compute_bias); false,
trans_qkvw,
bsz_seq,
output_size,
input_size,
compute_bias);
Tensor qkv_out; Tensor qkv_out;
auto *qkv_out_data = auto *qkv_out_data =
qkv_out.mutable_data<T>({bsz, seq_len, 3, num_head, dim_head}, place); qkv_out.mutable_data<T>({bsz, seq_len, 3, num_head, dim_head}, place);
......
...@@ -680,6 +680,7 @@ def fused_multi_transformer(x, ...@@ -680,6 +680,7 @@ def fused_multi_transformer(x,
activation="gelu", activation="gelu",
training=False, training=False,
mode='upscale_in_train', mode='upscale_in_train',
trans_qkvw=True,
ring_id=-1, ring_id=-1,
name=None): name=None):
r""" r"""
...@@ -756,6 +757,9 @@ def fused_multi_transformer(x, ...@@ -756,6 +757,9 @@ def fused_multi_transformer(x,
- train: out = input * mask - train: out = input * mask
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
trans_qkvw (bool, optional): Whether to transpose for weights of qkv.
If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed].
Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default True.
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp. ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -826,8 +830,8 @@ def fused_multi_transformer(x, ...@@ -826,8 +830,8 @@ def fused_multi_transformer(x,
ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases,
cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon,
'dropout_rate', dropout_rate, 'is_test', not training, 'dropout_rate', dropout_rate, 'is_test', not training,
'dropout_implementation', mode, 'act_method', activation, 'ring_id', 'dropout_implementation', mode, 'act_method', activation,
ring_id) 'trans_qkvw', trans_qkvw, 'ring_id', ring_id)
if cache_kvs is not None: if cache_kvs is not None:
return final_out, cache_kv_out return final_out, cache_kv_out
return final_out return final_out
...@@ -875,6 +879,7 @@ def fused_multi_transformer(x, ...@@ -875,6 +879,7 @@ def fused_multi_transformer(x,
'is_test': not training, 'is_test': not training,
'dropout_implementation': mode, 'dropout_implementation': mode,
'act_method': activation, 'act_method': activation,
'trans_qkvw': trans_qkvw,
'ring_id': ring_id 'ring_id': ring_id
} }
......
...@@ -1048,6 +1048,9 @@ class FusedMultiTransformer(Layer): ...@@ -1048,6 +1048,9 @@ class FusedMultiTransformer(Layer):
is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers
only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1. only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1.
nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp. nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp.
trans_qkvw (bool, optional): Whether to transpose for weights of qkv.
If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed].
Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default: True.
ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp. ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp.
name (str, optional): The default value is None. Normally there is no need for user to set name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`. this property. For more information, please refer to :ref:`api_guide_Name`.
...@@ -1090,6 +1093,7 @@ class FusedMultiTransformer(Layer): ...@@ -1090,6 +1093,7 @@ class FusedMultiTransformer(Layer):
epsilon=1e-5, epsilon=1e-5,
num_layers=-1, num_layers=-1,
nranks=1, nranks=1,
trans_qkvw=True,
ring_id=-1, ring_id=-1,
name=None): name=None):
super(FusedMultiTransformer, self).__init__() super(FusedMultiTransformer, self).__init__()
...@@ -1105,6 +1109,7 @@ class FusedMultiTransformer(Layer): ...@@ -1105,6 +1109,7 @@ class FusedMultiTransformer(Layer):
self.normalize_before = normalize_before self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self._epsilon = epsilon self._epsilon = epsilon
self._trans_qkvw = trans_qkvw
self._ring_id = ring_id self._ring_id = ring_id
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -1161,7 +1166,8 @@ class FusedMultiTransformer(Layer): ...@@ -1161,7 +1166,8 @@ class FusedMultiTransformer(Layer):
shape=[embed_dim], shape=[embed_dim],
is_bias=True) is_bias=True)
qkv_weight = self.create_parameter( qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim], shape=[3, num_heads, self.head_dim, embed_dim]
if trans_qkvw else [embed_dim, 3, num_heads, self.head_dim],
attr=qkv_weight_attr, attr=qkv_weight_attr,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
...@@ -1292,6 +1298,7 @@ class FusedMultiTransformer(Layer): ...@@ -1292,6 +1298,7 @@ class FusedMultiTransformer(Layer):
activation=self.activation, activation=self.activation,
training=self.training, training=self.training,
mode='upscale_in_train', mode='upscale_in_train',
trans_qkvw=self._trans_qkvw,
ring_id=self._ring_id, ring_id=self._ring_id,
name=self.name) name=self.name)
return out return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册