From c2a5bb91448645130c317d94ea18b99a48e65b1a Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 30 Jun 2022 10:48:44 +0800 Subject: [PATCH] Add new attr of fused_multi_transformer (#43730) * Add new attr of fused_multi_transformer * fix format * add note * add in layer * fixfixfixfix --- .../fused/fused_multi_transformer_op.cc | 71 +++++++++++++------ .../fused/fused_multi_transformer_op.cu | 16 +++-- .../nn/functional/fused_transformer.py | 9 ++- .../incubate/nn/layer/fused_transformer.py | 9 ++- 4 files changed, 77 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index aa05ebc43d..86de140b9c 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -63,6 +64,7 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto x_dim = ctx->GetInputDim("X"); auto y_dim = ctx->GetInputsDim("QKVW")[0]; + bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); PADDLE_ENFORCE_EQ( x_dim.size(), 3, @@ -79,24 +81,37 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { "but received dimensions of" "Input is [%d]", y_dim.size())); - PADDLE_ENFORCE_EQ(x_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "ShapeError: the dimension of x_dim[2] and y_dim[3]" - "must be equal. But received: the shape " - "of input x = [%s], and the shape of " - "input qkv_weight = [%s]", - x_dim, - y_dim)); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); if (ctx->Attrs().Get("ring_id") == -1) { - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); + if (trans_qkvw) { + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], + y_dim[3], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + + } else { + PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3], + y_dim[0], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(dim_embed, 3, num_head, dim_head)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + } } if (ctx->HasInputs("CacheKV")) { @@ -122,11 +137,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { x_dim[0], c_dim[1])); // batch_size PADDLE_ENFORCE_EQ(c_dim[2], - y_dim[1], + trans_qkvw ? y_dim[1] : y_dim[2], paddle::platform::errors::InvalidArgument( "The third dim of CacheKV must be equal with num " "head %d, but got %d", - y_dim[1], + trans_qkvw ? y_dim[1] : y_dim[2], c_dim[2])); // num_head PADDLE_ENFORCE_GT( c_dim[3], @@ -135,11 +150,11 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { "The forth dim of CacheKV must be greater than 0, but got %d", c_dim[3])); // cache_seq_len PADDLE_ENFORCE_EQ(c_dim[4], - y_dim[2], + trans_qkvw ? y_dim[2] : y_dim[3], paddle::platform::errors::InvalidArgument( "The fifth dim of CacheKV must be equal with head " "size %d, but got %d", - y_dim[2], + trans_qkvw ? y_dim[2] : y_dim[3], c_dim[4])); // head_size } @@ -258,6 +273,13 @@ class FusedMultiTransformerOpOpMaker "upscale_in_train")); }); AddAttr("act_method", "act_method").SetDefault("gelu"); + AddAttr( + "trans_qkvw", + "Whether the weights of qkv should be transposed. If true," + "the shape eights of qkv should be [3, num_head, dim_head, dim_embed]." + "Otherwise the shape of weights of qkv should be" + "[dim_embed, 3, num_head, dim_head]") + .SetDefault(true); AddAttr( "ring_id", @@ -278,3 +300,12 @@ REGISTER_OPERATOR( ops::FusedMultiTransformerOpOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_VERSION(fused_multi_transformer) + .AddCheckpoint( + R"ROC( + Add a new attribute [trans_qkvw] )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "trans_qkvw", + "A flag to indicate whether to transpose for weights of qkv.", + true)); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index ca2b884bf7..f806359093 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -1119,17 +1119,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto qkv_weights = ctx.MultiInput("QKVW"); auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); const auto qkv_w_dims = qkv_weights[0]->dims(); - int num_head = qkv_w_dims[1]; - int dim_head = qkv_w_dims[2]; + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; int hidden_size = num_head * dim_head; int output_size = 3 * hidden_size; int input_size = dim_embed; bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; - // (transA, transB, compute_bias) = (false, true, false) - auto qkv_compute = AttnMatMul( - dev_ctx, false, true, bsz_seq, output_size, input_size, compute_bias); + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + auto qkv_compute = AttnMatMul(dev_ctx, + false, + trans_qkvw, + bsz_seq, + output_size, + input_size, + compute_bias); Tensor qkv_out; auto *qkv_out_data = qkv_out.mutable_data({bsz, seq_len, 3, num_head, dim_head}, place); diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 3e4d015da1..506a282171 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -680,6 +680,7 @@ def fused_multi_transformer(x, activation="gelu", training=False, mode='upscale_in_train', + trans_qkvw=True, ring_id=-1, name=None): r""" @@ -756,6 +757,9 @@ def fused_multi_transformer(x, - train: out = input * mask - inference: out = input * (1.0 - p) + trans_qkvw (bool, optional): Whether to transpose for weights of qkv. + If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed]. + Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default True. ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -826,8 +830,8 @@ def fused_multi_transformer(x, ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, 'dropout_rate', dropout_rate, 'is_test', not training, - 'dropout_implementation', mode, 'act_method', activation, 'ring_id', - ring_id) + 'dropout_implementation', mode, 'act_method', activation, + 'trans_qkvw', trans_qkvw, 'ring_id', ring_id) if cache_kvs is not None: return final_out, cache_kv_out return final_out @@ -875,6 +879,7 @@ def fused_multi_transformer(x, 'is_test': not training, 'dropout_implementation': mode, 'act_method': activation, + 'trans_qkvw': trans_qkvw, 'ring_id': ring_id } diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 4a8f7815ae..ba14ac5b86 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1048,6 +1048,9 @@ class FusedMultiTransformer(Layer): is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1. nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp. + trans_qkvw (bool, optional): Whether to transpose for weights of qkv. + If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed]. + Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default: True. ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -1090,6 +1093,7 @@ class FusedMultiTransformer(Layer): epsilon=1e-5, num_layers=-1, nranks=1, + trans_qkvw=True, ring_id=-1, name=None): super(FusedMultiTransformer, self).__init__() @@ -1105,6 +1109,7 @@ class FusedMultiTransformer(Layer): self.normalize_before = normalize_before self._dtype = self._helper.get_default_dtype() self._epsilon = epsilon + self._trans_qkvw = trans_qkvw self._ring_id = ring_id self.embed_dim = embed_dim @@ -1161,7 +1166,8 @@ class FusedMultiTransformer(Layer): shape=[embed_dim], is_bias=True) qkv_weight = self.create_parameter( - shape=[3, num_heads, self.head_dim, embed_dim], + shape=[3, num_heads, self.head_dim, embed_dim] + if trans_qkvw else [embed_dim, 3, num_heads, self.head_dim], attr=qkv_weight_attr, dtype=self._dtype, is_bias=False) @@ -1292,6 +1298,7 @@ class FusedMultiTransformer(Layer): activation=self.activation, training=self.training, mode='upscale_in_train', + trans_qkvw=self._trans_qkvw, ring_id=self._ring_id, name=self.name) return out -- GitLab