/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { template void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, const T *input, std::vector input_shape, const T *bias, const T *scale, T *output, T *mean, T *variance, int begin_norm_axis, float eps) { const auto x_dims = framework::make_ddim(input_shape); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int64_t batch_size = static_cast(matrix_dim[0]); int64_t feature_size = static_cast(matrix_dim[1]); switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormForward<<>>( input, scale, bias, output, mean, variance, eps, feature_size)); default: PADDLE_THROW(platform::errors::InvalidArgument( "Product from begin_norm_axis to end in layer_norm must be larger " "than 1")); break; } } template class LayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); auto *scale = ctx.Input("Scale"); auto *bias = ctx.Input("Bias"); auto *x = ctx.Input("X"); auto *y = ctx.Output("Y"); auto *mean = ctx.Output("Mean"); auto *var = ctx.Output("Variance"); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto x_dims = x->dims(); auto *x_data = x->data(); auto *y_data = y->mutable_data(ctx.GetPlace()); auto *mean_data = mean->mutable_data(ctx.GetPlace()); auto *var_data = var->mutable_data(ctx.GetPlace()); auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); framework::proto::VarType::Type x_dtype = x->type(); framework::proto::VarType::Type scale_bias_dtype; if (void_scale_data != nullptr) { scale_bias_dtype = scale->type(); if (void_bias_data != nullptr) { PADDLE_ENFORCE_EQ(scale_bias_dtype, bias->type(), platform::errors::InvalidArgument( "Thie Scale and Bias of layer_norm op " "should have the same data type.")); } } else { scale_bias_dtype = (void_bias_data != nullptr ? bias->type() : x_dtype); } bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype; if (!is_scale_bias_same_dtype_with_x) { PADDLE_ENFORCE_EQ(scale_bias_dtype, framework::DataTypeTrait::DataType(), platform::errors::InvalidArgument( "Unsupported data type of Scale and Bias: %s", framework::DataTypeToString(scale_bias_dtype))); } auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int64_t batch_size = static_cast(matrix_dim[0]); int64_t feature_size = static_cast(matrix_dim[1]); auto stream = ctx.cuda_device_context().stream(); #define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ do { \ switch (GetDesiredBlockDim(feature_size)) { \ FIXED_BLOCK_DIM_CASE( \ LayerNormForward<<< \ batch_size, kBlockDim, 0, stream>>>( \ x_data, static_cast(void_scale_data), \ static_cast(void_bias_data), y_data, \ mean_data, var_data, epsilon, feature_size)); \ default: \ PADDLE_THROW(platform::errors::InvalidArgument( \ "Product from begin_norm_axis to end must be larger than 1")); \ break; \ } \ } while (0) #ifdef PADDLE_WITH_CUDA bool can_call_1024_kernel = false; if (feature_size == 1024 && scale != nullptr && bias != nullptr) { can_call_1024_kernel = true; } if (can_call_1024_kernel) { const int WARPS_M = 4; const int WARPS_N = 1; const int THREADS_PER_WARP = 32; const int BYTES_PER_LDG = 16; const int VecSize = BYTES_PER_LDG / sizeof(T); const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; const int ROWS_PER_CTA = WARPS_M; const int grid = static_cast( std::ceil(batch_size / static_cast(ROWS_PER_CTA))); if (is_scale_bias_same_dtype_with_x) { ln_fwd_1024_kernel<<>>( batch_size, feature_size, epsilon, x_data, static_cast(void_scale_data), static_cast(void_bias_data), mean_data, var_data, y_data); } else { ln_fwd_1024_kernel<<>>( batch_size, feature_size, epsilon, x_data, static_cast(void_scale_data), static_cast(void_bias_data), mean_data, var_data, y_data); } } else { #endif if (is_scale_bias_same_dtype_with_x) { PADDLE_LAUNCH_LAYERNORM_FWD(T, true); } else { PADDLE_LAUNCH_LAYERNORM_FWD(U, false); } #ifdef PADDLE_WITH_CUDA } #endif #undef PADDLE_LAUNCH_LAYERNORM_FWD } }; template class LayerNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); // d_x, d_scale, d_bias may be nullptr auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); auto *x = ctx.Input("X"); auto *mean = ctx.Input("Mean"); auto *var = ctx.Input("Variance"); auto *scale = ctx.Input("Scale"); auto *bias = ctx.Input("Bias"); auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto &x_dims = x->dims(); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int64_t batch_size = static_cast(matrix_dim[0]); int64_t feature_size = static_cast(matrix_dim[1]); auto *x_data = x->data(); auto *d_y_data = d_y->data(); auto *mean_data = mean->data(); auto *var_data = var->data(); auto *d_x_data = (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); framework::proto::VarType::Type x_dtype = x->type(); framework::proto::VarType::Type scale_bias_dtype; if (scale != nullptr) { scale_bias_dtype = scale->type(); } else { // FIXME(zengjinle): do not find a better way to get the right // data type of the d_scale and d_bias if scale == nullptr. auto *bias = ctx.Input("Bias"); if (bias != nullptr) { scale_bias_dtype = bias->saved_type(); } else { scale_bias_dtype = x_dtype; } } #define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ do { \ auto *scale_data = \ (scale == nullptr ? nullptr : scale->data()); \ auto *d_scale_data = \ (d_scale == nullptr ? nullptr : d_scale->mutable_data( \ ctx.GetPlace())); \ auto *d_bias_data = \ (d_bias == nullptr ? nullptr : d_bias->mutable_data( \ ctx.GetPlace())); \ auto *d_x_data = \ (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); \ LayerNormBackward( \ x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \ d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \ ctx.cuda_device_context()); \ } while (0) if (scale_bias_dtype == x_dtype) { PADDLE_LAUNCH_LAYERNORM_BWD(T, true); } else { PADDLE_LAUNCH_LAYERNORM_BWD(U, false); } #undef PADDLE_LAUNCH_LAYERNORM_BWD } }; template class LayerNormDirectCUDAFunctor; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; #ifdef PADDLE_WITH_HIP // MIOPEN do not support double REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, ops::LayerNormGradKernel); #else REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, ops::LayerNormKernel, ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, ops::LayerNormGradKernel, ops::LayerNormGradKernel); #endif