未验证 提交 9dadf7df 编写于 作者: W WangXi 提交者: GitHub

Add fused_multi_transformer op to optimize transformer generation performance (#41814)

上级 30838aa6
......@@ -19,6 +19,7 @@ register_operators(EXCLUDES
fused_attention_op
fused_transformer_op
fused_feedforward_op
fused_multi_transformer_op
resnet_unit_op
fused_gemm_epilogue_op)
......@@ -73,6 +74,7 @@ if (WITH_GPU OR WITH_ROCM)
op_library(fused_feedforward_op)
# fused_attention_op
op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
endif()
# resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedMultiTransformerOp : public framework::OperatorWithKernel {
private:
static constexpr const char *OpName = "FusedMultiTransformerOp";
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
#define CHECK_INPUT(name) \
OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName)
#define CHECK_INPUTS(name) \
OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName)
#define CHECK_OUTPUT(name) \
OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName)
#define CHECK_OUTPUTS(name) \
OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName)
CHECK_INPUT(X);
// attention
CHECK_INPUTS(QKVW);
CHECK_INPUTS(OutLinearW);
if (ctx->HasInput("TimeStep")) {
CHECK_INPUTS(CacheKV);
}
if (ctx->HasInputs("CacheKV")) {
CHECK_OUTPUTS(CacheKVOut);
}
// ffn
CHECK_INPUTS(FFN1Weight);
CHECK_INPUTS(FFN2Weight);
CHECK_OUTPUT(Out);
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputsDim("QKVW")[0];
PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument(
"The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(y_dim.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3]"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim, y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
if (ctx->HasInputs("CacheKV")) {
// [2, batch_size, num_head, max_seq_len, head_size]
const auto &c_dims = ctx->GetInputsDim("CacheKV");
const auto &c_dim = c_dims[0];
PADDLE_ENFORCE_EQ(
c_dim.size(), 5,
paddle::platform::errors::InvalidArgument(
"The CacheKV must be 5 dims, but got %d", c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0], 2,
paddle::platform::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0],
paddle::platform::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0], c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1],
paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
y_dim[1], c_dim[2])); // num_head
PADDLE_ENFORCE_GT(
c_dim[3], 0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2],
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
y_dim[2], c_dim[4])); // head_size
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "TimeStep") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class FusedMultiTransformerOpOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("LnScale",
"Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("LnBias",
"Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("QKVW", "The qkv weight tensor.").AsDuplicable();
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable();
AddInput("CacheKV", "(optional) The cached KV for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
AddInput("OutLinearBias", "The out_linear bias tensor.")
.AsDispensable()
.AsDuplicable();
AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op")
.AsDuplicable();
AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV")
.AsDispensable()
.AsDuplicable();
AddOutput("Out", "Result after multi .");
AddAttr<bool>("pre_layer_norm",
"if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. "
"[default true].")
.SetDefault(true);
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));
});
AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true,
platform::errors::InvalidArgument(
"'dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("dropout_is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"The meaning is the same as 'attn_dropout_implementation'.")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train", true,
platform::errors::InvalidArgument(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
.SetDefault(-1);
AddComment(R"DOC(fused multi transformer layers op)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_multi_transformer, ops::FusedMultiTransformerOp,
ops::FusedMultiTransformerOpOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -32,6 +32,10 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"fused_attention",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "SrcMask",
"OutLinearW", "OutLinearBias", "Ln2Scale", "Ln2Bias"}},
{"fused_multi_transformer",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep",
"SrcMask", "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias",
"FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}},
......@@ -176,6 +180,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"fused_multi_transformer", {"CacheKVOut", "Out"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......@@ -253,6 +258,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"assign_value", {"Out"}},
{"split", {"Out"}},
{"concat", {"Out"}},
{"fused_multi_transformer", {"CacheKVOut"}},
};
// NOTE(pangyoki): Tensor View Strategy.
......
......@@ -162,6 +162,7 @@ gray_list = {
'split',
'fused_feedforward',
'fused_attention',
'fused_multi_transformer',
}
# The set of ops that don't support fp16 calculation
......
......@@ -109,6 +109,8 @@ def _keep_fp32_input(op, in_name):
return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
}
if op_type == 'fused_multi_transformer':
return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'}
return False
......
......@@ -25,6 +25,7 @@ list(APPEND DIST_TEST_OPS test_ir_pass_pipeline)
list(APPEND DIST_TEST_OPS test_static_model_parallel)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height)
......@@ -128,6 +129,7 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
endif()
......@@ -1187,6 +1189,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240)
set_tests_properties(test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_model_parallel_fused_multi_transformer PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_split_embedding
test_collective_split_embedding_none_divisible
test_collective_split_row_linear
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main
from paddle.incubate.nn import FusedMultiTransformer
import paddle.distributed.fleet as fleet
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
from paddle.nn.initializer import Constant
paddle.enable_static()
def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight))
bias_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(bias))
return weight_attr, bias_attr
DTYPE = "float32"
MODEL_PARALLEL_SIZE = 2
num_head = 2 * MODEL_PARALLEL_SIZE
dim_head = 4
hidden = num_head * dim_head
dim_ffn = 4 * hidden
def create_model(data, rank):
np.random.seed(2021)
ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
qkv_w = np.random.uniform(
-1, 1, size=(3, num_head, dim_head, hidden)).astype(DTYPE)
qkv_b = np.random.uniform(-1, 1, size=(3, num_head, dim_head)).astype(DTYPE)
linear_w = np.random.uniform(
-1, 1, size=(num_head * dim_head, hidden)).astype(DTYPE)
linear_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ffn_ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ffn_ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ffn1_w = np.random.uniform(-1, 1, size=(hidden, dim_ffn)).astype(DTYPE)
ffn1_b = np.random.uniform(-1, 1, size=(dim_ffn, )).astype(DTYPE)
ffn2_w = np.random.uniform(-1, 1, size=(dim_ffn, hidden)).astype(DTYPE)
ffn2_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
if rank is not None:
start = 0 if rank == 0 else (num_head // MODEL_PARALLEL_SIZE)
end = start + (num_head // MODEL_PARALLEL_SIZE)
col_qkv_w = qkv_w[:, start:end, :, :]
col_qkv_b = qkv_b[:, start:end, :]
row_linear_w = linear_w[(start * dim_head):(end * dim_head), :]
ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b)
start = 0 if rank == 0 else (dim_ffn // MODEL_PARALLEL_SIZE)
end = start + (dim_ffn // MODEL_PARALLEL_SIZE)
col_ffn1_w = ffn1_w[:, start:end]
col_ffn1_b = ffn1_b[start:end]
row_ffn2_w = ffn2_w[start:end, :]
ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b)
ffn1_w_attr, ffn1_b_attr = get_param_attr(col_ffn1_w, col_ffn1_b)
ffn2_w_attr, ffn2_b_attr = get_param_attr(row_ffn2_w, ffn2_b)
multi_transformer = FusedMultiTransformer(
hidden,
num_head,
dim_ffn,
dropout_rate=0.0,
activation="gelu",
normalize_before=True,
ln_scale_attrs=[ln_w_attr],
ln_bias_attrs=[ln_b_attr],
qkv_weight_attrs=[qkv_w_attr],
qkv_bias_attrs=[qkv_b_attr],
linear_weight_attrs=[linear_w_attr],
linear_bias_attrs=[linear_b_attr],
ffn_ln_scale_attrs=[ffn_ln_w_attr],
ffn_ln_bias_attrs=[ffn_ln_b_attr],
ffn1_weight_attrs=[ffn1_w_attr],
ffn1_bias_attrs=[ffn1_b_attr],
ffn2_weight_attrs=[ffn2_w_attr],
ffn2_bias_attrs=[ffn2_b_attr],
nranks=MODEL_PARALLEL_SIZE,
ring_id=0)
result = multi_transformer(data)
else:
ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b)
ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b)
ffn1_w_attr, ffn1_b_attr = get_param_attr(ffn1_w, ffn1_b)
ffn2_w_attr, ffn2_b_attr = get_param_attr(ffn2_w, ffn2_b)
multi_transformer = FusedMultiTransformer(
hidden,
num_head,
dim_ffn,
dropout_rate=0.0,
activation="gelu",
normalize_before=True,
ln_scale_attrs=[ln_w_attr],
ln_bias_attrs=[ln_b_attr],
qkv_weight_attrs=[qkv_w_attr],
qkv_bias_attrs=[qkv_b_attr],
linear_weight_attrs=[linear_w_attr],
linear_bias_attrs=[linear_b_attr],
ffn_ln_scale_attrs=[ffn_ln_w_attr],
ffn_ln_bias_attrs=[ffn_ln_b_attr],
ffn1_weight_attrs=[ffn1_w_attr],
ffn1_bias_attrs=[ffn1_b_attr],
ffn2_weight_attrs=[ffn2_w_attr],
ffn2_bias_attrs=[ffn2_b_attr])
result = multi_transformer(data)
# fused_multi_transformer have no backward
result.stop_gradient = True
predict = paddle.mean(result)
return predict
class TestModelParallel(TestDistRunnerBase):
def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None):
# Input data
seq_len = 2
data_in = fluid.data(
name='data_in', shape=[batch_size, seq_len, hidden], dtype=DTYPE)
if dist_strategy:
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[data_in],
capacity=64,
use_double_buffer=False,
iterable=False)
if dist_strategy:
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2}
rank = fleet.worker_index() if dist_strategy else None
avg_cost = create_model(data_in, rank)
opt = fluid.optimizer.SGD(0.1)
if dist_strategy:
dist_opt = fleet.distributed_optimizer(
optimizer=opt, strategy=strategy)
dist_opt.minimize(avg_cost)
else:
opt.minimize(avg_cost)
def gen_data():
np.random.seed(2021)
while True:
data = [np.random.random([seq_len, hidden]).astype(DTYPE)]
yield data
train_reader = paddle.batch(gen_data, batch_size=batch_size)
if dist_strategy:
return None, avg_cost, train_reader, None, None, None, data_loader
else:
return None, avg_cost, train_reader, None, None, None
if __name__ == "__main__":
runtime_main(TestModelParallel)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
import os
import paddle
paddle.enable_static()
flag_name = os.path.splitext(__file__)[0]
class TestStaticModelParallel(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl_comm_num = 1
self._pipeline_mode = True
def test_dist_static_model_parallel_fused_multi_transformer(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"static_model_parallel_fused_multi_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == '__main__':
unittest.main()
......@@ -15,10 +15,11 @@
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
from .layer.fused_transformer import FusedMultiTransformer # noqa: F401
__all__ = [ #noqa
'FusedMultiHeadAttention',
'FusedFeedForward',
'FusedTransformerEncoderLayer',
'FusedMultiTransformer',
]
......@@ -14,5 +14,10 @@
from .fused_transformer import fused_multi_head_attention
from .fused_transformer import fused_feedforward
from .fused_transformer import fused_multi_transformer
__all__ = ['fused_multi_head_attention', 'fused_feedforward']
__all__ = [
'fused_multi_head_attention',
'fused_feedforward',
'fused_multi_transformer',
]
......@@ -488,3 +488,238 @@ def fused_multi_head_attention(x,
attrs=attrs)
return (final_out, cache_kv_out) if cache_kv else final_out
def fused_multi_transformer(x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=True,
epsilon=1e-05,
cache_kvs=None,
time_step=None,
attn_mask=None,
dropout_rate=0.0,
activation="gelu",
training=False,
mode='upscale_in_train',
ring_id=-1,
name=None):
r"""
This is a fusion operator to compute multi transformer layers in transformer model architecture.
This operator only supports running on GPU. The function of the transformer layer is consistent
with the following pseudo code:
.. code-block:: python
if pre_layer_norm:
out = layer_norm(x)
out = qkv_linear(out) + qkv_bias
else:
out = qkv_linear(x) + qkv_bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out.
q = out[0:1, ::]
k = out[1:2, ::]
v = out[2:3, ::]
out = q * k^t
out = attn_mask + out
out = softmax(out)
out = dropout(out)
out = out * v
out = transpose(out, perm=[0, 2, 1, 3])
out = linear(out)
if pre_layer_norm:
out = x + dropout(out + bias)
else:
out = layer_norm(x + dropout(out + bias))
residual = out;
if pre_layer_norm:
out = ffn_layer_norm(out)
out = ffn1_linear(out)
out = dropout(activation(out + ffn1_bias))
out = ffn2_linear(out)
out = residual + dropout(out + ffn2_bias)
if not pre_layer_norm:
out = ffn_layer_norm(out)
Args:
x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`.
ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, the shape is `[d\_model]`.
ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. the shape is `[d\_model]`.
qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head, d\_model]`.
qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head]`.
linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. The shape is `[num\_head * dim\_head, d\_model]`.
linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. The shape is `[d\_model]`.
ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, the shape is `[d\_model]`
ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, the shape is `[d\_model]`
ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, the shape is `[d\_model, dim\_feedforward]`.
ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, the shape is `[dim\_feedforward]`.
ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, the shape is `[dim\_feedforward, d\_model]`.
ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, the shape is `[d_model]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True.
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None.
dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0.
activation (str, optional): The activation. Default "gelu".
training (bool, optional): A flag indicating whether it is in train phrase or not. Default False.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor|tuple: If `cache_kvs` is None, return a tensor that has
the same shape and data type with `x`, representing the output
of Transformer layers. If `cache_kvs` is not None, return the
tuple (output, cache_kvs), which output is the output of
Transformer layers, cache_kvs is inplace with input `cache_kvs`.
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.nn.functional as F
import numpy as np
# input: [batch_size, seq_len, embed_dim]
x = paddle.rand(shape=(2, 4, 128), dtype="float32")
# ln_scale: [embed_dim], ln_bias: [embed_dim]
ln_scale = paddle.rand(shape=(128,), dtype="float32")
ln_bias = paddle.rand(shape=(128,), dtype="float32")
# qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim]
qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
# linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim]
linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
linear_bias = paddle.rand(shape=(128,), dtype="float32")
# ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim]
ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32")
ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32")
# ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim]
ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32")
ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32")
# ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim]
ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32")
ffn2_bias = paddle.rand(shape=(128,), dtype="float32")
# self attention mask: [batch_size, 1, seq_len, seq_len]
attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32")
# output: [batch_size, seq_len, embed_dim]
output = F.fused_multi_transformer(
x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias],
[linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias],
[ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias],
attn_mask=attn_mask)
# [2, 4, 128]
print(output.shape)
"""
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if _non_static_mode():
cache_kv_out, final_out = _C_ops.fused_multi_transformer(
x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs,
time_step, attn_mask, linear_weights, linear_biases, ffn_ln_scales,
ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases,
cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon,
'dropout_rate', dropout_rate, 'dropout_is_test', not training,
'dropout_implementation', mode, 'act_method', activation, 'ring_id',
ring_id)
if cache_kvs is not None:
return final_out, cache_kv_out
return final_out
else:
helper = LayerHelper('fused_multi_transformer', **locals())
dtype = x.dtype
# check dtypes
check_variable_and_dtype(x, 'x', ['float16', 'float32'],
'fused_multi_transformer')
check_dtype(dtype, 'dtype', ['float16', 'float32'],
'fused_multi_transformer')
# set inputs
inputs = dict()
inputs['X'] = [x]
inputs['LnScale'] = ln_scales
inputs['LnBias'] = ln_biases
inputs['QKVW'] = qkv_weights
if qkv_biases is not None:
inputs['QKVBias'] = qkv_biases
if cache_kvs is not None:
assert len(cache_kvs) == len(qkv_weights)
inputs['CacheKV'] = cache_kvs
if time_step is not None:
inputs['TimeStep'] = time_step
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights
if linear_biases is not None:
inputs['OutLinearBias'] = linear_biases
inputs['FFNLnScale'] = ffn_ln_scales
inputs['FFNLnBias'] = ffn_ln_biases
inputs['FFN1Weight'] = ffn1_weights
if ffn1_biases is not None:
inputs['FFN1Bias'] = ffn1_biases
inputs['FFN2Weight'] = ffn2_weights
if ffn2_biases is not None:
inputs['FFN2Bias'] = ffn2_biases
# set attrs
attrs = {
'pre_layer_norm': pre_layer_norm,
'epsilon': epsilon,
'dropout_rate': dropout_rate,
'dropout_is_test': not training,
'dropout_implementation': mode,
'act_method': activation,
'ring_id': ring_id
}
outputs = dict()
final_out = helper.create_variable_for_type_inference(dtype=dtype)
outputs['Out'] = final_out
if cache_kvs:
# NOTE: inplace
outputs['CacheKVOut'] = cache_kvs
helper.append_op(
type='fused_multi_transformer',
inputs=inputs,
outputs=outputs,
attrs=attrs)
return (final_out, cache_kvs) if cache_kvs else final_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册