未验证 提交 0fd8ee63 编写于 作者: W Wilber 提交者: GitHub

Multihead matmul fp16 (#44792)

* multihead matmul add fp16

* fix windows error

* fix rocm error

* fix rocm error
上级 be0ec904
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -257,16 +260,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -257,16 +260,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
} }
PDNode* MultiHeadMatmulPattern::operator()() { PDNode* MultiHeadMatmulPattern::operator()() {
std::unordered_set<std::string> mul_ops{"mul", "matmul_v2"};
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr()); auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("mul"); input0->assert_is_ops_input(mul_ops);
// First path with scale // First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput() ->AsInput()
->assert_is_op_input("mul", "Y"); ->assert_is_ops_input(mul_ops, "Y");
auto* mul0_out_var = auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul"); pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul0) eltadd0; decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var; decltype(mul0) eltadd0_b_var;
...@@ -299,11 +304,12 @@ PDNode* MultiHeadMatmulPattern::operator()() { ...@@ -299,11 +304,12 @@ PDNode* MultiHeadMatmulPattern::operator()() {
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var = auto* scale_out_var =
pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); pattern->NewNode(scale_out_repr())->assert_is_op_output("scale");
scale_out_var->AsIntermediate()->assert_is_op_input("matmul"); scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk = auto* eltadd_qk =
...@@ -319,12 +325,12 @@ PDNode* MultiHeadMatmulPattern::operator()() { ...@@ -319,12 +325,12 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul"); softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv = auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul"); pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var = auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv = auto* transpose2_qkv =
...@@ -337,15 +343,15 @@ PDNode* MultiHeadMatmulPattern::operator()() { ...@@ -337,15 +343,15 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2"); ->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("mul"); reshape2_qkv_out_var->assert_is_ops_input(mul_ops);
// Second path to matmul // Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul"); auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput() ->AsInput()
->assert_is_op_input("mul", "Y"); ->assert_is_ops_input(mul_ops, "Y");
auto* mul1_out_var = auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul"); pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul1) eltadd1; decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var; decltype(mul1) eltadd1_b_var;
...@@ -372,16 +378,16 @@ PDNode* MultiHeadMatmulPattern::operator()() { ...@@ -372,16 +378,16 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2"); ->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input( transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
"matmul"); // link to matmul qk matmul_ops); // link to matmul qk
// Third path to matmul // Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul"); auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput() ->AsInput()
->assert_is_op_input("mul", "Y"); ->assert_is_ops_input(mul_ops, "Y");
auto* mul2_out_var = auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul"); pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul2) eltadd2; decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var; decltype(mul2) eltadd2_b_var;
...@@ -408,8 +414,8 @@ PDNode* MultiHeadMatmulPattern::operator()() { ...@@ -408,8 +414,8 @@ PDNode* MultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2"); ->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_op_input( transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
"matmul"); // link to matmul qkv matmul_ops); // link to matmul qkv
// Q path // Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
...@@ -631,6 +637,68 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { ...@@ -631,6 +637,68 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
} }
} // namespace patterns } // namespace patterns
namespace {
template <typename T>
inline void QKVWeightsProcess(Tensor* wq_tensor,
Tensor* wk_tensor,
Tensor* wv_tensor,
Tensor* bq_tensor,
Tensor* bk_tensor,
Tensor* bv_tensor) {
auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());
auto combined_w_dims =
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]});
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<T>(platform::CPUPlace());
std::vector<T*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<T>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size);
memcpy(
tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(T) * bq_tensor->numel());
}
} // namespace
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
...@@ -757,6 +825,23 @@ MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() { ...@@ -757,6 +825,23 @@ MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() {
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax")) AddOpCompat(OpCompat("softmax"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -820,16 +905,17 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -820,16 +905,17 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
auto* bv_tensor = auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>(); scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();
auto* wq_data = wq_tensor->mutable_data<float>(platform::CPUPlace()); if (wq_tensor->dtype() == phi::DataType::FLOAT32) {
auto* wk_data = wk_tensor->mutable_data<float>(platform::CPUPlace()); QKVWeightsProcess<float>(
auto* wv_data = wv_tensor->mutable_data<float>(platform::CPUPlace()); wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor);
auto* bq_data = bq_tensor->mutable_data<float>(platform::CPUPlace()); } else if (wq_tensor->dtype() == phi::DataType::FLOAT16) {
auto* bk_data = bk_tensor->mutable_data<float>(platform::CPUPlace()); QKVWeightsProcess<platform::float16>(
auto* bv_data = bv_tensor->mutable_data<float>(platform::CPUPlace()); wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor);
} else {
auto combined_w_dims = PADDLE_THROW(platform::errors::Unavailable(
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); "multihead_matmul not supported weight dtype. we now only support "
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]}); "fp32 and fp16."));
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes. // reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = mul0_w->Var(); auto* combined_w_desc = mul0_w->Var();
...@@ -840,53 +926,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -840,53 +926,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]}); combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true); combined_bias_desc->SetPersistable(true);
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data,
tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(
tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size,
bv_data,
sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op(); auto reshape_desc = reshape2->Op();
......
...@@ -154,18 +154,21 @@ const std::vector<std::string> kLiteSubgraphPasses({ ...@@ -154,18 +154,21 @@ const std::vector<std::string> kLiteSubgraphPasses({
// support fp16/bf16 precision, temporarily use low precision pass to prevent // support fp16/bf16 precision, temporarily use low precision pass to prevent
// running errors. After fusion operator supports low precision, delete this. // running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{ const std::vector<std::string> kGpuLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
"conv_bn_fuse_pass", "conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass", "conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass", "conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass", "conv_elementwise_add_fuse_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass", // "multihead_matmul_fuse_pass_v2",
"gpu_cpu_map_matmul_v2_to_matmul_pass", // "gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"fc_fuse_pass", "fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass", "fc_elementwise_layernorm_fuse_pass",
}; };
const std::vector<std::string> kTrtLowerPrecisionPasses{ const std::vector<std::string> kTrtLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
// "conv_bn_fuse_pass", // "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass",
"trt_map_matmul_v2_to_mul_pass", "trt_map_matmul_v2_to_mul_pass",
......
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
#include <paddle/fluid/platform/device_context.h> #include <paddle/fluid/platform/device_context.h>
#include <algorithm> #include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
...@@ -64,6 +66,26 @@ __device__ float4 add_func<float4>(float4 a, float4 b) { ...@@ -64,6 +66,26 @@ __device__ float4 add_func<float4>(float4 a, float4 b) {
c.w = a.w + b.w; c.w = a.w + b.w;
return c; return c;
} }
#if defined(PADDLE_WITH_CUDA)
template <>
__device__ half2 add_func<half2>(half2 a, half2 b) {
#if __CUDA_ARCH__ >= 530
return __hadd2(a, b);
#else
return half2(__float2half(__half2float(a.x) + __half2float(b.x)),
__float2half(__half2float(b.x) + __half2float(b.y)));
#endif
}
template <>
__device__ half add_func<half>(half a, half b) {
#if __CUDA_ARCH__ >= 530
return __hadd(a, b);
#else
return __float2half(__half2float(a) + __half2float(b));
#endif
}
#endif
template <typename T> template <typename T>
__global__ void TransposeQkvKernel(const int H, __global__ void TransposeQkvKernel(const int H,
...@@ -71,7 +93,7 @@ __global__ void TransposeQkvKernel(const int H, ...@@ -71,7 +93,7 @@ __global__ void TransposeQkvKernel(const int H,
const T *bias, const T *bias,
T *output) { T *output) {
// Input: BxSx3xNxH // Input: BxSx3xNxH
// Bias: 3xSxB // Bias: 3xNxH
// Output: 3xBxNxSxH // Output: 3xBxNxSxH
int n = threadIdx.y; int n = threadIdx.y;
int s = blockIdx.x; int s = blockIdx.x;
...@@ -93,6 +115,17 @@ __global__ void TransposeQkvKernel(const int H, ...@@ -93,6 +115,17 @@ __global__ void TransposeQkvKernel(const int H,
add_func(input[in_offset + i], bias[bias_offset + i]); add_func(input[in_offset + i], bias[bias_offset + i]);
} }
template <typename T>
void TransQKVWithBias(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const T *input,
const T *bias,
T *output,
gpuStream_t stream);
template <>
void TransQKVWithBias(const int batch, void TransQKVWithBias(const int batch,
const int seq_len, const int seq_len,
const int head_size, const int head_size,
...@@ -153,6 +186,55 @@ void TransQKVWithBias(const int batch, ...@@ -153,6 +186,55 @@ void TransQKVWithBias(const int batch,
} }
} }
#if defined(PADDLE_WITH_CUDA)
template <>
void TransQKVWithBias(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const platform::float16 *input,
const platform::float16 *bias,
platform::float16 *output,
gpuStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const half2 *input2 = reinterpret_cast<const half2 *>(input);
const half2 *bias2 = reinterpret_cast<const half2 *>(bias);
half2 *output2 = reinterpret_cast<half2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 2));
TransposeQkvKernel<half2>
<<<grid, block, 0, stream>>>(h, input2, bias2, output2);
} else {
const dim3 block(head_size, head_num, 1);
const half *input_half = reinterpret_cast<const half *>(input);
const half *bias_half = reinterpret_cast<const half *>(bias);
half *output_half = reinterpret_cast<half *>(output);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024));
TransposeQkvKernel<half><<<grid, block, 0, stream>>>(
head_size, input_half, bias_half, output_half);
}
}
#endif
inline int round_up(int seq_len, int multiple = 32) { inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
multiple, multiple,
...@@ -261,7 +343,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -261,7 +343,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
bias_d, bias_d,
tptr, tptr,
stream); stream);
if (std::is_same<T, platform::float16>::value) {
math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(device_ctx,
batch,
seq_len,
head_number,
head_size,
reinterpret_cast<half *>(qkptr),
reinterpret_cast<const half *>(bias_qk_d),
reinterpret_cast<half *>(tptr),
__float2half(static_cast<float>(scale)),
__float2half(0.0));
} else {
math::MultiHeadGPUComputeFunctor<T> multihead_compute_func; math::MultiHeadGPUComputeFunctor<T> multihead_compute_func;
multihead_compute_func(device_ctx, multihead_compute_func(device_ctx,
batch, batch,
...@@ -273,6 +367,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -273,6 +367,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
tptr, tptr,
scale, scale,
T(0.0)); T(0.0));
}
int grid = batch * head_number * seq_len; int grid = batch * head_number * seq_len;
int block = head_size; int block = head_size;
...@@ -285,5 +380,12 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -285,5 +380,12 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL(
multihead_matmul,
ops::MultiHeadMatMulV2Kernel<phi::GPUContext, paddle::platform::float16>,
ops::MultiHeadMatMulV2Kernel<phi::GPUContext, float>);
#else
REGISTER_OP_CUDA_KERNEL(multihead_matmul, REGISTER_OP_CUDA_KERNEL(multihead_matmul,
ops::MultiHeadMatMulV2Kernel<phi::GPUContext, float>); ops::MultiHeadMatMulV2Kernel<phi::GPUContext, float>);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册