未验证 提交 e8673668 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Add multihead op for ernie opt (#19933)

* Add multihead op for ernie opt

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine softmax

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine kernel.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine cuda kernel

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine cuda version

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Refine cmake

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 bfa55c9d
......@@ -116,7 +116,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "multihead_matmul_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -55,7 +55,7 @@ if (NOT WITH_MKL)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
sync_batch_norm_op multihead_matmul_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
......@@ -73,6 +73,8 @@ if (WITH_GPU)
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
endif()
op_library(multihead_matmul_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n")
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle {
namespace operators {
class MultiHeadMatMulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("Q"), true,
"Input(Q) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("K"), true,
"Input(K) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("V"), true,
"Input(V) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasQ"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasK"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasV"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasQK"), true,
"Input(BiasQK) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
"Output(Out) of MatMulOp should not be null.");
auto dim_q = context->GetInputDim("Q");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
auto dim_k = context->GetInputDim("K");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
auto dim_v = context->GetInputDim("V");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
PADDLE_ENFORCE_EQ(dim_q[0], dim_k[0],
"Multihead input should have same batch size");
PADDLE_ENFORCE_EQ(dim_q[0], dim_v[0],
"Multihead input should have same batch size");
PADDLE_ENFORCE_EQ(dim_q[1], dim_k[1],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[1], dim_v[1],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[2], dim_k[2],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[2], dim_v[2],
"Multihead input should have same size");
auto dim_bias_q = context->GetInputDim("BiasQ");
PADDLE_ENFORCE_GT(dim_bias_q.size(), 0,
"Multihead input should be at least 1-D tensor.");
auto dim_bias_k = context->GetInputDim("BiasK");
PADDLE_ENFORCE_GT(dim_bias_k.size(), 0,
"Multihead input should be at least 1-D tensor.");
auto dim_bias_v = context->GetInputDim("BiasV");
PADDLE_ENFORCE_GT(dim_bias_v.size(), 0,
"Multihead input should be at least 1-D tensor.");
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_k[0],
"Multihead input bias should have same batch size");
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0],
"Multihead input bias should have same batch size");
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_k[1],
"Multihead input bias should have same size");
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_v[1],
"Multihead input bias should have same size");
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
"Multihead input bias qk should be at least 4-D tensor.");
int head_number = context->Attrs().Get<int>("head_number");
PADDLE_ENFORCE_GT(head_number, 1,
"Multihead input head number should be at least 1.");
context->SetOutputDim("Out", dim_q);
context->ShareLoD("Q", /*->*/ "Out");
}
};
class MultiHeadMatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Q", "The first input of MultiHeadMatMul op");
AddInput("K", "The second input of MMultiHeadMatMul op");
AddInput("V", "The third input of MultiHeadMatMul op");
AddInput("BiasQ", "The first bias input of MultiHeadMatMul op");
AddInput("BiasK", "The second bias input of MultiHeadMatMul op");
AddInput("BiasV", "The third bias input of MultiHeadMatMul op");
AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op");
AddOutput("Out", "The output of MultiHeadMatMul op");
AddAttr<bool>("transpose_Q",
R"DOC(If true, use the transpose of `Q`.
)DOC")
.SetDefault(false);
AddAttr<bool>("transpose_K",
R"DOC(If true, use the transpose of `K`.
)DOC")
.SetDefault(true);
AddAttr<bool>("transpose_V",
R"DOC(If true, use the transpose of `V`.
)DOC")
.SetDefault(false);
AddAttr<float>("alpha", "The scale of Out").SetDefault(1.0f);
AddAttr<int>("head_number", "The number of heads of the matrix")
.SetDefault(1);
AddComment(R"DOC(
MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of H
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
Both the input `Q` and `K` can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input `Q`, because
they are the same.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulOp,
ops::MultiHeadMatMulOpMaker);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize);
#else
val += __shfl_xor(val, mask, warpSize);
#endif
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize));
#else
val = max(val, __shfl_xor(val, mask, warpSize));
#endif
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMax(val);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f;
val = warpReduceMax(val);
return val;
}
template <typename T>
__global__ void add_QKV(const T *Q, const T *K, const T *V, T *q_buf_,
T *k_buf_, T *v_buf_, const T *bias_q, const T *bias_k,
const T *bias_v, int batch_size, int seq_len,
int head_num, int size_per_head) {
const T *data_ptr_q, *data_ptr_k, *data_ptr_v;
const T *bias_ptr_q, *bias_ptr_k, *bias_ptr_v;
int m = batch_size * seq_len;
int n = head_num * size_per_head;
int row_offset = (blockIdx.x % m) * n;
data_ptr_q = Q + row_offset;
data_ptr_k = K + row_offset;
data_ptr_v = V + row_offset;
// bias ptr
bias_ptr_q = bias_q;
bias_ptr_k = bias_k;
bias_ptr_v = bias_v;
int batch_id = (blockIdx.x % m) / seq_len;
int head_id = threadIdx.x / size_per_head;
int id_in_head = threadIdx.x % size_per_head;
int word_start_id = (blockIdx.x) % seq_len;
#if __CUDA_ARCH__ >= 350
T tmp_q = __ldg(&data_ptr_q[threadIdx.x]) + __ldg(&bias_ptr_q[threadIdx.x]);
T tmp_k = __ldg(&data_ptr_k[threadIdx.x]) + __ldg(&bias_ptr_k[threadIdx.x]);
T tmp_v = __ldg(&data_ptr_v[threadIdx.x]) + __ldg(&bias_ptr_v[threadIdx.x]);
#else
T tmp_q = data_ptr_q[threadIdx.x] + bias_ptr_q[threadIdx.x];
T tmp_k = data_ptr_k[threadIdx.x] + bias_ptr_k[threadIdx.x];
T tmp_v = data_ptr_v[threadIdx.x] + bias_ptr_v[threadIdx.x];
#endif
int target_id = batch_id * (seq_len * head_num * size_per_head) +
head_id * seq_len * size_per_head +
word_start_id * size_per_head + id_in_head;
q_buf_[target_id] = tmp_q;
k_buf_[target_id] = tmp_k;
v_buf_[target_id] = tmp_v;
}
// Keep to compare performance
template <typename T>
__global__ void add_QKV_V2(const T *Q, const T *K, const T *V, T *q_buf_,
T *k_buf_, T *v_buf_, const T *bias_Q,
const T *bias_K, const T *bias_V, int batch_size,
int seq_len, int head_num, int size_per_head,
const int word_per_block) {
const T *data_ptr;
T *buf_ptr;
const T *bias_ptr;
int m = batch_size * seq_len;
int n = head_num * size_per_head;
int qkv_id = blockIdx.x * word_per_block / m;
int row_offset = (blockIdx.x * word_per_block % m) * n;
if (qkv_id == 0) {
data_ptr = Q + row_offset;
buf_ptr = q_buf_;
bias_ptr = bias_Q;
} else if (qkv_id == 1) {
data_ptr = K + row_offset;
buf_ptr = k_buf_;
bias_ptr = bias_K;
} else {
data_ptr = V + row_offset;
buf_ptr = v_buf_;
bias_ptr = bias_V;
}
int batch_id = (blockIdx.x * word_per_block % m) / seq_len;
int head_id = threadIdx.x / size_per_head;
int id_in_head = threadIdx.x % size_per_head;
int word_start_id = (blockIdx.x * word_per_block) % seq_len;
#if __CUDA_ARCH__ >= 350
T bias = __ldg(&bias_ptr[threadIdx.x]);
#else
T bias = bias_ptr[threadIdx.x];
#endif
for (int i = word_start_id; i < word_start_id + word_per_block; ++i) {
T tmp = data_ptr[threadIdx.x] + bias;
int target_id = batch_id * (seq_len * head_num * size_per_head) +
head_id * seq_len * size_per_head + i * size_per_head +
id_in_head;
buf_ptr[target_id] = tmp;
data_ptr += n;
}
}
template <typename T>
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len) {
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
__shared__ float s_sum, s_max;
float qk = threadIdx.x < seq_len
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + bias_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
float max_val = blockReduceMax<float>(tmp);
if (threadIdx.x == 0) s_max = max_val;
__syncthreads();
float qk_tmp =
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp);
if (threadIdx.x == 0) {
s_sum = sum_val + 1e-6f;
}
__syncthreads();
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
}
// For verify result
template <typename T>
__global__ void elt_qk_add(const T *bias_qk, T *qk_buf, int head_num,
int seq_len, int size_per_head, int batch_size) {
int m = batch_size * head_num * seq_len;
int row_id = blockIdx.x % m;
int dst_id = row_id * seq_len + threadIdx.x;
const T *bias_ptr = bias_qk;
#if __CUDA_ARCH__ >= 350
int tmp_bias = __ldg(&bias_ptr[dst_id]);
#else
int tmp_bias = bias_ptr[dst_id];
#endif
qk_buf[dst_id] += tmp_bias;
}
// Compute Q*K->softmax->eltadd
template <typename T>
void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
int seq_len, int size_per_head, int batch_size,
bool q_trans, bool k_trans, T *q_buf_, T *k_buf_,
T *qk_buf_, const T *bias_qk, T alpha, T beta) {
CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
auto stream = context.stream();
blas.BatchedGEMM(transA, transB, seq_len, seq_len, size_per_head, alpha,
q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num,
seq_len * size_per_head, seq_len * size_per_head);
int m = batch_size * head_num * seq_len;
int k = seq_len;
int grid = m;
int block = k;
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len);
}
template <typename T>
__global__ void transpose(T *src, T *dst, const int batch_size,
const int seq_len, const int head_num,
const int size_per_head) {
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
dst[batch_id * (head_num * seq_len * size_per_head) +
seq_id * head_num * size_per_head + head_id * size_per_head +
threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}
// Compute QK*V->transpose
template <typename T>
void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num,
int seq_len, int size_per_head, int batch_size,
bool qk_trans, bool v_trans, T *v_buf_, const T *qk_buf_,
T *dst, T *out, T alpha, T beta) {
int m = batch_size * seq_len;
int k = head_num * size_per_head;
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
auto stream = context.stream();
CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;
blas.BatchedGEMM(transA, transB, seq_len, size_per_head, seq_len, alpha,
qk_buf_, v_buf_, beta, dst, batch_size * head_num,
seq_len * seq_len, seq_len * size_per_head);
int grid = batch_size * head_num * seq_len;
int block = size_per_head;
transpose<T><<<grid, block, 0, stream>>>(dst, out, batch_size, seq_len,
head_num, size_per_head);
}
template <typename T>
void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx,
int head_num, const framework::DDim &mat_q,
const framework::DDim &mat_k,
const framework::DDim &mat_v, const T *Q, const T *K,
const T *V, const T *bias_q, const T *bias_k,
const T *bias_v, const T *bias_qk, T *out, T alpha,
T beta, bool trans_q, bool trans_k, bool trans_v) {
int seq_len = mat_q[1];
int size_per_head = (mat_q[2] / head_num);
int batch_size = mat_q[0];
int buf_size = batch_size * head_num * seq_len * size_per_head;
int qk_buf_size = batch_size * head_num * seq_len * seq_len;
auto alloc_buf =
memory::Alloc(dev_ctx, (buf_size * 4 + qk_buf_size) * sizeof(T));
T *buf = reinterpret_cast<T *>(alloc_buf->ptr());
T *q_buf = buf;
T *k_buf = buf + buf_size;
T *v_buf = buf + 2 * buf_size;
T *qk_buf = buf + 3 * buf_size;
T *dst_buf = buf + 3 * buf_size + qk_buf_size;
int m = batch_size * seq_len;
int k = head_num * size_per_head;
// Each block process head*size-per_head element,
// have m lines. bias is m lines
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
auto stream = dev_ctx.stream();
int grid = m;
PADDLE_ENFORCE_LT(k, 1024,
"Input head_number * size_per_head should <= 1024");
int block = k <= 1024 ? k : 1024;
add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q,
bias_k, bias_v, batch_size, seq_len,
head_num, size_per_head);
MatMulWithHeadQK<T>(dev_ctx, head_num, seq_len, size_per_head, batch_size,
trans_q, trans_k, q_buf, k_buf, qk_buf, bias_qk, alpha,
beta);
MatMulWithHeadQKV<T>(dev_ctx, head_num, seq_len, size_per_head, batch_size,
false, trans_v, v_buf, qk_buf, dst_buf, out, T(1.0),
beta);
}
template <typename DeviceContext, typename T>
class MultiHeadMatMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *q = context.Input<framework::Tensor>("Q");
auto *k = context.Input<framework::Tensor>("K");
auto *v = context.Input<framework::Tensor>("V");
auto &bias_q = detail::Ref(context.Input<framework::Tensor>("BiasQ"),
"Cannot find BiasQ");
auto &bias_k = detail::Ref(context.Input<framework::Tensor>("BiasK"),
"Cannot find BiasK");
auto &bias_v = detail::Ref(context.Input<framework::Tensor>("BiasV"),
"Cannot find BiasV");
auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"),
"Cannot find QK");
auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
T scale = static_cast<T>(context.Attr<float>("alpha"));
bool transpose_q = context.Attr<bool>("transpose_Q");
bool transpose_k = context.Attr<bool>("transpose_K");
bool transpose_v = context.Attr<bool>("transpose_V");
int head_number = context.Attr<int>("head_number");
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
MultiHeadGPUCompute<T>(device_ctx, head_number, q->dims(), k->dims(),
v->dims(), q->data<T>(), k->data<T>(), v->data<T>(),
bias_q.data<T>(), bias_k.data<T>(), bias_v.data<T>(),
bias_qk.data<T>(), out->data<T>(), scale, T(0.0),
transpose_q, transpose_k, transpose_v);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
multihead_matmul,
ops::MultiHeadMatMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::MultiHeadMatMulKernel<paddle::platform::CUDADeviceContext, double>);
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
import paddle.fluid as fluid
np.random.random(123)
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"Paddle core is not compiled with CUDA")
class TestFusedMultiheadMatmulOp(OpTest):
def config(self):
self.seq_len = 128
self.size_per_head = 64
self.head_number = 12
self.batch_size = 1
self.scale = 0.125
def setUp(self):
self.op_type = "multihead_matmul"
self.config()
h = self.seq_len
w = self.head_number * self.size_per_head
self.Q = np.random.random((self.batch_size, h, w)).astype("float32")
self.K = np.random.random((self.batch_size, h, w)).astype("float32")
self.V = np.random.random((self.batch_size, h, w)).astype("float32")
self.BiasQ = np.random.random((1, w)).astype("float32")
self.BiasK = np.random.random((1, w)).astype("float32")
self.BiasV = np.random.random((1, w)).astype("float32")
self.BiasQK = np.random.random(
(1, self.head_number, self.seq_len, self.seq_len)).astype("float32")
# Compute Q path
fc_q = self.Q + self.BiasQ
reshape_q = np.reshape(fc_q, (self.batch_size, self.seq_len,
self.head_number, self.size_per_head))
transpose_q = np.transpose(reshape_q, (0, 2, 1, 3))
scale_q = self.scale * transpose_q
# Compute K path
fc_k = self.K + self.BiasK
reshape_k = np.reshape(fc_k, (self.batch_size, self.seq_len,
self.head_number, self.size_per_head))
transpose_k = np.transpose(reshape_k, (0, 2, 3, 1))
# Compute Q*K
q_k = np.matmul(scale_q, transpose_k)
eltadd_qk = q_k + self.BiasQK
softmax_qk = np.apply_along_axis(stable_softmax, 3, eltadd_qk)
# Compute V path
fc_v = self.V + self.BiasV
reshape_v = np.reshape(fc_v, (self.batch_size, self.seq_len,
self.head_number, self.size_per_head))
transpose_v = np.transpose(reshape_v, (0, 2, 1, 3))
# Compute QK*V
qkv = np.matmul(softmax_qk, transpose_v)
transpose_qkv = np.transpose(qkv, (0, 2, 1, 3))
reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w))
self.inputs = {
"Q": self.Q,
"K": self.K,
"V": self.V,
"BiasQ": self.BiasQ,
"BiasK": self.BiasK,
"BiasV": self.BiasV,
"BiasQK": self.BiasQK
}
self.attrs = {
"transpose_Q": False,
"transpose_K": True,
"transpose_V": False,
"head_number": self.head_number,
"alpha": self.scale
}
self.outputs = {"Out": reshape_qkv}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
class TestFusedMultiHeadMatmulOp2(TestFusedMultiheadMatmulOp):
def config(self):
self.seq_len = 256
self.size_per_head = 32
self.head_number = 12
self.batch_size = 8
self.scale = 0.125
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册