/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef __NVCC__ #include #endif #ifdef __HIPCC__ #include namespace cub = hipcub; #endif #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/cuda_device_function.h" namespace paddle { namespace operators { template static __device__ __forceinline__ T Relu(T x) { return (x > 0) ? x : 0; } static __device__ __forceinline__ float RealSqrt(float x) { return sqrtf(x); } static __device__ __forceinline__ double RealSqrt(double x) { return sqrt(x); } template struct PairForLayerNorm { __device__ __forceinline__ PairForLayerNorm() {} __device__ __forceinline__ PairForLayerNorm(const T& first, const T& second) : first_(first), second_(second) {} T first_; T second_; }; template struct PairForLayerNormAddFunctor { __device__ __forceinline__ PairForLayerNorm operator()( const PairForLayerNorm& p1, const PairForLayerNorm& p2) { return PairForLayerNorm(p1.first_ + p2.first_, p1.second_ + p2.second_); } }; template __global__ void InplaceAddReluAddLayerNormKernel(const T* y, const T* bias_0, const T* bias_1, const T* scale, T* out, T* mean, T* variance, int M, int N, float epsilon) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T shared_mem[BlockDim + 2]; for (int i = blockIdx.x; i < M; i += gridDim.x) { int index = i * N + threadIdx.x; // The fisrt BlockDim elements will be saved to shared memory. int save_index = threadIdx.x; T* save_ptr = shared_mem; T sum_i = 0; T square_sum_i = 0; for (int j = threadIdx.x; j < N; j += blockDim.x) { T tmp_0 = out[index]; // Add bias T tmp_1 = bias_0 ? tmp_0 + bias_0[j] : tmp_0; // Relu T tmp_2 = DoRelu ? Relu(tmp_1) : tmp_1; // elementwise_add T tmp_3 = tmp_2 + y[index]; // Save save_ptr[save_index] = tmp_3; save_ptr = out; index += blockDim.x; save_index = index; // For layer_norm, reduce to calculate mean and std sum_i += tmp_3; square_sum_i += (tmp_3 * tmp_3); } auto pair = BlockReduce(temp_storage) .Reduce(PairForLayerNorm(sum_i, square_sum_i), PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { T mean_i = static_cast(pair.first_ / N); T variance_i = static_cast(pair.second_ / N - mean_i * mean_i); shared_mem[BlockDim] = mean_i; shared_mem[BlockDim + 1] = variance_i; if (mean) { mean[blockIdx.x] = mean_i; } if (variance) { variance[blockIdx.x] = variance_i; } } __syncthreads(); T mean_i = shared_mem[BlockDim]; T std_i = static_cast(RealSqrt(shared_mem[BlockDim + 1] + epsilon)); index = i * N + threadIdx.x; // First BlockDim elements loading from shared memory. save_index = threadIdx.x; save_ptr = shared_mem; // For layer_norm, calculate out for (int j = threadIdx.x; j < N; j += blockDim.x) { T tmp_0 = (save_ptr[save_index] - mean_i) / std_i; T tmp_1 = scale ? scale[j] * tmp_0 : tmp_0; out[index] = bias_1 ? tmp_1 + bias_1[j] : tmp_1; save_ptr = out; index += blockDim.x; save_index = index; } } } template class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* w = ctx.Input("W"); auto* out = ctx.Output("Out"); auto w_dims = w->dims(); int N = w_dims[1]; int K = w_dims[0]; int M = framework::product(x->dims()) / K; const T* x_data = x->data(); const T* w_data = w->data(); T* out_data = out->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); blas.GEMM(false, false, M, N, K, static_cast(1.0), x_data, K, w_data, N, static_cast(0.0), out_data, N); auto* y = ctx.Input("Y"); auto* bias_0 = ctx.Input("Bias0"); auto* bias_1 = ctx.Input("Bias1"); auto* scale = ctx.Input("Scale"); const T* y_data = y->data(); const T* bias_0_data = bias_0 ? bias_0->data() : nullptr; const T* bias_1_data = bias_1 ? bias_1->data() : nullptr; const T* scale_data = scale ? scale->data() : nullptr; auto* mean = ctx.Output("Mean"); auto* variance = ctx.Output("Variance"); T* mean_data = mean ? mean->mutable_data(ctx.GetPlace()) : nullptr; T* variance_data = variance ? variance->mutable_data(ctx.GetPlace()) : nullptr; bool with_relu = (ctx.Attr("activation_type") == "relu") ? true : false; float epsilon = ctx.Attr("epsilon"); int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); if (with_relu) { switch (platform::RoundToPowerOfTwo(N)) { CUDA_LAUNCH_KERNEL_HELPER( InplaceAddReluAddLayerNormKernel< T, true, kPowerOfTwoDim><<>>( y_data, bias_0_data, bias_1_data, scale_data, out_data, mean_data, variance_data, M, N, epsilon)); } } else { switch (platform::RoundToPowerOfTwo(N)) { CUDA_LAUNCH_KERNEL_HELPER( InplaceAddReluAddLayerNormKernel< T, false, kPowerOfTwoDim><<>>( y_data, bias_0_data, bias_1_data, scale_data, out_data, mean_data, variance_data, M, N, epsilon)); } } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(fused_fc_elementwise_layernorm, ops::FusedFCElementwiseLayerNormOpKernel);