未验证 提交 2b4f44d5 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add fusion kernel dir and migrate fused_softmax_mask op (#45802)

* add fusion dir and fuse_softmax_mask kernel

* remove fusion kernel dir

* migrate infershape

* fix code errror
上级 4a675b7a
......@@ -1890,7 +1890,7 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
PADDLE_ENFORCE_NE(
kernels_iter,
all_op_kernels.end(),
platform::errors::Unavailable(
platform::errors::Unimplemented(
"There are no kernels which are registered in the %s operator.",
type_));
......
......@@ -11,10 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_softmax_mask_op.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -23,30 +27,6 @@ using framework::Tensor;
class SoftmaxMaskFuseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SoftmaxMaskFuse");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "SoftmaxMaskFuse");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SoftmaxMaskFuse");
auto x_dims = ctx->GetInputDim("X");
auto mask_dims = ctx->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(
x_dims.size(),
4,
platform::errors::InvalidArgument("Input x must be in 4D dimension but "
"received the dimension of X is %d",
x_dims.size()));
PADDLE_ENFORCE_EQ(mask_dims.size(),
4,
platform::errors::InvalidArgument(
"Input mask must be in 4D dimension but "
"received the dimension of mask is %d",
mask_dims.size()));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", "Out");
}
};
class SoftmaxMaskFuseOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -80,17 +60,6 @@ By doing this fusion, we can optimize the training by
class SoftmaxMaskFuseOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"SoftmaxMaskFuseGrad");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
};
template <typename T>
......@@ -111,12 +80,18 @@ class SoftmaxMaskFuseGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fused_softmax_mask,
SoftmaxMaskFuseInferShapeFunctor,
PD_INFER_META(phi::SoftmaxMaskFuseInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(fused_softmax_mask_grad,
SoftmaxMaskFuseGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(fused_softmax_mask,
ops::SoftmaxMaskFuseOp,
ops::SoftmaxMaskFuseOpMaker,
ops::SoftmaxMaskFuseGradOpMaker<paddle::framework::OpDesc>,
ops::SoftmaxMaskFuseGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_softmax_mask_grad, ops::SoftmaxMaskFuseOpGrad);
REGISTER_OP_CPU_KERNEL(fused_softmax_mask,
ops::SoftmaxMaskFuseCPUKernel<phi::CPUContext, float>,
ops::SoftmaxMaskFuseCPUKernel<phi::CPUContext, double>);
ops::SoftmaxMaskFuseGradOpMaker<paddle::imperative::OpBase>,
SoftmaxMaskFuseInferShapeFunctor);
REGISTER_OPERATOR(fused_softmax_mask_grad,
ops::SoftmaxMaskFuseOpGrad,
SoftmaxMaskFuseGradInferShapeFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// this file is inspired by:
// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/fused_kernels/scaled_masked_softmax.h
/* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#endif
#include <stdint.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fused_softmax_mask_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using framework::Tensor;
#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#define MASK 0xffffffff
namespace plat = paddle::platform;
__device__ __inline__ void load_data(plat::float16* dst,
const plat::float16* src) {
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}
__device__ __inline__ void load_data(float* dst, const float* src) {
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
}
int get_pow2(int value) {
// get next pow2 index
int pow2_index = 0;
while ((1 << pow2_index) < value) {
++pow2_index;
}
return pow2_index;
}
template <typename T>
struct AddOP {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct MaxOP {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T
warp_shfl_xor(T value, int laneMask, int width, unsigned int mask = MASK) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename T, int batch, int width, template <typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(T* sum) {
ReduceOp<T> r;
#pragma unroll
for (int offset = width / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < batch; ++i) {
T b = warp_shfl_xor(sum[i], offset, width);
sum[i] = r(sum[i], b);
}
}
}
// T == fp16
template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
const T* mask_data,
T* y_data,
int batch_count,
int key_seq_len) {
// the forward gpu kernel
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int data_first_idx =
(blockDim.y *
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
threadIdx.y) *
kLocalBatchSize;
int mask_fist_idx =
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
kLocalBatchSize;
// batch_count might not be a multiple of kLocalBatchSize. Check how
// many batches have to computed within this WARP.
int local_batches = batch_count - data_first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
// might be many batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
int x_offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx;
int mask_offset = mask_fist_idx * key_seq_len + kOneLoadingCounts * local_idx;
x_data += x_offset;
mask_data += mask_offset;
y_data += x_offset;
// using float for all inter compute
float data[kLocalBatchSize][kLocalIterations];
T temp_data[kOneLoadingCounts];
T temp_mask[kOneLoadingCounts];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_data = (i >= local_batches) ? 0 : key_seq_len;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int data_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (data_index < batch_data) {
int itr_idx = i * key_seq_len + ii * warp_size;
// efficiently load data from global memory
load_data(temp_data, x_data + itr_idx);
load_data(temp_mask, mask_data + itr_idx);
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
data[i][ii + counter] = static_cast<float>(temp_data[counter]) +
static_cast<float>(temp_mask[counter]);
}
} else {
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
data[i][ii + counter] = -std::numeric_limits<float>::infinity();
}
}
}
}
// compute max_value
// max value for each batch for current warp
float samples_max_value[kLocalBatchSize];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
samples_max_value[i] = data[i][0];
#pragma unroll
for (int ii = 1; ii < kLocalIterations; ++ii) {
samples_max_value[i] = (samples_max_value[i] > data[i][ii])
? samples_max_value[i]
: data[i][ii];
}
}
// max value for each batch for all warp
warp_reduce<float, kLocalBatchSize, warp_size, MaxOP>(samples_max_value);
// compute the sum for each batch for current warp
float samples_sum[kLocalBatchSize]{0.0f};
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ++ii) {
data[i][ii] = std::exp((data[i][ii] - samples_max_value[i]));
samples_sum[i] += data[i][ii];
}
}
// samples_sum for each batch for all warp
warp_reduce<float, kLocalBatchSize, warp_size, AddOP>(samples_sum);
// load the result from device back to host
T samples_out[kOneLoadingCounts];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int idx = kOneLoadingCounts * local_idx + ii * warp_size;
if (idx < key_seq_len) {
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
samples_out[counter] = data[i][ii + counter] / samples_sum[i];
}
load_data(y_data + i * key_seq_len + ii * warp_size, samples_out);
} else {
break;
}
}
}
}
template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseGradGPUKernel(const T* grad_input,
T* grad_output,
const T* softmax_rst,
int batch_count,
int key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int data_first_idx =
(blockDim.y * blockIdx.x + threadIdx.y) * kLocalBatchSize;
// batch_count might not be a multiple of kLocalBatchSize. Check how
// many batches have to computed within this WARP.
int local_batches = batch_count - data_first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
// might be many batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx;
grad_input += offset;
grad_output += offset;
softmax_rst += offset;
// using float for all inter compute
float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f};
float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f};
T temp_grad_input[kOneLoadingCounts];
T temp_softmax_rst[kOneLoadingCounts];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_data = (i >= local_batches) ? 0 : key_seq_len;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int data_index = kOneLoadingCounts * local_idx + ii * WARP_SIZE;
if (data_index < batch_data) {
load_data(temp_grad_input,
grad_input + i * key_seq_len + ii * warp_size);
load_data(temp_softmax_rst,
softmax_rst + i * key_seq_len + ii * warp_size);
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
softmax_rst_reg[i][ii + counter] =
static_cast<float>(temp_softmax_rst[counter]);
}
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
grad_input_reg[i][ii + counter] =
static_cast<float>(temp_grad_input[counter]) *
softmax_rst_reg[i][ii + counter];
}
}
}
}
float samples_sum[kLocalBatchSize];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
samples_sum[i] = grad_input_reg[i][0];
#pragma unroll
for (int ii = 1; ii < kLocalIterations; ++ii) {
samples_sum[i] += grad_input_reg[i][ii];
}
}
warp_reduce<float, kLocalBatchSize, warp_size, AddOP>(samples_sum);
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int data_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (data_index < key_seq_len) {
// compute gradients
T samples_out[kOneLoadingCounts];
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
samples_out[counter] =
grad_input_reg[i][ii + counter] -
softmax_rst_reg[i][ii + counter] * samples_sum[i];
}
load_data(grad_output + i * key_seq_len + ii * warp_size, samples_out);
}
}
}
}
// T only supports fp16
// leave as template only for future update
template <typename Place, typename T>
class SoftmaxMaskFuseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* mask = context.Input<Tensor>("Mask");
auto* y = context.Output<Tensor>("Out");
auto* x_data = x->data<T>();
auto* mask_data = mask->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
auto x_dim = x->dims();
auto mask_dim = mask->dims();
auto batches = x_dim[0];
auto attn_heads = x_dim[1];
auto query_seq_len = x_dim[2];
auto key_seq_len = x_dim[3];
PADDLE_ENFORCE_GT(query_seq_len,
1,
platform::errors::InvalidArgument(
"Input x's second last dim must be large than 1 but "
"received the second last dimension of x is %d",
query_seq_len));
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192,
true,
platform::errors::InvalidArgument(
"Input x's last dim must be between [32, 8192) "
"received the last dimension of x is %d",
key_seq_len));
PADDLE_ENFORCE_EQ(mask_dim[1],
1,
platform::errors::InvalidArgument(
"Input mask's second dim must be 1 "
"received the second dimension of mask is %d",
mask_dim[1]));
// dim of x and mask must be equal
for (size_t idx = 0; idx < 4; ++idx) {
if (idx == 1) continue;
PADDLE_ENFORCE_EQ(
x_dim[idx],
mask_dim[idx],
platform::errors::InvalidArgument(
"Input x's %dth dim should be equal with input mask's %dth dim "
"but "
"received the %dth dimension of x and mask are not equal "
"the %dth dim of x is %d, while the %dth dim of mask is %d.",
idx,
idx,
idx,
idx,
x_dim[idx],
idx,
mask_dim[idx]));
}
auto& place = *context.template device_context<Place>().eigen_device();
auto stream = context.cuda_device_context().stream();
int pow2_index = get_pow2(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = batches * attn_heads * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
PADDLE_ENFORCE_EQ(
query_seq_len % batches_per_block,
0,
platform::errors::InvalidArgument(
"The query seq len (third dim of input X) must can divide the "
"number of batches per block. The query seq len is %d, while "
"the number of batches per block is %d.",
query_seq_len,
batches_per_block));
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// launch the kernel based on the pow2_index
switch (pow2_index) {
case 5: // 32
SoftmaxMaskFuseGPUKernel<T, 5><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 6: // 64
SoftmaxMaskFuseGPUKernel<T, 6><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 7: // 128
SoftmaxMaskFuseGPUKernel<T, 7><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 8: // 256
SoftmaxMaskFuseGPUKernel<T, 8><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 9: // 512
SoftmaxMaskFuseGPUKernel<T, 9><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 10: // 1024
SoftmaxMaskFuseGPUKernel<T, 10><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 11: // 2048
SoftmaxMaskFuseGPUKernel<T, 11><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 12: // 4096
SoftmaxMaskFuseGPUKernel<T, 12><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 13: // 8192
SoftmaxMaskFuseGPUKernel<T, 13><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
default:
break;
}
}
};
template <typename Place, typename T>
class SoftmaxMaskFuseGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* softmax_rst = context.Input<Tensor>("Softmax");
auto* grad_x_data = grad_x->mutable_data<T>(context.GetPlace());
auto* grad_y_data = grad_y->data<T>();
auto* softmax_rst_data = softmax_rst->data<T>();
auto y_dim = grad_y->dims();
auto batches = y_dim[0];
auto attn_heads = y_dim[1];
auto query_seq_len = y_dim[2];
auto key_seq_len = y_dim[3];
auto& place = *context.template device_context<Place>().eigen_device();
auto stream = context.cuda_device_context().stream();
int pow2_index = get_pow2(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = batches * attn_heads * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch the kernel based on the pow2_index
switch (pow2_index) {
case 5: // 32
SoftmaxMaskFuseGradGPUKernel<T, 5>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 6: // 64
SoftmaxMaskFuseGradGPUKernel<T, 6>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 7: // 128
SoftmaxMaskFuseGradGPUKernel<T, 7>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 8: // 256
SoftmaxMaskFuseGradGPUKernel<T, 8>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 9: // 512
SoftmaxMaskFuseGradGPUKernel<T, 9>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 10: // 1024
SoftmaxMaskFuseGradGPUKernel<T, 10>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 11: // 2048
SoftmaxMaskFuseGradGPUKernel<T, 11>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 12: // 4096
SoftmaxMaskFuseGradGPUKernel<T, 12>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
case 13: // 8192
SoftmaxMaskFuseGradGPUKernel<T, 13>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
default:
break;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask,
ops::SoftmaxMaskFuseKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_grad,
ops::SoftmaxMaskFuseGradKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseGradKernel<phi::GPUContext, float>);
......@@ -2342,6 +2342,28 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
}
}
void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out) {
auto x_dims = x.dims();
auto mask_dims = mask.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
4,
phi::errors::InvalidArgument("Input x must be in 4D dimension but "
"received the dimension of X is %d",
x_dims.size()));
PADDLE_ENFORCE_EQ(
mask_dims.size(),
4,
phi::errors::InvalidArgument("Input mask must be in 4D dimension but "
"received the dimension of mask is %d",
mask_dims.size()));
out->share_meta(x);
}
void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
......
......@@ -358,6 +358,10 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
bool right,
MetaTensor* out);
void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out);
void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
......
......@@ -95,8 +95,8 @@ file(
"kps/*.cu"
"selected_rows/gpu/*.cu"
"sparse/gpu/*.cu"
"strings/*.cu"
"strings/gpu/*.cu")
"strings/gpu/*.cu"
"fusion/gpu/*.cu")
if(WITH_MKLDNN)
file(
......@@ -110,7 +110,9 @@ if(WITH_MKLDNN)
"sparse/cpu/*.cc"
"strings/*.cc"
"strings/cpu/*.cc"
"onednn/*.cc")
"onednn/*.cc"
"fusion/*.cc"
"fusion/cpu/*.cc")
else()
file(
GLOB
......@@ -122,10 +124,12 @@ else()
"sparse/*.cc"
"sparse/cpu/*.cc"
"strings/*.cc"
"strings/cpu/*.cc")
"strings/cpu/*.cc"
"fusion/*.cc"
"fusion/cpu/*.cc")
endif()
file(GLOB kernel_xpu "xpu/*.cc" "selected_rows/xpu/*.cc")
file(GLOB kernel_xpu "xpu/*.cc" "selected_rows/xpu/*.cc" "fusion/xpu/*.cc")
add_library(phi_cpu ${kernel_cc})
kernel_declare("${kernel_cc}")
......
# What's difference for fusion kernel?
1. We don't recommend to implement Python API for fusion kernel
- We don't recommend to implement Python API for fusion kernel, because it contains many inputs or outputs arguments generally, it is difficult to use and understand as an Python API, we recommend to call fusion kernel by pass optimization in dy2static mode or static mode.
- We also don't recommend to reuse fusion kernel in other kernel implementation, but recommended that the fusion kernel be implemented by reusing other kernels.
2. We don't require fusion kernel to have implementations for all devices
- Fusion Kernel is generally used to accelerate the combined operation on a certain device. If all devices need to be implemented, the cost is relatively high.
- We don't recommend implementing a pseudo kernel that just throws exception, if not required, it can be not implemented.
3. Fusion Kernel needs to be in the `phi/fusion` namespace
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxMaskFuseGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxMaskFuseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h"
namespace phi {
namespace fusion {
template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseGradGPUKernel(const T* grad_input,
T* grad_output,
const T* softmax_rst,
int batch_count,
int key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int data_first_idx =
(blockDim.y * blockIdx.x + threadIdx.y) * kLocalBatchSize;
// batch_count might not be a multiple of kLocalBatchSize. Check how
// many batches have to computed within this WARP.
int local_batches = batch_count - data_first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
// might be many batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx;
grad_input += offset;
grad_output += offset;
softmax_rst += offset;
// using float for all inter compute
float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f};
float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f};
T temp_grad_input[kOneLoadingCounts];
T temp_softmax_rst[kOneLoadingCounts];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_data = (i >= local_batches) ? 0 : key_seq_len;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int data_index = kOneLoadingCounts * local_idx + ii * WARP_SIZE;
if (data_index < batch_data) {
load_data(temp_grad_input,
grad_input + i * key_seq_len + ii * warp_size);
load_data(temp_softmax_rst,
softmax_rst + i * key_seq_len + ii * warp_size);
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
softmax_rst_reg[i][ii + counter] =
static_cast<float>(temp_softmax_rst[counter]);
}
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
grad_input_reg[i][ii + counter] =
static_cast<float>(temp_grad_input[counter]) *
softmax_rst_reg[i][ii + counter];
}
}
}
}
float samples_sum[kLocalBatchSize];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
samples_sum[i] = grad_input_reg[i][0];
#pragma unroll
for (int ii = 1; ii < kLocalIterations; ++ii) {
samples_sum[i] += grad_input_reg[i][ii];
}
}
warp_reduce<float, kLocalBatchSize, warp_size, AddOP>(samples_sum);
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int data_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (data_index < key_seq_len) {
// compute gradients
T samples_out[kOneLoadingCounts];
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
samples_out[counter] =
grad_input_reg[i][ii + counter] -
softmax_rst_reg[i][ii + counter] * samples_sum[i];
}
load_data(grad_output + i * key_seq_len + ii * warp_size, samples_out);
}
}
}
}
template <typename T, typename Context>
void SoftmaxMaskFuseGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto* grad_x_data = dev_ctx.template Alloc<T>(x_grad);
auto* grad_y_data = out_grad.data<T>();
auto* softmax_rst_data = out.data<T>();
auto y_dim = out_grad.dims();
auto batches = y_dim[0];
auto attn_heads = y_dim[1];
auto query_seq_len = y_dim[2];
auto key_seq_len = y_dim[3];
auto stream = dev_ctx.stream();
int pow2_index = get_pow2(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = batches * attn_heads * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// launch the kernel based on the pow2_index
switch (pow2_index) {
case 5: // 32
SoftmaxMaskFuseGradGPUKernel<T, 5><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 6: // 64
SoftmaxMaskFuseGradGPUKernel<T, 6><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 7: // 128
SoftmaxMaskFuseGradGPUKernel<T, 7><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 8: // 256
SoftmaxMaskFuseGradGPUKernel<T, 8><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 9: // 512
SoftmaxMaskFuseGradGPUKernel<T, 9><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 10: // 1024
SoftmaxMaskFuseGradGPUKernel<T, 10><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 11: // 2048
SoftmaxMaskFuseGradGPUKernel<T, 11><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 12: // 4096
SoftmaxMaskFuseGradGPUKernel<T, 12><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
case 13: // 8192
SoftmaxMaskFuseGradGPUKernel<T, 13><<<blocks, threads, 0, stream>>>(
grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len);
break;
default:
break;
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_softmax_mask_grad,
GPU,
ALL_LAYOUT,
phi::fusion::SoftmaxMaskFuseGradKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h"
namespace phi {
namespace fusion {
// T == fp16
template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
const T* mask_data,
T* y_data,
int batch_count,
int key_seq_len) {
// the forward gpu kernel
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int data_first_idx =
(blockDim.y *
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
threadIdx.y) *
kLocalBatchSize;
int mask_fist_idx =
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
kLocalBatchSize;
// batch_count might not be a multiple of kLocalBatchSize. Check how
// many batches have to computed within this WARP.
int local_batches = batch_count - data_first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
// might be many batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
int x_offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx;
int mask_offset = mask_fist_idx * key_seq_len + kOneLoadingCounts * local_idx;
x_data += x_offset;
mask_data += mask_offset;
y_data += x_offset;
// using float for all inter compute
float data[kLocalBatchSize][kLocalIterations];
T temp_data[kOneLoadingCounts];
T temp_mask[kOneLoadingCounts];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_data = (i >= local_batches) ? 0 : key_seq_len;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int data_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (data_index < batch_data) {
int itr_idx = i * key_seq_len + ii * warp_size;
// efficiently load data from global memory
load_data(temp_data, x_data + itr_idx);
load_data(temp_mask, mask_data + itr_idx);
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
data[i][ii + counter] = static_cast<float>(temp_data[counter]) +
static_cast<float>(temp_mask[counter]);
}
} else {
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
data[i][ii + counter] = -std::numeric_limits<float>::infinity();
}
}
}
}
// compute max_value
// max value for each batch for current warp
float samples_max_value[kLocalBatchSize];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
samples_max_value[i] = data[i][0];
#pragma unroll
for (int ii = 1; ii < kLocalIterations; ++ii) {
samples_max_value[i] = (samples_max_value[i] > data[i][ii])
? samples_max_value[i]
: data[i][ii];
}
}
// max value for each batch for all warp
warp_reduce<float, kLocalBatchSize, warp_size, MaxOP>(samples_max_value);
// compute the sum for each batch for current warp
float samples_sum[kLocalBatchSize]{0.0f};
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ++ii) {
data[i][ii] = std::exp((data[i][ii] - samples_max_value[i]));
samples_sum[i] += data[i][ii];
}
}
// samples_sum for each batch for all warp
warp_reduce<float, kLocalBatchSize, warp_size, AddOP>(samples_sum);
// load the result from device back to host
T samples_out[kOneLoadingCounts];
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int idx = kOneLoadingCounts * local_idx + ii * warp_size;
if (idx < key_seq_len) {
#pragma unroll
for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
samples_out[counter] = data[i][ii + counter] / samples_sum[i];
}
load_data(y_data + i * key_seq_len + ii * warp_size, samples_out);
} else {
break;
}
}
}
}
// T only supports fp16
// leave as template only for future update
template <typename T, typename Context>
void SoftmaxMaskFuseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto* mask_data = mask.data<T>();
auto* y_data = dev_ctx.template Alloc<T>(out);
auto x_dim = x.dims();
auto mask_dim = mask.dims();
auto batches = x_dim[0];
auto attn_heads = x_dim[1];
auto query_seq_len = x_dim[2];
auto key_seq_len = x_dim[3];
PADDLE_ENFORCE_GT(query_seq_len,
1,
phi::errors::InvalidArgument(
"Input x's second last dim must be large than 1 but "
"received the second last dimension of x is %d",
query_seq_len));
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192,
true,
phi::errors::InvalidArgument(
"Input x's last dim must be between [32, 8192) "
"received the last dimension of x is %d",
key_seq_len));
PADDLE_ENFORCE_EQ(mask_dim[1],
1,
phi::errors::InvalidArgument(
"Input mask's second dim must be 1 "
"received the second dimension of mask is %d",
mask_dim[1]));
// dim of x and mask must be equal
for (size_t idx = 0; idx < 4; ++idx) {
if (idx == 1) continue;
PADDLE_ENFORCE_EQ(
x_dim[idx],
mask_dim[idx],
phi::errors::InvalidArgument(
"Input x's %dth dim should be equal with input mask's %dth dim "
"but "
"received the %dth dimension of x and mask are not equal "
"the %dth dim of x is %d, while the %dth dim of mask is %d.",
idx,
idx,
idx,
idx,
x_dim[idx],
idx,
mask_dim[idx]));
}
auto stream = dev_ctx.stream();
int pow2_index = get_pow2(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = batches * attn_heads * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
PADDLE_ENFORCE_EQ(
query_seq_len % batches_per_block,
0,
phi::errors::InvalidArgument(
"The query seq len (third dim of input X) must can divide the "
"number of batches per block. The query seq len is %d, while "
"the number of batches per block is %d.",
query_seq_len,
batches_per_block));
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// launch the kernel based on the pow2_index
switch (pow2_index) {
case 5: // 32
SoftmaxMaskFuseGPUKernel<T, 5><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 6: // 64
SoftmaxMaskFuseGPUKernel<T, 6><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 7: // 128
SoftmaxMaskFuseGPUKernel<T, 7><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 8: // 256
SoftmaxMaskFuseGPUKernel<T, 8><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 9: // 512
SoftmaxMaskFuseGPUKernel<T, 9><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 10: // 1024
SoftmaxMaskFuseGPUKernel<T, 10><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 11: // 2048
SoftmaxMaskFuseGPUKernel<T, 11><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 12: // 4096
SoftmaxMaskFuseGPUKernel<T, 12><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
case 13: // 8192
SoftmaxMaskFuseGPUKernel<T, 13><<<blocks, threads, 0, stream>>>(
x_data, mask_data, y_data, batch_count, key_seq_len);
break;
default:
break;
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_softmax_mask,
GPU,
ALL_LAYOUT,
phi::fusion::SoftmaxMaskFuseKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#define MASK 0xffffffff
namespace phi {
namespace fusion {
__device__ __inline__ void load_data(dtype::float16* dst,
const dtype::float16* src) {
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}
__device__ __inline__ void load_data(float* dst, const float* src) {
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
}
inline int get_pow2(int value) {
// get next pow2 index
int pow2_index = 0;
while ((1 << pow2_index) < value) {
++pow2_index;
}
return pow2_index;
}
template <typename T>
struct AddOP {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct MaxOP {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T
warp_shfl_xor(T value, int laneMask, int width, unsigned int mask = MASK) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename T, int batch, int width, template <typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(T* sum) {
ReduceOp<T> r;
#pragma unroll
for (int offset = width / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < batch; ++i) {
T b = warp_shfl_xor(sum[i], offset, width);
sum[i] = r(sum[i], b);
}
}
}
} // namespace fusion
} // namespace phi
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,22 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SoftmaxMaskFuseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()),
true,
platform::errors::Unimplemented(
"Softmax mask fuse op only supports GPU now."));
}
};
} // namespace operators
} // namespace paddle
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SoftmaxMaskFuseGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"fused_softmax_mask_grad", {"Softmax", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_softmax_mask_grad,
phi::SoftmaxMaskFuseGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册