未验证 提交 79bfb184 编写于 作者: Y Yuanle Liu 提交者: GitHub

multihead_matmul op support codegen and kernel remove to phi (#56846)

上级 7fd6ffb8
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" #include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -254,7 +254,7 @@ int MultiheadMatmulRoformerPlugin::enqueue( ...@@ -254,7 +254,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
platform::CUDAPlace(device_id))); platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx; const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func; phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx, multihead_compute_func(dev_ctx,
batch, batch,
seq_len, seq_len,
...@@ -341,7 +341,7 @@ int MultiheadMatmulRoformerPlugin::enqueue( ...@@ -341,7 +341,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
tptr, static_cast<half>(scale_), n_q); tptr, static_cast<half>(scale_), n_q);
const phi::GPUContext &dev_ctx = *device_ctx; const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func; phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx, multihead_compute_func(dev_ctx,
batch, batch,
seq_len, seq_len,
......
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" #include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -396,7 +396,7 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -396,7 +396,7 @@ int QkvToContextPluginDynamic::enqueue(
platform::CUDAPlace(device_id))); platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx; const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func; phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx, multihead_compute_func(dev_ctx,
batch, batch,
seq_len, seq_len,
...@@ -506,7 +506,7 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -506,7 +506,7 @@ int QkvToContextPluginDynamic::enqueue(
tptr, static_cast<half>(scale_), n_q); tptr, static_cast<half>(scale_), n_q);
const phi::GPUContext &dev_ctx = *device_ctx; const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func; phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx, multihead_compute_func(dev_ctx,
batch, batch,
seq_len, seq_len,
......
...@@ -10,7 +10,6 @@ register_operators( ...@@ -10,7 +10,6 @@ register_operators(
fusion_transpose_flatten_concat_op fusion_transpose_flatten_concat_op
fusion_conv_inception_op fusion_conv_inception_op
fused_fc_elementwise_layernorm_op fused_fc_elementwise_layernorm_op
multihead_matmul_op
self_dp_attention_op self_dp_attention_op
skip_layernorm_op skip_layernorm_op
yolo_box_head_op yolo_box_head_op
...@@ -74,8 +73,6 @@ if(WITH_GPU OR WITH_ROCM) ...@@ -74,8 +73,6 @@ if(WITH_GPU OR WITH_ROCM)
endif() endif()
# fused_fc_elementwise_layernorm_op # fused_fc_elementwise_layernorm_op
op_library(fused_fc_elementwise_layernorm_op) op_library(fused_fc_elementwise_layernorm_op)
# multihead_matmul_op
op_library(multihead_matmul_op)
op_library(skip_layernorm_op) op_library(skip_layernorm_op)
op_library(yolo_box_head_op) op_library(yolo_box_head_op)
op_library(yolo_box_post_op) op_library(yolo_box_post_op)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace operators {
class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(
context->HasInput("Input"),
true,
platform::errors::InvalidArgument(
"Input(Input) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("W"),
true,
platform::errors::InvalidArgument(
"Input(W) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Bias"),
true,
platform::errors::InvalidArgument(
"Input(Bias) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasOutput("Out"),
true,
platform::errors::InvalidArgument(
"Output(Out) of MultiHeadMatMul should not be null."));
auto dim_w = context->GetInputDim("W");
PADDLE_ENFORCE_GT(
dim_w.size(),
2,
platform::errors::InvalidArgument(
"Multihead input is expected at least a 3-D tensor, but "
"it's %d-D tensor now.",
dim_w.size()));
auto dim_bias_q = context->GetInputDim("Bias");
PADDLE_ENFORCE_GT(
dim_bias_q.size(),
1,
platform::errors::InvalidArgument(
"Multihead input should be at least 2-D tensor, but it's "
"%d-D tensor now.",
dim_bias_q.size()));
auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out");
}
};
class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input of MultiHeadMatMul op");
AddInput("W", "The weight input of MultiHeadMatMul op");
AddInput("Bias", "The bias input of MultiHeadMatMul op");
AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op")
.AsDispensable();
AddOutput("Out", "The output of MultiHeadMatMul op");
AddAttr<bool>("transpose_Q",
R"DOC(If true, use the transpose of `Q`.
)DOC")
.SetDefault(false);
AddAttr<bool>("transpose_K",
R"DOC(If true, use the transpose of `K`.
)DOC")
.SetDefault(true);
AddAttr<bool>("transpose_V",
R"DOC(If true, use the transpose of `V`.
)DOC")
.SetDefault(false);
AddAttr<float>("alpha", "The scale of Out").SetDefault(1.0f);
AddAttr<int>("head_number", "The number of heads of the matrix")
.SetDefault(1);
AddComment(R"DOC(
MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul,
ops::MultiHeadMatMulV2Op,
ops::MultiHeadMatMulV2OpMaker);
...@@ -77,35 +77,6 @@ class EmbEltwiseLayerNormFunctor { ...@@ -77,35 +77,6 @@ class EmbEltwiseLayerNormFunctor {
gpuStream_t stream); gpuStream_t stream);
}; };
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template <typename T>
class MultiHeadGPUComputeFunctor {
public:
void operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta);
};
// This functor involves a fusion calculation in Ernie or Bert. // This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows: // The fusion mode is as follows:
// //
......
...@@ -189,6 +189,16 @@ ...@@ -189,6 +189,16 @@
data_type : x data_type : x
optional : mask, seq_lod, max_seq_len, x_fp16, out_fp16 optional : mask, seq_lod, max_seq_len, x_fp16, out_fp16
- op : multihead_matmul
args : (Tensor input, Tensor w, Tensor bias, Tensor bias_qk, bool transpose_q = false, bool transpose_k = true, bool transpose_v = false, float alpha = 1.0f, int head_number = 1)
output : Tensor(out)
infer_meta :
func : MultiheadMatmulInferMeta
kernel :
func : multihead_matmul
data_type : input
optional : bias_qk
- op : yolo_box_xpu - op : yolo_box_xpu
args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset) args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset)
output : Tensor(out), Tensor(out_max) output : Tensor(out), Tensor(out_max)
......
...@@ -1935,6 +1935,14 @@ ...@@ -1935,6 +1935,14 @@
outputs : outputs :
{out : Out, index : Index, nms_rois_num : NmsRoisNum} {out : Out, index : Index, nms_rois_num : NmsRoisNum}
- op : multihead_matmul
inputs :
{input : Input, w : W, bias : Bias, bias_qk : BiasQK}
outputs :
out : Out
attrs :
{transpose_q : transpose_Q, transpose_k : transpose_K, transpose_v : transpose_V}
- op : multinomial - op : multinomial
inputs : inputs :
{x : X} {x : X}
......
...@@ -4126,6 +4126,39 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, ...@@ -4126,6 +4126,39 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32); out_count->set_dtype(DataType::INT32);
} }
void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
MetaTensor* out) {
auto w_dims = w.dims();
PADDLE_ENFORCE_GT(
w_dims.size(),
2,
errors::InvalidArgument(
"MultiheadMatmul's w is expected at least a 3-D tensor, but "
"it's %d-D tensor now.",
w_dims.size()));
auto bias_dims = bias.dims();
PADDLE_ENFORCE_GT(
bias_dims.size(),
1,
errors::InvalidArgument(
"MultiheadMatmul's bias should be at least 2-D tensor, but it's "
"%d-D tensor now.",
bias_dims.size()));
out->set_dims(input.dims());
out->set_dtype(input.dtype());
out->share_lod(input);
}
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv, const MetaTensor& cache_kv,
const MetaTensor& bias, const MetaTensor& bias,
......
...@@ -811,6 +811,17 @@ void FusedRopeInferMeta(const MetaTensor& q, ...@@ -811,6 +811,17 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k, MetaTensor* out_k,
MetaTensor* out_v); MetaTensor* out_v);
void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
MetaTensor* out);
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv, const MetaTensor& cache_kv,
const MetaTensor& bias, const MetaTensor& bias,
......
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace funcs {
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template <typename T>
class MultiheadGPUComputeFunctor {
public:
void operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta);
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,20 +12,19 @@ ...@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <paddle/fluid/platform/device_context.h>
#include <algorithm> #include <algorithm>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/common/float16.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/phi/core/errors.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle { namespace phi {
namespace operators { namespace fusion {
template <typename T> template <typename T>
__global__ void transpose(T *src, __global__ void transpose(T *src,
...@@ -149,7 +148,7 @@ void TransQKVWithBias(const int batch, ...@@ -149,7 +148,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024). // limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, PADDLE_ENFORCE_LE(h * head_num,
1024, 1024,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d", "head_num (%d) * head_size (%d) should <= %d",
head_num, head_num,
head_size, head_size,
...@@ -165,7 +164,7 @@ void TransQKVWithBias(const int batch, ...@@ -165,7 +164,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024). // limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, PADDLE_ENFORCE_LE(h * head_num,
1024, 1024,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d", "head_num (%d) * head_size (%d) should <= %d",
head_num, head_num,
head_size, head_size,
...@@ -177,7 +176,7 @@ void TransQKVWithBias(const int batch, ...@@ -177,7 +176,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024). // limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num, PADDLE_ENFORCE_LE(head_size * head_num,
1024, 1024,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d", "head_num (%d) * head_size (%d) should <= %d",
head_num, head_num,
head_size, head_size,
...@@ -193,9 +192,9 @@ void TransQKVWithBias(const int batch, ...@@ -193,9 +192,9 @@ void TransQKVWithBias(const int batch,
const int seq_len, const int seq_len,
const int head_size, const int head_size,
const int head_num, const int head_num,
const platform::float16 *input, const phi::dtype::float16 *input,
const platform::float16 *bias, const phi::dtype::float16 *bias,
platform::float16 *output, phi::dtype::float16 *output,
gpuStream_t stream) { gpuStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH // BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len; int scratch_size = batch * head_num * seq_len * seq_len;
...@@ -209,7 +208,7 @@ void TransQKVWithBias(const int batch, ...@@ -209,7 +208,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024). // limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, PADDLE_ENFORCE_LE(h * head_num,
1024, 1024,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d", "head_num (%d) * head_size (%d) should <= %d",
head_num, head_num,
head_size, head_size,
...@@ -225,7 +224,7 @@ void TransQKVWithBias(const int batch, ...@@ -225,7 +224,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024). // limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num, PADDLE_ENFORCE_LE(head_size * head_num,
1024, 1024,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d", "head_num (%d) * head_size (%d) should <= %d",
head_num, head_num,
head_size, head_size,
...@@ -240,7 +239,7 @@ inline int round_up(int seq_len, int multiple = 32) { ...@@ -240,7 +239,7 @@ inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
multiple, multiple,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"multiple should be a positive number, but it's (%d)", multiple)); "multiple should be a positive number, but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple; return ((seq_len + multiple - 1) / multiple) * multiple;
} }
...@@ -270,168 +269,166 @@ __global__ void broadcast_batch_head_number(const T *src, ...@@ -270,168 +269,166 @@ __global__ void broadcast_batch_head_number(const T *src,
} }
} }
template <typename T, typename DeviceContext> template <typename T, typename Context>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { void MultiheadMatmulKernel(const Context &dev_ctx,
public: const DenseTensor &input,
void Compute(const framework::ExecutionContext &context) const override { const DenseTensor &w,
auto *input = context.Input<phi::DenseTensor>("Input"); const DenseTensor &bias,
auto *w = context.Input<phi::DenseTensor>("W"); const paddle::optional<DenseTensor> &bias_qk,
auto *bias = context.Input<phi::DenseTensor>("Bias"); const bool transpose_q,
auto *bias_qk = context.Input<phi::DenseTensor>("BiasQK"); const bool transpose_k,
const bool transpose_v,
auto *input_d = input->data<T>(); const float alpha,
auto *w_d = w->data<T>(); const int head_number,
auto *bias_d = bias->data<T>(); DenseTensor *out) {
auto *bias_qk_d = bias_qk ? bias_qk->data<T>() : nullptr; auto *input_d = input.data<T>();
T scale = static_cast<T>(context.Attr<float>("alpha")); auto *w_d = w.data<T>();
auto *bias_d = bias.data<T>();
int head_number = context.Attr<int>("head_number"); auto *bias_qk_d = bias_qk ? bias_qk->data<T>() : nullptr;
// compute q*k with eltadd T scale = static_cast<T>(alpha);
auto &device_ctx = context.template device_context<DeviceContext>();
auto stream = device_ctx.stream(); // compute q*k with eltadd
// should be (B * S * hidden) auto stream = dev_ctx.stream();
auto input_dims = input->dims(); // should be (B * S * hidden)
// shouble be (hidden * 3 * all_head_size) auto input_dims = input.dims();
auto w_dims = w->dims(); // shouble be (hidden * 3 * all_head_size)
int batch = input_dims[0]; auto w_dims = w.dims();
int seq_len = input_dims[1]; int batch = input_dims[0];
int hidden = input_dims[2]; int seq_len = input_dims[1];
phi::DenseTensor temp_bias_tensor; int hidden = input_dims[2];
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted phi::DenseTensor temp_bias_tensor;
if (bias_qk && bias_qk->numel() == (batch * seq_len)) { // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]"; if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]";
auto *temp_qk_bias = device_ctx.template Alloc<T>( temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); auto *temp_qk_bias = dev_ctx.template Alloc<T>(
int grid = batch * head_number * seq_len; &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int block = round_up(seq_len); int grid = batch * head_number * seq_len;
broadcast<<<grid, block, 0, stream>>>( int block = round_up(seq_len);
bias_qk_d, temp_qk_bias, seq_len, head_number); broadcast<<<grid, block, 0, stream>>>(
bias_qk_d = static_cast<const T *>(temp_qk_bias); bias_qk_d, temp_qk_bias, seq_len, head_number);
} bias_qk_d = static_cast<const T *>(temp_qk_bias);
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be }
// broadcasted // if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) { // broadcasted
VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]"; if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) {
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]";
auto *temp_qk_bias = device_ctx.template Alloc<T>( temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); auto *temp_qk_bias = dev_ctx.template Alloc<T>(
int grid = batch * head_number * seq_len; &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int block = round_up(seq_len); int grid = batch * head_number * seq_len;
broadcast_batch_head_number<<<grid, block, 0, stream>>>( int block = round_up(seq_len);
bias_qk_d, temp_qk_bias, batch, seq_len, head_number); broadcast_batch_head_number<<<grid, block, 0, stream>>>(
bias_qk_d = static_cast<const T *>(temp_qk_bias); bias_qk_d, temp_qk_bias, batch, seq_len, head_number);
} bias_qk_d = static_cast<const T *>(temp_qk_bias);
if (!bias_qk) { }
int size = batch * head_number * seq_len * seq_len; if (!bias_qk) {
temp_bias_tensor.Resize({size}); int size = batch * head_number * seq_len * seq_len;
auto *temp_qk_bias = device_ctx.template Alloc<T>( temp_bias_tensor.Resize({size});
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); auto *temp_qk_bias = dev_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipMemset(temp_qk_bias, 0, sizeof(float) * size); hipMemset(temp_qk_bias, 0, sizeof(float) * size);
#else #else
cudaMemset(temp_qk_bias, 0, sizeof(float) * size); cudaMemset(temp_qk_bias, 0, sizeof(float) * size);
#endif #endif
bias_qk_d = static_cast<const T *>(temp_qk_bias); bias_qk_d = static_cast<const T *>(temp_qk_bias);
} }
int all_head_size = w_dims[2]; int all_head_size = w_dims[2];
int head_size = all_head_size / head_number; int head_size = all_head_size / head_number;
auto *out = context.Output<phi::DenseTensor>("Out"); out->Resize({batch, seq_len, all_head_size});
out->Resize({batch, seq_len, all_head_size}); auto *output_d = dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
auto *output_d =
device_ctx.template Alloc<T>(out, out->numel() * sizeof(T)); // (B*S, hidden)
const phi::DenseTensor input_matrix =
// (B*S, hidden) phi::ReshapeToMatrix(input, 2 /*x_num_col_dims */);
const phi::DenseTensor input_matrix = // (hidden, 3 * all_head_size)
phi::ReshapeToMatrix(*input, 2 /*x_num_col_dims */); const phi::DenseTensor w_matrix =
// (hidden, 3 * all_head_size) phi::ReshapeToMatrix(w, 1 /*y_num_col_dims*/);
const phi::DenseTensor w_matrix =
phi::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/); phi::DenseTensor temp_out_tensor;
auto temp_out_dims =
phi::DenseTensor temp_out_tensor; phi::make_ddim({batch, seq_len, 3, head_number, head_size});
auto temp_out_dims = temp_out_tensor.Resize(
phi::make_ddim({batch, seq_len, 3, head_number, head_size}); {batch * seq_len, phi::product(temp_out_dims) / (batch * seq_len)});
temp_out_tensor.Resize( auto *temp_out_data = dev_ctx.template Alloc<T>(
{batch * seq_len, phi::product(temp_out_dims) / (batch * seq_len)}); &temp_out_tensor, temp_out_tensor.numel() * sizeof(T));
auto *temp_out_data = device_ctx.template Alloc<T>(
&temp_out_tensor, temp_out_tensor.numel() * sizeof(T)); // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H) blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(device_ctx); VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor); // temp_out_tensor.Resize(temp_out_dims);
VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
VLOG(2) << temp_out_tensor; phi::DenseTensor multihead_temp_tensor;
// temp_out_tensor.Resize(temp_out_dims); // B * head_number * S * S * 1 + B * S * 3 * N * H
int scratch_size = batch * head_number * seq_len * seq_len * 1;
phi::DenseTensor multihead_temp_tensor; multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
// B * head_number * S * S * 1 + B * S * 3 * N * H auto *multihead_temp_data = dev_ctx.template Alloc<T>(
int scratch_size = batch * head_number * seq_len * seq_len * 1; &multihead_temp_tensor, multihead_temp_tensor.numel() * sizeof(T));
multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
auto *multihead_temp_data = device_ctx.template Alloc<T>( auto *qkptr = multihead_temp_data;
&multihead_temp_tensor, multihead_temp_tensor.numel() * sizeof(T)); auto *tptr = multihead_temp_data + scratch_size;
auto *qkptr = multihead_temp_data; // Do the transpose with bias.
auto *tptr = multihead_temp_data + scratch_size; // BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch,
// Do the transpose with bias. seq_len,
// BxSx3xNxH => tptr: 3xBxNxSxH. head_size,
TransQKVWithBias(batch, head_number,
seq_len, temp_out_data,
head_size, bias_d,
head_number, tptr,
temp_out_data, stream);
bias_d, if (std::is_same<T, phi::dtype::float16>::value) {
tptr, phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
stream); multihead_compute_func(dev_ctx,
if (std::is_same<T, platform::float16>::value) { batch,
math::MultiHeadGPUComputeFunctor<half> multihead_compute_func; seq_len,
multihead_compute_func(device_ctx, head_number,
batch, head_size,
seq_len, reinterpret_cast<half *>(qkptr),
head_number, reinterpret_cast<const half *>(bias_qk_d),
head_size, false,
reinterpret_cast<half *>(qkptr), reinterpret_cast<half *>(tptr),
reinterpret_cast<const half *>(bias_qk_d), __float2half(static_cast<float>(scale)),
false, __float2half(0.0));
reinterpret_cast<half *>(tptr), } else {
__float2half(static_cast<float>(scale)), phi::funcs::MultiheadGPUComputeFunctor<T> multihead_compute_func;
__float2half(0.0)); multihead_compute_func(dev_ctx,
} else { batch,
math::MultiHeadGPUComputeFunctor<T> multihead_compute_func; seq_len,
multihead_compute_func(device_ctx, head_number,
batch, head_size,
seq_len, qkptr,
head_number, bias_qk_d,
head_size, false,
qkptr, tptr,
bias_qk_d, scale,
false, T(0.0));
tptr,
scale,
T(0.0));
}
int grid = batch * head_number * seq_len;
int block = head_size;
transpose<T><<<grid, block, 0, stream>>>(
tptr, output_d, batch, seq_len, head_number, head_size);
} }
};
} // namespace operators int grid = batch * head_number * seq_len;
} // namespace paddle int block = head_size;
transpose<T><<<grid, block, 0, stream>>>(
tptr, output_d, batch, seq_len, head_number, head_size);
}
} // namespace fusion
} // namespace phi
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
PD_REGISTER_STRUCT_KERNEL(multihead_matmul, PD_REGISTER_KERNEL(multihead_matmul,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
ops::MultiHeadMatMulV2Kernel, phi::fusion::MultiheadMatmulKernel,
float, float,
plat::float16) {} phi::dtype::float16) {}
#else #else
PD_REGISTER_STRUCT_KERNEL( PD_REGISTER_KERNEL(multihead_matmul,
multihead_matmul, GPU, ALL_LAYOUT, ops::MultiHeadMatMulV2Kernel, float) {} GPU,
ALL_LAYOUT,
phi::fusion::MultiheadMatmulKernel,
float) {}
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册