diff --git a/paddle/fluid/operators/fused_softmax_mask_op.cc b/paddle/fluid/operators/fused_softmax_mask_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a41380028338a12449fe3ba4b1b7425fabab82be --- /dev/null +++ b/paddle/fluid/operators/fused_softmax_mask_op.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/fused_softmax_mask_op.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SoftmaxMaskFuseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SoftmaxMaskFuse"); + OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "SoftmaxMaskFuse"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SoftmaxMaskFuse"); + auto x_dims = ctx->GetInputDim("X"); + auto mask_dims = ctx->GetInputDim("Mask"); + + PADDLE_ENFORCE_EQ( + x_dims.size(), 4, + platform::errors::InvalidArgument("Input x must be in 4D dimension but " + "received the dimension of X is %d", + x_dims.size())); + PADDLE_ENFORCE_EQ(mask_dims.size(), 4, + platform::errors::InvalidArgument( + "Input mask must be in 4D dimension but " + "received the dimension of mask is %d", + mask_dims.size())); + + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class SoftmaxMaskFuseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input of softmax_mask_fuse op, " + "which is the result of matmul(QK)/sqrt(dk)."); + AddInput("Mask", "The mask attr of the op, multi-head attention's mask"); + AddOutput("Out", "The result of softmax_mask_fuse op."); + + AddComment(R"DOC( +Softmax Mask Fuse Operator. +In general, the compute pass is: +product = matmul(QK)/sqrt(dk) +pre_softmax = product + attn_mask +output = softmax(pre_softmax) +To reduce the launch op time and reduce the number of forward and backward, +and to reduce the memory cost for the pre_softmax var during the compute +this op fuse last two operations into one, so users can simply call +product = matmul(QK)/sqrt(dk) +output = softmax_mask_fuse(product, attn_mask) +to get the final output. +By doing this fusion, we can optimize the training by +1. saving one launch cost, one forward and one backward cost +2. saving the memory cost used to save the tmp var +)DOC"); + } +}; + +class SoftmaxMaskFuseOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "SoftmaxMaskFuseGrad"); + + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), out_dims); + ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); + } +}; + +template +class SoftmaxMaskFuseGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("fused_softmax_mask_grad"); + op->SetInput("Softmax", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_softmax_mask, ops::SoftmaxMaskFuseOp, + ops::SoftmaxMaskFuseOpMaker, + ops::SoftmaxMaskFuseGradOpMaker, + ops::SoftmaxMaskFuseGradOpMaker); +REGISTER_OPERATOR(fused_softmax_mask_grad, ops::SoftmaxMaskFuseOpGrad); +REGISTER_OP_CPU_KERNEL( + fused_softmax_mask, + ops::SoftmaxMaskFuseCPUKernel, + ops::SoftmaxMaskFuseCPUKernel); diff --git a/paddle/fluid/operators/fused_softmax_mask_op.cu b/paddle/fluid/operators/fused_softmax_mask_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..2ba5c027a4d7609f045dcbd960d71ad5b022daa1 --- /dev/null +++ b/paddle/fluid/operators/fused_softmax_mask_op.cu @@ -0,0 +1,538 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +// this file is inspired by: +// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/fused_kernels/scaled_masked_softmax.h + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/fused_softmax_mask_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +#ifdef PADDLE_WITH_HIP +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif + +#define MASK 0xffffffff + +namespace plat = paddle::platform; + +__device__ __inline__ void load_data(plat::float16* dst, + const plat::float16* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +__device__ __inline__ void load_data(float* dst, const float* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +int get_pow2(int value) { + // get next pow2 index + int pow2_index = 0; + while ((1 << pow2_index) < value) { + ++pow2_index; + } + return pow2_index; +} + +template +struct AddOP { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct MaxOP { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T warp_shfl_xor(T value, int laneMask, int width, + unsigned int mask = MASK) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(T* sum) { + ReduceOp r; +#pragma unroll + for (int offset = width / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < batch; ++i) { + T b = warp_shfl_xor(sum[i], offset, width); + sum[i] = r(sum[i], b); + } + } +} + +// T == fp16 +template +__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, const T* mask_data, + T* y_data, int batch_count, + int key_seq_len) { + // the forward gpu kernel + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + + int data_first_idx = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + kLocalBatchSize; + + int mask_fist_idx = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + kLocalBatchSize; + + // batch_count might not be a multiple of kLocalBatchSize. Check how + // many batches have to computed within this WARP. + int local_batches = batch_count - data_first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + // might be many batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + int x_offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx; + int mask_offset = mask_fist_idx * key_seq_len + kOneLoadingCounts * local_idx; + x_data += x_offset; + mask_data += mask_offset; + y_data += x_offset; + + // using float for all inter compute + float data[kLocalBatchSize][kLocalIterations]; + T temp_data[kOneLoadingCounts]; + T temp_mask[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + int batch_data = (i >= local_batches) ? 0 : key_seq_len; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int data_index = kOneLoadingCounts * local_idx + ii * warp_size; + + if (data_index < batch_data) { + int itr_idx = i * key_seq_len + ii * warp_size; + + // efficiently load data from global memory + load_data(temp_data, x_data + itr_idx); + load_data(temp_mask, mask_data + itr_idx); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + data[i][ii + counter] = static_cast(temp_data[counter]) + + static_cast(temp_mask[counter]); + } + } else { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + data[i][ii + counter] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + // max value for each batch for current warp + float samples_max_value[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + samples_max_value[i] = data[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + samples_max_value[i] = (samples_max_value[i] > data[i][ii]) + ? samples_max_value[i] + : data[i][ii]; + } + } + // max value for each batch for all warp + warp_reduce(samples_max_value); + + // compute the sum for each batch for current warp + float samples_sum[kLocalBatchSize]{0.0f}; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ++ii) { + data[i][ii] = std::exp((data[i][ii] - samples_max_value[i])); + samples_sum[i] += data[i][ii]; + } + } + // samples_sum for each batch for all warp + warp_reduce(samples_sum); + + // load the result from device back to host + T samples_out[kOneLoadingCounts]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int idx = kOneLoadingCounts * local_idx + ii * warp_size; + if (idx < key_seq_len) { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + samples_out[counter] = data[i][ii + counter] / samples_sum[i]; + } + load_data(y_data + i * key_seq_len + ii * warp_size, samples_out); + } else { + break; + } + } + } +} + +template +__global__ void SoftmaxMaskFuseGradGPUKernel(const T* grad_input, + T* grad_output, + const T* softmax_rst, + int batch_count, int key_seq_len) { + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + + int data_first_idx = + (blockDim.y * blockIdx.x + threadIdx.y) * kLocalBatchSize; + + // batch_count might not be a multiple of kLocalBatchSize. Check how + // many batches have to computed within this WARP. + int local_batches = batch_count - data_first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + // might be many batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx; + grad_input += offset; + grad_output += offset; + softmax_rst += offset; + + // using float for all inter compute + float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + T temp_grad_input[kOneLoadingCounts]; + T temp_softmax_rst[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + int batch_data = (i >= local_batches) ? 0 : key_seq_len; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int data_index = kOneLoadingCounts * local_idx + ii * WARP_SIZE; + if (data_index < batch_data) { + load_data(temp_grad_input, + grad_input + i * key_seq_len + ii * warp_size); + load_data(temp_softmax_rst, + softmax_rst + i * key_seq_len + ii * warp_size); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + softmax_rst_reg[i][ii + counter] = + static_cast(temp_softmax_rst[counter]); + } +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + grad_input_reg[i][ii + counter] = + static_cast(temp_grad_input[counter]) * + softmax_rst_reg[i][ii + counter]; + } + } + } + } + + float samples_sum[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + samples_sum[i] = grad_input_reg[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + samples_sum[i] += grad_input_reg[i][ii]; + } + } + warp_reduce(samples_sum); + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int data_index = kOneLoadingCounts * local_idx + ii * warp_size; + if (data_index < key_seq_len) { + // compute gradients + T samples_out[kOneLoadingCounts]; +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + samples_out[counter] = + grad_input_reg[i][ii + counter] - + softmax_rst_reg[i][ii + counter] * samples_sum[i]; + } + load_data(grad_output + i * key_seq_len + ii * warp_size, samples_out); + } + } + } +} + +// T only supports fp16 +// leave as template only for future update +template +class SoftmaxMaskFuseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* mask = context.Input("Mask"); + auto* y = context.Output("Out"); + + auto* x_data = x->data(); + auto* mask_data = mask->data(); + auto* y_data = y->mutable_data(context.GetPlace()); + + auto x_dim = x->dims(); + auto mask_dim = mask->dims(); + auto batches = x_dim[0]; + auto attn_heads = x_dim[1]; + auto query_seq_len = x_dim[2]; + auto key_seq_len = x_dim[3]; + + PADDLE_ENFORCE_GT(query_seq_len, 1, + platform::errors::InvalidArgument( + "Input x's second last dim must be large than 1 but " + "received the second last dimension of x is %d", + query_seq_len)); + + PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192, true, + platform::errors::InvalidArgument( + "Input x's last dim must be between [32, 8192) " + "received the last dimension of x is %d", + key_seq_len)); + + PADDLE_ENFORCE_EQ(mask_dim[1], 1, + platform::errors::InvalidArgument( + "Input mask's second dim must be 1 " + "received the second dimension of mask is %d", + mask_dim[1])); + + // dim of x and mask must be equal + for (size_t idx = 0; idx < 4; ++idx) { + if (idx == 1) continue; + PADDLE_ENFORCE_EQ( + x_dim[idx], mask_dim[idx], + platform::errors::InvalidArgument( + "Input x's %dth dim should be equal with input mask's %dth dim " + "but " + "received the %dth dimension of x and mask are not equal " + "the %dth dim of x is %d, while the %dth dim of mask is %d.", + idx, idx, idx, idx, x_dim[idx], idx, mask_dim[idx])); + } + + auto& place = *context.template device_context().eigen_device(); + auto stream = context.cuda_device_context().stream(); + + int pow2_index = get_pow2(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int batch_count = batches * attn_heads * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + // use 128 threads per block to maximum gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + PADDLE_ENFORCE_EQ( + query_seq_len % batches_per_block, 0, + platform::errors::InvalidArgument( + "The query seq len (third dim of input X) must can divide the " + "number of batches per block. The query seq len is %d, while " + "the number of batches per block is %d.", + query_seq_len, batches_per_block)); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + + // launch the kernel based on the pow2_index + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + default: + break; + } + } +}; + +template +class SoftmaxMaskFuseGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* softmax_rst = context.Input("Softmax"); + + auto* grad_x_data = grad_x->mutable_data(context.GetPlace()); + auto* grad_y_data = grad_y->data(); + auto* softmax_rst_data = softmax_rst->data(); + + auto y_dim = grad_y->dims(); + auto batches = y_dim[0]; + auto attn_heads = y_dim[1]; + auto query_seq_len = y_dim[2]; + auto key_seq_len = y_dim[3]; + + auto& place = *context.template device_context().eigen_device(); + auto stream = context.cuda_device_context().stream(); + + int pow2_index = get_pow2(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int batch_count = batches * attn_heads * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + // use 128 threads per block to maximum gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + + // launch the kernel based on the pow2_index + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, + key_seq_len); + break; + default: + break; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + fused_softmax_mask, + ops::SoftmaxMaskFuseKernel, + ops::SoftmaxMaskFuseKernel); +REGISTER_OP_CUDA_KERNEL( + fused_softmax_mask_grad, + ops::SoftmaxMaskFuseGradKernel, + ops::SoftmaxMaskFuseGradKernel); diff --git a/paddle/fluid/operators/fused_softmax_mask_op.h b/paddle/fluid/operators/fused_softmax_mask_op.h new file mode 100644 index 0000000000000000000000000000000000000000..452eda730e8d8434511399d7832a215c59f5878e --- /dev/null +++ b/paddle/fluid/operators/fused_softmax_mask_op.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SoftmaxMaskFuseCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::Unimplemented( + "Softmax mask fuse op only supports GPU now.")); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_op.py b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_op.py new file mode 100644 index 0000000000000000000000000000000000000000..cff06f9025fb14ea64bf494ecb27935682fcb469 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.incubate as incubate + +paddle.enable_static() + + +def _get_softmax(x, mask, fp16=True): + masked_x = (x + mask).astype("float32") + max_value = np.max(masked_x, axis=-1, keepdims=True) + before_exp = masked_x - max_value + exp = np.exp(before_exp) + exp_sum = np.sum(exp, axis=-1, keepdims=True) + rst = exp / exp_sum + if fp16: + rst = rst.astype("float16") + return rst + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxMaskFuseOp(OpTest): + def setUp(self): + self.op_type = "fused_softmax_mask" + x = np.random.random((1, 1, 8, 32)) + mask = np.random.randint(0, 2, (1, 1, 8, 32)) + mask_input = np.where(mask == 1, -10000.0, mask) + self.inputs = {'X': x, 'Mask': mask_input} + rst = _get_softmax(x, mask_input) + self.outputs = {'Out': rst} + + def test_check_output(self): + try: + self.check_output_with_place(core.CPUPlace()) + except NotImplementedError: + pass + + def test_check_grad(self): + try: + self.check_grad_with_place(core.CPUPlace(), ["X"], "Out") + except NotImplementedError: + pass + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxMaskFuseOp0(OpTest): + def setUp(self): + self.op_type = "fused_softmax_mask" + x = np.random.random((1, 1, 8, 32)).astype("float16") + mask = np.random.randint(0, 2, (1, 1, 8, 32)).astype("float16") + mask_input = np.where(mask == 1, -10000.0, mask) + self.inputs = {'X': x, 'Mask': mask_input} + rst = _get_softmax(x, mask_input) + self.outputs = {'Out': rst} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestDropoutBiasFuseOp3(unittest.TestCase): + def test_static_result(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = fluid.data(name="x", shape=[1, 1, 8, 32], dtype="float32") + input_mask = fluid.data( + name="mask", shape=[1, 1, 8, 32], dtype="float32") + rst = incubate.softmax_mask_fuse(input_x, input_mask) + + x_in_np = np.random.random((1, 1, 8, 32)).astype("float32") + mask = np.random.randint(0, 2, (1, 1, 8, 32)).astype("float32") + mask_in_np = np.where(mask == 1, -10000.0, mask) + rst_np = _get_softmax(x_in_np, mask_in_np, False) + + exe = fluid.Executor(fluid.CUDAPlace(0)) + fetches = exe.run(fluid.default_main_program(), + feed={"x": x_in_np, + "mask": mask_in_np}, + fetch_list=[rst]) + self.assertTrue(np.allclose(fetches[0], rst_np)) + + def test_dygraph(self): + with fluid.dygraph.guard(fluid.CUDAPlace(0)): + x_in_np = np.random.random((1, 1, 8, 32)).astype("float32") + mask = np.random.randint(0, 2, (1, 1, 8, 32)).astype("float32") + mask_in_np = np.where(mask == 1, -10000.0, mask) + rst_np = _get_softmax(x_in_np, mask_in_np, False) + input_x = fluid.dygraph.to_variable(x_in_np) + input_mask = fluid.dygraph.to_variable(mask_in_np) + + rst = incubate.softmax_mask_fuse(input_x, input_mask) + self.assertTrue(np.allclose(rst, rst_np)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 9b9797ede717e7ffc7f1b710162b9bbc23098ad3..efaeda272087fcce65cf9a4b174b491e7e60d097 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -17,7 +17,8 @@ from .optimizer import ModelAverage # noqa: F401 from .checkpoint import auto_checkpoint # noqa: F401 from ..fluid.layer_helper import LayerHelper # noqa: F401 from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 +from .operators import softmax_mask_fuse # noqa: F401 __all__ = [ # noqa - 'LookAhead', 'ModelAverage', 'softmax_mask_fuse_upper_triangle' + 'LookAhead', 'ModelAverage', 'softmax_mask_fuse_upper_triangle', 'softmax_mask_fuse' ] diff --git a/python/paddle/incubate/operators/__init__.py b/python/paddle/incubate/operators/__init__.py index 026bf32d81250dfb5613c242311dde37484b428d..694cde4f28624b76e8da1af880f530e40789ef68 100644 --- a/python/paddle/incubate/operators/__init__.py +++ b/python/paddle/incubate/operators/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401 +from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401 diff --git a/python/paddle/incubate/operators/softmax_mask_fuse.py b/python/paddle/incubate/operators/softmax_mask_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc4175b0d1c444b16950c3cae2fa9c0484546cb --- /dev/null +++ b/python/paddle/incubate/operators/softmax_mask_fuse.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid import core + + +def softmax_mask_fuse(x, mask, name=None): + if in_dygraph_mode(): + out = core.ops.fused_softmax_mask(x, mask) + return out + helper = LayerHelper('fused_softmax_mask', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='fused_softmax_mask', + inputs={'X': [x], + 'Mask': [mask]}, + outputs={'Out': [out]}) + return out