未验证 提交 79bfb184 编写于 作者: Y Yuanle Liu 提交者: GitHub

multihead_matmul op support codegen and kernel remove to phi (#56846)

上级 7fd6ffb8
......@@ -22,9 +22,9 @@
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle {
namespace inference {
......@@ -254,7 +254,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......@@ -341,7 +341,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
tptr, static_cast<half>(scale_), n_q);
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......
......@@ -24,9 +24,9 @@
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle {
namespace inference {
......@@ -396,7 +396,7 @@ int QkvToContextPluginDynamic::enqueue(
platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......@@ -506,7 +506,7 @@ int QkvToContextPluginDynamic::enqueue(
tptr, static_cast<half>(scale_), n_q);
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......
......@@ -10,7 +10,6 @@ register_operators(
fusion_transpose_flatten_concat_op
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op
self_dp_attention_op
skip_layernorm_op
yolo_box_head_op
......@@ -74,8 +73,6 @@ if(WITH_GPU OR WITH_ROCM)
endif()
# fused_fc_elementwise_layernorm_op
op_library(fused_fc_elementwise_layernorm_op)
# multihead_matmul_op
op_library(multihead_matmul_op)
op_library(skip_layernorm_op)
op_library(yolo_box_head_op)
op_library(yolo_box_post_op)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace operators {
class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(
context->HasInput("Input"),
true,
platform::errors::InvalidArgument(
"Input(Input) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("W"),
true,
platform::errors::InvalidArgument(
"Input(W) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Bias"),
true,
platform::errors::InvalidArgument(
"Input(Bias) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasOutput("Out"),
true,
platform::errors::InvalidArgument(
"Output(Out) of MultiHeadMatMul should not be null."));
auto dim_w = context->GetInputDim("W");
PADDLE_ENFORCE_GT(
dim_w.size(),
2,
platform::errors::InvalidArgument(
"Multihead input is expected at least a 3-D tensor, but "
"it's %d-D tensor now.",
dim_w.size()));
auto dim_bias_q = context->GetInputDim("Bias");
PADDLE_ENFORCE_GT(
dim_bias_q.size(),
1,
platform::errors::InvalidArgument(
"Multihead input should be at least 2-D tensor, but it's "
"%d-D tensor now.",
dim_bias_q.size()));
auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out");
}
};
class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input of MultiHeadMatMul op");
AddInput("W", "The weight input of MultiHeadMatMul op");
AddInput("Bias", "The bias input of MultiHeadMatMul op");
AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op")
.AsDispensable();
AddOutput("Out", "The output of MultiHeadMatMul op");
AddAttr<bool>("transpose_Q",
R"DOC(If true, use the transpose of `Q`.
)DOC")
.SetDefault(false);
AddAttr<bool>("transpose_K",
R"DOC(If true, use the transpose of `K`.
)DOC")
.SetDefault(true);
AddAttr<bool>("transpose_V",
R"DOC(If true, use the transpose of `V`.
)DOC")
.SetDefault(false);
AddAttr<float>("alpha", "The scale of Out").SetDefault(1.0f);
AddAttr<int>("head_number", "The number of heads of the matrix")
.SetDefault(1);
AddComment(R"DOC(
MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul,
ops::MultiHeadMatMulV2Op,
ops::MultiHeadMatMulV2OpMaker);
......@@ -77,35 +77,6 @@ class EmbEltwiseLayerNormFunctor {
gpuStream_t stream);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template <typename T>
class MultiHeadGPUComputeFunctor {
public:
void operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
......
......@@ -189,6 +189,16 @@
data_type : x
optional : mask, seq_lod, max_seq_len, x_fp16, out_fp16
- op : multihead_matmul
args : (Tensor input, Tensor w, Tensor bias, Tensor bias_qk, bool transpose_q = false, bool transpose_k = true, bool transpose_v = false, float alpha = 1.0f, int head_number = 1)
output : Tensor(out)
infer_meta :
func : MultiheadMatmulInferMeta
kernel :
func : multihead_matmul
data_type : input
optional : bias_qk
- op : yolo_box_xpu
args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset)
output : Tensor(out), Tensor(out_max)
......
......@@ -1935,6 +1935,14 @@
outputs :
{out : Out, index : Index, nms_rois_num : NmsRoisNum}
- op : multihead_matmul
inputs :
{input : Input, w : W, bias : Bias, bias_qk : BiasQK}
outputs :
out : Out
attrs :
{transpose_q : transpose_Q, transpose_k : transpose_K, transpose_v : transpose_V}
- op : multinomial
inputs :
{x : X}
......
......@@ -4126,6 +4126,39 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32);
}
void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
MetaTensor* out) {
auto w_dims = w.dims();
PADDLE_ENFORCE_GT(
w_dims.size(),
2,
errors::InvalidArgument(
"MultiheadMatmul's w is expected at least a 3-D tensor, but "
"it's %d-D tensor now.",
w_dims.size()));
auto bias_dims = bias.dims();
PADDLE_ENFORCE_GT(
bias_dims.size(),
1,
errors::InvalidArgument(
"MultiheadMatmul's bias should be at least 2-D tensor, but it's "
"%d-D tensor now.",
bias_dims.size()));
out->set_dims(input.dims());
out->set_dtype(input.dtype());
out->share_lod(input);
}
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& bias,
......
......@@ -811,6 +811,17 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k,
MetaTensor* out_v);
void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
MetaTensor* out);
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& bias,
......
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace funcs {
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template <typename T>
class MultiheadGPUComputeFunctor {
public:
void operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta);
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle {
namespace operators {
namespace phi {
namespace fusion {
template <typename T>
__global__ void transpose(T *src,
......@@ -149,7 +148,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -165,7 +164,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -177,7 +176,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -193,9 +192,9 @@ void TransQKVWithBias(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const platform::float16 *input,
const platform::float16 *bias,
platform::float16 *output,
const phi::dtype::float16 *input,
const phi::dtype::float16 *bias,
phi::dtype::float16 *output,
gpuStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len;
......@@ -209,7 +208,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -225,7 +224,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -240,7 +239,7 @@ inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"multiple should be a positive number, but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}
......@@ -270,168 +269,166 @@ __global__ void broadcast_batch_head_number(const T *src,
}
}
template <typename T, typename DeviceContext>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *input = context.Input<phi::DenseTensor>("Input");
auto *w = context.Input<phi::DenseTensor>("W");
auto *bias = context.Input<phi::DenseTensor>("Bias");
auto *bias_qk = context.Input<phi::DenseTensor>("BiasQK");
auto *input_d = input->data<T>();
auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk ? bias_qk->data<T>() : nullptr;
T scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = context.Attr<int>("head_number");
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
auto stream = device_ctx.stream();
// should be (B * S * hidden)
auto input_dims = input->dims();
// shouble be (hidden * 3 * all_head_size)
auto w_dims = w->dims();
int batch = input_dims[0];
int seq_len = input_dims[1];
int hidden = input_dims[2];
phi::DenseTensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) {
VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, batch, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
if (!bias_qk) {
int size = batch * head_number * seq_len * seq_len;
temp_bias_tensor.Resize({size});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
template <typename T, typename Context>
void MultiheadMatmulKernel(const Context &dev_ctx,
const DenseTensor &input,
const DenseTensor &w,
const DenseTensor &bias,
const paddle::optional<DenseTensor> &bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
DenseTensor *out) {
auto *input_d = input.data<T>();
auto *w_d = w.data<T>();
auto *bias_d = bias.data<T>();
auto *bias_qk_d = bias_qk ? bias_qk->data<T>() : nullptr;
T scale = static_cast<T>(alpha);
// compute q*k with eltadd
auto stream = dev_ctx.stream();
// should be (B * S * hidden)
auto input_dims = input.dims();
// shouble be (hidden * 3 * all_head_size)
auto w_dims = w.dims();
int batch = input_dims[0];
int seq_len = input_dims[1];
int hidden = input_dims[2];
phi::DenseTensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = dev_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) {
VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = dev_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, batch, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
if (!bias_qk) {
int size = batch * head_number * seq_len * seq_len;
temp_bias_tensor.Resize({size});
auto *temp_qk_bias = dev_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
#ifdef PADDLE_WITH_HIP
hipMemset(temp_qk_bias, 0, sizeof(float) * size);
hipMemset(temp_qk_bias, 0, sizeof(float) * size);
#else
cudaMemset(temp_qk_bias, 0, sizeof(float) * size);
cudaMemset(temp_qk_bias, 0, sizeof(float) * size);
#endif
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;
auto *out = context.Output<phi::DenseTensor>("Out");
out->Resize({batch, seq_len, all_head_size});
auto *output_d =
device_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
// (B*S, hidden)
const phi::DenseTensor input_matrix =
phi::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
// (hidden, 3 * all_head_size)
const phi::DenseTensor w_matrix =
phi::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);
phi::DenseTensor temp_out_tensor;
auto temp_out_dims =
phi::make_ddim({batch, seq_len, 3, head_number, head_size});
temp_out_tensor.Resize(
{batch * seq_len, phi::product(temp_out_dims) / (batch * seq_len)});
auto *temp_out_data = device_ctx.template Alloc<T>(
&temp_out_tensor, temp_out_tensor.numel() * sizeof(T));
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(device_ctx);
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
VLOG(2) << temp_out_tensor;
// temp_out_tensor.Resize(temp_out_dims);
phi::DenseTensor multihead_temp_tensor;
// B * head_number * S * S * 1 + B * S * 3 * N * H
int scratch_size = batch * head_number * seq_len * seq_len * 1;
multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
auto *multihead_temp_data = device_ctx.template Alloc<T>(
&multihead_temp_tensor, multihead_temp_tensor.numel() * sizeof(T));
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch,
seq_len,
head_size,
head_number,
temp_out_data,
bias_d,
tptr,
stream);
if (std::is_same<T, platform::float16>::value) {
math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(device_ctx,
batch,
seq_len,
head_number,
head_size,
reinterpret_cast<half *>(qkptr),
reinterpret_cast<const half *>(bias_qk_d),
false,
reinterpret_cast<half *>(tptr),
__float2half(static_cast<float>(scale)),
__float2half(0.0));
} else {
math::MultiHeadGPUComputeFunctor<T> multihead_compute_func;
multihead_compute_func(device_ctx,
batch,
seq_len,
head_number,
head_size,
qkptr,
bias_qk_d,
false,
tptr,
scale,
T(0.0));
}
int grid = batch * head_number * seq_len;
int block = head_size;
transpose<T><<<grid, block, 0, stream>>>(
tptr, output_d, batch, seq_len, head_number, head_size);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;
out->Resize({batch, seq_len, all_head_size});
auto *output_d = dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
// (B*S, hidden)
const phi::DenseTensor input_matrix =
phi::ReshapeToMatrix(input, 2 /*x_num_col_dims */);
// (hidden, 3 * all_head_size)
const phi::DenseTensor w_matrix =
phi::ReshapeToMatrix(w, 1 /*y_num_col_dims*/);
phi::DenseTensor temp_out_tensor;
auto temp_out_dims =
phi::make_ddim({batch, seq_len, 3, head_number, head_size});
temp_out_tensor.Resize(
{batch * seq_len, phi::product(temp_out_dims) / (batch * seq_len)});
auto *temp_out_data = dev_ctx.template Alloc<T>(
&temp_out_tensor, temp_out_tensor.numel() * sizeof(T));
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
// temp_out_tensor.Resize(temp_out_dims);
phi::DenseTensor multihead_temp_tensor;
// B * head_number * S * S * 1 + B * S * 3 * N * H
int scratch_size = batch * head_number * seq_len * seq_len * 1;
multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
auto *multihead_temp_data = dev_ctx.template Alloc<T>(
&multihead_temp_tensor, multihead_temp_tensor.numel() * sizeof(T));
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch,
seq_len,
head_size,
head_number,
temp_out_data,
bias_d,
tptr,
stream);
if (std::is_same<T, phi::dtype::float16>::value) {
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
head_number,
head_size,
reinterpret_cast<half *>(qkptr),
reinterpret_cast<const half *>(bias_qk_d),
false,
reinterpret_cast<half *>(tptr),
__float2half(static_cast<float>(scale)),
__float2half(0.0));
} else {
phi::funcs::MultiheadGPUComputeFunctor<T> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
head_number,
head_size,
qkptr,
bias_qk_d,
false,
tptr,
scale,
T(0.0));
}
};
} // namespace operators
} // namespace paddle
int grid = batch * head_number * seq_len;
int block = head_size;
transpose<T><<<grid, block, 0, stream>>>(
tptr, output_d, batch, seq_len, head_number, head_size);
}
} // namespace fusion
} // namespace phi
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
PD_REGISTER_STRUCT_KERNEL(multihead_matmul,
GPU,
ALL_LAYOUT,
ops::MultiHeadMatMulV2Kernel,
float,
plat::float16) {}
PD_REGISTER_KERNEL(multihead_matmul,
GPU,
ALL_LAYOUT,
phi::fusion::MultiheadMatmulKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_STRUCT_KERNEL(
multihead_matmul, GPU, ALL_LAYOUT, ops::MultiHeadMatMulV2Kernel, float) {}
PD_REGISTER_KERNEL(multihead_matmul,
GPU,
ALL_LAYOUT,
phi::fusion::MultiheadMatmulKernel,
float) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册