diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 0f675b68d274b3d30ffb83b56fb1d3d71be5d901..3d3ce56d89df41a47536f7aa5d8b430b198bd16d 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -116,7 +116,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "multihead_matmul_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index f99cbc8762aab5ae420c48624b204cbff438a15a..886ff49f04909b35dbb682dbd484d1cfead0963a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -55,7 +55,7 @@ if (NOT WITH_MKL) endif() register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op - sync_batch_norm_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) + sync_batch_norm_op multihead_matmul_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) if (WITH_GPU) # warpctc_op needs cudnn 7 above @@ -73,6 +73,8 @@ if (WITH_GPU) op_library(sync_batch_norm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") endif() + op_library(multihead_matmul_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n") else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() diff --git a/paddle/fluid/operators/multihead_matmul_op.cc b/paddle/fluid/operators/multihead_matmul_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b612be02b4f50ff1d50c6d8a3e1e0c5c1e9f61c6 --- /dev/null +++ b/paddle/fluid/operators/multihead_matmul_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" + +namespace paddle { +namespace operators { + +class MultiHeadMatMulOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInput("Q"), true, + "Input(Q) of MultiheadOp should not be null."); + PADDLE_ENFORCE_EQ(context->HasInput("K"), true, + "Input(K) of MultiheadOp should not be null."); + PADDLE_ENFORCE_EQ(context->HasInput("V"), true, + "Input(V) of MultiheadOp should not be null."); + PADDLE_ENFORCE_EQ(context->HasInput("BiasQ"), true, + "Input(BiasQ) of MultiheadOp should not be null."); + PADDLE_ENFORCE_EQ(context->HasInput("BiasK"), true, + "Input(BiasQ) of MultiheadOp should not be null."); + PADDLE_ENFORCE_EQ(context->HasInput("BiasV"), true, + "Input(BiasQ) of MultiheadOp should not be null."); + PADDLE_ENFORCE_EQ(context->HasInput("BiasQK"), true, + "Input(BiasQK) of MultiheadOp should not be null."); + PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true, + "Output(Out) of MatMulOp should not be null."); + + auto dim_q = context->GetInputDim("Q"); + PADDLE_ENFORCE_GT(dim_q.size(), 2, + "Multihead input should be at least 3-D tensor."); + + auto dim_k = context->GetInputDim("K"); + PADDLE_ENFORCE_GT(dim_q.size(), 2, + "Multihead input should be at least 3-D tensor."); + + auto dim_v = context->GetInputDim("V"); + PADDLE_ENFORCE_GT(dim_q.size(), 2, + "Multihead input should be at least 3-D tensor."); + + PADDLE_ENFORCE_EQ(dim_q[0], dim_k[0], + "Multihead input should have same batch size"); + PADDLE_ENFORCE_EQ(dim_q[0], dim_v[0], + "Multihead input should have same batch size"); + + PADDLE_ENFORCE_EQ(dim_q[1], dim_k[1], + "Multihead input should have same size"); + PADDLE_ENFORCE_EQ(dim_q[1], dim_v[1], + "Multihead input should have same size"); + + PADDLE_ENFORCE_EQ(dim_q[2], dim_k[2], + "Multihead input should have same size"); + PADDLE_ENFORCE_EQ(dim_q[2], dim_v[2], + "Multihead input should have same size"); + + auto dim_bias_q = context->GetInputDim("BiasQ"); + PADDLE_ENFORCE_GT(dim_bias_q.size(), 0, + "Multihead input should be at least 1-D tensor."); + auto dim_bias_k = context->GetInputDim("BiasK"); + PADDLE_ENFORCE_GT(dim_bias_k.size(), 0, + "Multihead input should be at least 1-D tensor."); + auto dim_bias_v = context->GetInputDim("BiasV"); + PADDLE_ENFORCE_GT(dim_bias_v.size(), 0, + "Multihead input should be at least 1-D tensor."); + + PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_k[0], + "Multihead input bias should have same batch size"); + PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0], + "Multihead input bias should have same batch size"); + + PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_k[1], + "Multihead input bias should have same size"); + PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_v[1], + "Multihead input bias should have same size"); + + auto dim_bias_qk = context->GetInputDim("BiasQK"); + PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3, + "Multihead input bias qk should be at least 4-D tensor."); + + int head_number = context->Attrs().Get("head_number"); + PADDLE_ENFORCE_GT(head_number, 1, + "Multihead input head number should be at least 1."); + + context->SetOutputDim("Out", dim_q); + context->ShareLoD("Q", /*->*/ "Out"); + } +}; + +class MultiHeadMatMulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Q", "The first input of MultiHeadMatMul op"); + AddInput("K", "The second input of MMultiHeadMatMul op"); + AddInput("V", "The third input of MultiHeadMatMul op"); + AddInput("BiasQ", "The first bias input of MultiHeadMatMul op"); + AddInput("BiasK", "The second bias input of MultiHeadMatMul op"); + AddInput("BiasV", "The third bias input of MultiHeadMatMul op"); + AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op"); + AddOutput("Out", "The output of MultiHeadMatMul op"); + AddAttr("transpose_Q", + R"DOC(If true, use the transpose of `Q`. + )DOC") + .SetDefault(false); + AddAttr("transpose_K", + R"DOC(If true, use the transpose of `K`. + )DOC") + .SetDefault(true); + AddAttr("transpose_V", + R"DOC(If true, use the transpose of `V`. + )DOC") + .SetDefault(false); + AddAttr("alpha", "The scale of Out").SetDefault(1.0f); + AddAttr("head_number", "The number of heads of the matrix") + .SetDefault(1); + AddComment(R"DOC( +MultiHeadMatMul Operator. + +This op is used for optimize multi head calculation in ernie model. +Not suggest to use in other case except has same structure as ernie. + +Example of matrix multiplication with head_number of H +- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] + +Both the input `Q` and `K` can carry the LoD (Level of Details) information, +or not. But the output only shares the LoD information with input `Q`, because +they are the same. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulOp, + ops::MultiHeadMatMulOpMaker); diff --git a/paddle/fluid/operators/multihead_matmul_op.cu b/paddle/fluid/operators/multihead_matmul_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e8aa712fbf00355b83bde5313ba0d04724e2ffb --- /dev/null +++ b/paddle/fluid/operators/multihead_matmul_op.cu @@ -0,0 +1,394 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { + +#define FINAL_MASK 0xffffffff +#define HALF_WARP 16 +#define WARP_SIZE 32 + +template +__inline__ __device__ T warpReduceSum(T val) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) +#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 + val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize); +#else + val += __shfl_xor(val, mask, warpSize); +#endif + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[WARP_SIZE]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + + return val; +} + +template +__inline__ __device__ T warpReduceMax(T val) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) +#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize)); +#else + val = max(val, __shfl_xor(val, mask, warpSize)); +#endif + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) { + static __shared__ T shared[WARP_SIZE]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceMax(val); + + if (lane == 0) shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f; + val = warpReduceMax(val); + + return val; +} + +template +__global__ void add_QKV(const T *Q, const T *K, const T *V, T *q_buf_, + T *k_buf_, T *v_buf_, const T *bias_q, const T *bias_k, + const T *bias_v, int batch_size, int seq_len, + int head_num, int size_per_head) { + const T *data_ptr_q, *data_ptr_k, *data_ptr_v; + const T *bias_ptr_q, *bias_ptr_k, *bias_ptr_v; + + int m = batch_size * seq_len; + int n = head_num * size_per_head; + + int row_offset = (blockIdx.x % m) * n; + + data_ptr_q = Q + row_offset; + data_ptr_k = K + row_offset; + data_ptr_v = V + row_offset; + // bias ptr + bias_ptr_q = bias_q; + bias_ptr_k = bias_k; + bias_ptr_v = bias_v; + + int batch_id = (blockIdx.x % m) / seq_len; + int head_id = threadIdx.x / size_per_head; + int id_in_head = threadIdx.x % size_per_head; + int word_start_id = (blockIdx.x) % seq_len; + +#if __CUDA_ARCH__ >= 350 + T tmp_q = __ldg(&data_ptr_q[threadIdx.x]) + __ldg(&bias_ptr_q[threadIdx.x]); + T tmp_k = __ldg(&data_ptr_k[threadIdx.x]) + __ldg(&bias_ptr_k[threadIdx.x]); + T tmp_v = __ldg(&data_ptr_v[threadIdx.x]) + __ldg(&bias_ptr_v[threadIdx.x]); +#else + T tmp_q = data_ptr_q[threadIdx.x] + bias_ptr_q[threadIdx.x]; + T tmp_k = data_ptr_k[threadIdx.x] + bias_ptr_k[threadIdx.x]; + T tmp_v = data_ptr_v[threadIdx.x] + bias_ptr_v[threadIdx.x]; +#endif + + int target_id = batch_id * (seq_len * head_num * size_per_head) + + head_id * seq_len * size_per_head + + word_start_id * size_per_head + id_in_head; + + q_buf_[target_id] = tmp_q; + k_buf_[target_id] = tmp_k; + v_buf_[target_id] = tmp_v; +} + +// Keep to compare performance +template +__global__ void add_QKV_V2(const T *Q, const T *K, const T *V, T *q_buf_, + T *k_buf_, T *v_buf_, const T *bias_Q, + const T *bias_K, const T *bias_V, int batch_size, + int seq_len, int head_num, int size_per_head, + const int word_per_block) { + const T *data_ptr; + T *buf_ptr; + const T *bias_ptr; + + int m = batch_size * seq_len; + int n = head_num * size_per_head; + + int qkv_id = blockIdx.x * word_per_block / m; + int row_offset = (blockIdx.x * word_per_block % m) * n; + + if (qkv_id == 0) { + data_ptr = Q + row_offset; + buf_ptr = q_buf_; + bias_ptr = bias_Q; + } else if (qkv_id == 1) { + data_ptr = K + row_offset; + buf_ptr = k_buf_; + bias_ptr = bias_K; + } else { + data_ptr = V + row_offset; + buf_ptr = v_buf_; + bias_ptr = bias_V; + } + + int batch_id = (blockIdx.x * word_per_block % m) / seq_len; + int head_id = threadIdx.x / size_per_head; + int id_in_head = threadIdx.x % size_per_head; + int word_start_id = (blockIdx.x * word_per_block) % seq_len; + +#if __CUDA_ARCH__ >= 350 + T bias = __ldg(&bias_ptr[threadIdx.x]); +#else + T bias = bias_ptr[threadIdx.x]; +#endif + + for (int i = word_start_id; i < word_start_id + word_per_block; ++i) { + T tmp = data_ptr[threadIdx.x] + bias; + + int target_id = batch_id * (seq_len * head_num * size_per_head) + + head_id * seq_len * size_per_head + i * size_per_head + + id_in_head; + + buf_ptr[target_id] = tmp; + data_ptr += n; + } +} + +template +__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, + const int batch_size, + const int head_num, + const int seq_len) { + int seq_id = blockIdx.x % seq_len; + int qk_offset = blockIdx.x * seq_len; + int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len; + + __shared__ float s_sum, s_max; + + float qk = threadIdx.x < seq_len + ? static_cast((qk_buf_[threadIdx.x + qk_offset] + + bias_qk_[threadIdx.x + bias_offset])) + : 0.0f; + float tmp = threadIdx.x < seq_len ? static_cast(qk) : -1e20f; + float max_val = blockReduceMax(tmp); + if (threadIdx.x == 0) s_max = max_val; + __syncthreads(); + + float qk_tmp = + threadIdx.x < seq_len ? __expf(static_cast(tmp - s_max)) : 0.0f; + float sum_val = blockReduceSum(qk_tmp); + + if (threadIdx.x == 0) { + s_sum = sum_val + 1e-6f; + } + __syncthreads(); + + if (threadIdx.x < seq_len) + qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum); +} + +// For verify result +template +__global__ void elt_qk_add(const T *bias_qk, T *qk_buf, int head_num, + int seq_len, int size_per_head, int batch_size) { + int m = batch_size * head_num * seq_len; + int row_id = blockIdx.x % m; + int dst_id = row_id * seq_len + threadIdx.x; + const T *bias_ptr = bias_qk; +#if __CUDA_ARCH__ >= 350 + int tmp_bias = __ldg(&bias_ptr[dst_id]); +#else + int tmp_bias = bias_ptr[dst_id]; +#endif + + qk_buf[dst_id] += tmp_bias; +} + +// Compute Q*K->softmax->eltadd +template +void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num, + int seq_len, int size_per_head, int batch_size, + bool q_trans, bool k_trans, T *q_buf_, T *k_buf_, + T *qk_buf_, const T *bias_qk, T alpha, T beta) { + CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans; + + auto blas = math::GetBlas(context); + auto stream = context.stream(); + + blas.BatchedGEMM(transA, transB, seq_len, seq_len, size_per_head, alpha, + q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num, + seq_len * size_per_head, seq_len * size_per_head); + + int m = batch_size * head_num * seq_len; + int k = seq_len; + + int grid = m; + int block = k; + + softmax_kernel_with_eltadd<<>>( + qk_buf_, bias_qk, batch_size, head_num, seq_len); +} + +template +__global__ void transpose(T *src, T *dst, const int batch_size, + const int seq_len, const int head_num, + const int size_per_head) { + int batch_id = blockIdx.x / (head_num * seq_len); + int seq_id = blockIdx.x % seq_len; + int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; + dst[batch_id * (head_num * seq_len * size_per_head) + + seq_id * head_num * size_per_head + head_id * size_per_head + + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; +} + +// Compute QK*V->transpose +template +void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num, + int seq_len, int size_per_head, int batch_size, + bool qk_trans, bool v_trans, T *v_buf_, const T *qk_buf_, + T *dst, T *out, T alpha, T beta) { + int m = batch_size * seq_len; + int k = head_num * size_per_head; + + auto blas = math::GetBlas(context); + auto stream = context.stream(); + CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans; + + blas.BatchedGEMM(transA, transB, seq_len, size_per_head, seq_len, alpha, + qk_buf_, v_buf_, beta, dst, batch_size * head_num, + seq_len * seq_len, seq_len * size_per_head); + + int grid = batch_size * head_num * seq_len; + int block = size_per_head; + transpose<<>>(dst, out, batch_size, seq_len, + head_num, size_per_head); +} + +template +void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx, + int head_num, const framework::DDim &mat_q, + const framework::DDim &mat_k, + const framework::DDim &mat_v, const T *Q, const T *K, + const T *V, const T *bias_q, const T *bias_k, + const T *bias_v, const T *bias_qk, T *out, T alpha, + T beta, bool trans_q, bool trans_k, bool trans_v) { + int seq_len = mat_q[1]; + int size_per_head = (mat_q[2] / head_num); + int batch_size = mat_q[0]; + int buf_size = batch_size * head_num * seq_len * size_per_head; + int qk_buf_size = batch_size * head_num * seq_len * seq_len; + + auto alloc_buf = + memory::Alloc(dev_ctx, (buf_size * 4 + qk_buf_size) * sizeof(T)); + + T *buf = reinterpret_cast(alloc_buf->ptr()); + T *q_buf = buf; + T *k_buf = buf + buf_size; + T *v_buf = buf + 2 * buf_size; + T *qk_buf = buf + 3 * buf_size; + T *dst_buf = buf + 3 * buf_size + qk_buf_size; + + int m = batch_size * seq_len; + int k = head_num * size_per_head; + + // Each block process head*size-per_head element, + // have m lines. bias is m lines + auto blas = math::GetBlas(dev_ctx); + auto stream = dev_ctx.stream(); + + int grid = m; + PADDLE_ENFORCE_LT(k, 1024, + "Input head_number * size_per_head should <= 1024"); + int block = k <= 1024 ? k : 1024; + add_QKV<<>>(Q, K, V, q_buf, k_buf, v_buf, bias_q, + bias_k, bias_v, batch_size, seq_len, + head_num, size_per_head); + + MatMulWithHeadQK(dev_ctx, head_num, seq_len, size_per_head, batch_size, + trans_q, trans_k, q_buf, k_buf, qk_buf, bias_qk, alpha, + beta); + MatMulWithHeadQKV(dev_ctx, head_num, seq_len, size_per_head, batch_size, + false, trans_v, v_buf, qk_buf, dst_buf, out, T(1.0), + beta); +} + +template +class MultiHeadMatMulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *q = context.Input("Q"); + auto *k = context.Input("K"); + auto *v = context.Input("V"); + + auto &bias_q = detail::Ref(context.Input("BiasQ"), + "Cannot find BiasQ"); + auto &bias_k = detail::Ref(context.Input("BiasK"), + "Cannot find BiasK"); + auto &bias_v = detail::Ref(context.Input("BiasV"), + "Cannot find BiasV"); + + auto &bias_qk = detail::Ref(context.Input("BiasQK"), + "Cannot find QK"); + + auto *out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + T scale = static_cast(context.Attr("alpha")); + bool transpose_q = context.Attr("transpose_Q"); + bool transpose_k = context.Attr("transpose_K"); + bool transpose_v = context.Attr("transpose_V"); + + int head_number = context.Attr("head_number"); + // compute q*k with eltadd + auto &device_ctx = context.template device_context(); + + MultiHeadGPUCompute(device_ctx, head_number, q->dims(), k->dims(), + v->dims(), q->data(), k->data(), v->data(), + bias_q.data(), bias_k.data(), bias_v.data(), + bias_qk.data(), out->data(), scale, T(0.0), + transpose_q, transpose_k, transpose_v); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + multihead_matmul, + ops::MultiHeadMatMulKernel, + ops::MultiHeadMatMulKernel); diff --git a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e574b987d85828b43a1a8041f60a1558cb6d6d3a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py @@ -0,0 +1,116 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import core +import paddle.fluid as fluid + +np.random.random(123) + + +def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - np.max(x).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Paddle core is not compiled with CUDA") +class TestFusedMultiheadMatmulOp(OpTest): + def config(self): + self.seq_len = 128 + self.size_per_head = 64 + self.head_number = 12 + self.batch_size = 1 + self.scale = 0.125 + + def setUp(self): + self.op_type = "multihead_matmul" + self.config() + h = self.seq_len + w = self.head_number * self.size_per_head + self.Q = np.random.random((self.batch_size, h, w)).astype("float32") + self.K = np.random.random((self.batch_size, h, w)).astype("float32") + self.V = np.random.random((self.batch_size, h, w)).astype("float32") + self.BiasQ = np.random.random((1, w)).astype("float32") + self.BiasK = np.random.random((1, w)).astype("float32") + self.BiasV = np.random.random((1, w)).astype("float32") + self.BiasQK = np.random.random( + (1, self.head_number, self.seq_len, self.seq_len)).astype("float32") + # Compute Q path + fc_q = self.Q + self.BiasQ + reshape_q = np.reshape(fc_q, (self.batch_size, self.seq_len, + self.head_number, self.size_per_head)) + transpose_q = np.transpose(reshape_q, (0, 2, 1, 3)) + scale_q = self.scale * transpose_q + # Compute K path + fc_k = self.K + self.BiasK + reshape_k = np.reshape(fc_k, (self.batch_size, self.seq_len, + self.head_number, self.size_per_head)) + transpose_k = np.transpose(reshape_k, (0, 2, 3, 1)) + + # Compute Q*K + q_k = np.matmul(scale_q, transpose_k) + eltadd_qk = q_k + self.BiasQK + softmax_qk = np.apply_along_axis(stable_softmax, 3, eltadd_qk) + # Compute V path + fc_v = self.V + self.BiasV + reshape_v = np.reshape(fc_v, (self.batch_size, self.seq_len, + self.head_number, self.size_per_head)) + transpose_v = np.transpose(reshape_v, (0, 2, 1, 3)) + + # Compute QK*V + qkv = np.matmul(softmax_qk, transpose_v) + transpose_qkv = np.transpose(qkv, (0, 2, 1, 3)) + reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w)) + + self.inputs = { + "Q": self.Q, + "K": self.K, + "V": self.V, + "BiasQ": self.BiasQ, + "BiasK": self.BiasK, + "BiasV": self.BiasV, + "BiasQK": self.BiasQK + } + self.attrs = { + "transpose_Q": False, + "transpose_K": True, + "transpose_V": False, + "head_number": self.head_number, + "alpha": self.scale + } + self.outputs = {"Out": reshape_qkv} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=2e-3) + + +class TestFusedMultiHeadMatmulOp2(TestFusedMultiheadMatmulOp): + def config(self): + self.seq_len = 256 + self.size_per_head = 32 + self.head_number = 12 + self.batch_size = 8 + self.scale = 0.125 + + +if __name__ == '__main__': + unittest.main()