diff --git a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc index 67e7c78b62e9d212b5c1738403361d77d7a3925b..496e8932a690dbcd87001da4f7e017fc86d6bff5 100644 --- a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/layer_norm_op.h" + #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h" diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu index 861e98e4437564bfe5fae2a575741beb1d8823de..67d44184a76d0552b667c6d5a3d9466582e33558 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -17,7 +17,7 @@ #include #include "glog/logging.h" #include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h" -#include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/phi/kernels/layer_norm_kernel.h" namespace paddle { namespace inference { @@ -83,7 +83,7 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size, cudaMemcpyHostToDevice, stream); - paddle::operators::LayerNormDirectCUDAFunctor layer_norm; + phi::LayerNormDirectCUDAFunctor layer_norm; layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, variance_d, begin_norm_axis, eps); return cudaGetLastError() != cudaSuccess; @@ -177,7 +177,7 @@ int LayerNormPluginDynamic::enqueue( cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size, cudaMemcpyHostToDevice, stream); - paddle::operators::LayerNormDirectCUDAFunctor layer_norm; + phi::LayerNormDirectCUDAFunctor layer_norm; layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, variance_d, begin_norm_axis, eps); } else { diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index 18c7187fc8e64c9fed8a86a984954b5420c1e5b5..a9b72a9cdf397f026f6ce24d83cc13066a3fd000 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -25,14 +25,16 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/string/printf.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/layer_norm_kernel.h" namespace framework = paddle::framework; namespace platform = paddle::platform; namespace memory = paddle::memory; USE_OP_ITSELF(dropout); -USE_OP(layer_norm); +USE_OP_ITSELF(layer_norm); template using CudnnDataType = platform::CudnnDataType; @@ -136,18 +138,23 @@ void LayerNorm(const std::vector> &scale, const platform::CUDADeviceContext &ctx) { framework::Scope scope; auto place = ctx.GetPlace(); + paddle::optional scale_opt = paddle::none; if (scale.size() > 0) { auto var_scale = scope.Var("Scale"); auto tensor_scale = var_scale->GetMutable(); framework::TensorFromVector(scale, ctx, tensor_scale); tensor_scale->Resize({cols}); + scale_opt = *tensor_scale; } + paddle::optional bias_opt = paddle::none; if (bias.size() > 0) { auto var_bias = scope.Var("Bias"); auto tensor_bias = var_bias->GetMutable(); framework::TensorFromVector(bias, ctx, tensor_bias); tensor_bias->Resize({cols}); + + bias_opt = *tensor_bias; } auto var_x = scope.Var("X"); @@ -157,20 +164,19 @@ void LayerNorm(const std::vector> &scale, auto var_y = scope.Var("Y"); auto tensor_y = var_y->GetMutable(); + tensor_y->Resize({rows, cols}); auto var_mean = scope.Var("Mean"); auto tensor_mean = var_mean->GetMutable(); + tensor_mean->Resize({rows}); auto var_variance = scope.Var("Variance"); auto tensor_variance = var_variance->GetMutable(); - - framework::AttributeMap attrs; - attrs.insert({"epsilon", epsilon}); - - auto op = framework::OpRegistry::CreateOp( - "layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}}, - {{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs); - op->Run(scope, place); + tensor_variance->Resize({rows}); + ctx.Wait(); + phi::LayerNormKernel(static_cast(ctx), *tensor_x, + scale_opt, bias_opt, 1e-5, 1, false, tensor_y, + tensor_mean, tensor_variance); framework::TensorToVector(*tensor_y, ctx, y); framework::TensorToVector(*tensor_mean, ctx, means); framework::TensorToVector(*tensor_variance, ctx, vars); diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu index 032440d7f0478dc087e3ba38274f2a31a9a66a23..c7e1f4a5463fe11b9fa96f147b71004140130399 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu @@ -198,7 +198,6 @@ struct TestFusedLayernormResidualDropoutBias { residual_vec[i * cols + j] + out2[i * cols + j]; } } - LayerNorm(scale_vec, layernorm_bias_vec, correct_out, &correct_means, &correct_vars, &correct_layernorm_out, epsilon, rows, cols, *ctx); diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 412ae3c49b5f3cc9fc2422aa220af324e6d99b69..c0a4b88fc76fd0d648b289e0d2f13536523f02d8 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -758,12 +758,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( */ template -void ln_bwd_1024_kernel_driver( - const platform::CUDADeviceContext &dev_ctx, const int rows, const int cols, - float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr, - const U *var_ptr, const T *dout_ptr, T *dx_ptr, ScaleT *dscale_ptr, - ScaleT *dbias_ptr, const MaskType *mask_ptr = nullptr, - T factor = static_cast(0), T *d_dropout_src_ptr = nullptr) { +void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, + const int cols, float epsilon, const T *x_ptr, + const ScaleT *scale_ptr, const U *mean_ptr, + const U *var_ptr, const T *dout_ptr, T *dx_ptr, + ScaleT *dscale_ptr, ScaleT *dbias_ptr, + const MaskType *mask_ptr = nullptr, + T factor = static_cast(0), + T *d_dropout_src_ptr = nullptr) { auto stream = dev_ctx.stream(); if (cols == 1024) { // step-1: compute dx and reduced part results of dscale and dbias. @@ -1334,8 +1336,7 @@ static void LayerNormBackward( const U *mean, const U *var, T *d_x, LayerNormScaleBiasT *d_scale, LayerNormScaleBiasT *d_bias, float epsilon, - int64_t batch_size, int64_t feature_size, - const platform::CUDADeviceContext &dev_ctx) { + int64_t batch_size, int64_t feature_size, const phi::GPUContext &dev_ctx) { auto stream = dev_ctx.stream(); #ifdef __HIPCC__ const int kMaxBlockDim = 256; diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index e7d676479be0cc1176fa27c477bd35a5d6787cd3..224ab748dab6cdf8be246c4b400b4e55b6faf675 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/layer_norm_op.h" - #include #include +#include "paddle/fluid/framework/op_registry.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -278,10 +277,3 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, ops::LayerNormGradOpMaker); REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp, ops::LayerNormGradNoNeedBufferVarInferer); -REGISTER_OP_CPU_KERNEL( - layer_norm, ops::LayerNormKernel, - ops::LayerNormKernel); -REGISTER_OP_CPU_KERNEL( - layer_norm_grad, - ops::LayerNormGradKernel, - ops::LayerNormGradKernel); diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu deleted file mode 100644 index dfe73d3727132ae9b8f71e2a415ef5193f303493..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/layer_norm_op.cu +++ /dev/null @@ -1,289 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" -#include "paddle/fluid/operators/layer_norm_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -template -void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, - const T *input, - std::vector input_shape, - const T *bias, const T *scale, - T *output, T *mean, T *variance, - int begin_norm_axis, float eps) { - const auto x_dims = phi::make_ddim(input_shape); - auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); - int64_t batch_size = static_cast(matrix_dim[0]); - int64_t feature_size = static_cast(matrix_dim[1]); - switch (GetDesiredBlockDim(feature_size)) { - FIXED_BLOCK_DIM_CASE( - LayerNormForward<<>>( - input, scale, bias, output, mean, variance, eps, feature_size)); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Product from begin_norm_axis to end in layer_norm must be larger " - "than 1")); - break; - } -} - -template -class LayerNormKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - const float epsilon = ctx.Attr("epsilon"); - auto *scale = ctx.Input("Scale"); - auto *bias = ctx.Input("Bias"); - auto *x = ctx.Input("X"); - - auto *y = ctx.Output("Y"); - auto *mean = ctx.Output("Mean"); - auto *var = ctx.Output("Variance"); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - - const auto x_dims = x->dims(); - auto *x_data = x->data(); - auto *y_data = y->mutable_data(ctx.GetPlace()); - auto *mean_data = mean->mutable_data(ctx.GetPlace()); - auto *var_data = var->mutable_data(ctx.GetPlace()); - - auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); - auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); - - framework::proto::VarType::Type x_dtype = - framework::TransToProtoVarType(x->dtype()); - framework::proto::VarType::Type scale_bias_dtype; - if (void_scale_data != nullptr) { - scale_bias_dtype = framework::TransToProtoVarType(scale->dtype()); - if (void_bias_data != nullptr) { - PADDLE_ENFORCE_EQ(scale_bias_dtype, - framework::TransToProtoVarType(bias->dtype()), - platform::errors::InvalidArgument( - "Thie Scale and Bias of layer_norm op " - "should have the same data type.")); - } - } else { - scale_bias_dtype = (void_bias_data != nullptr - ? framework::TransToProtoVarType(bias->dtype()) - : x_dtype); - } - - bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype; - if (!is_scale_bias_same_dtype_with_x) { - PADDLE_ENFORCE_EQ(scale_bias_dtype, - framework::DataTypeTrait::DataType(), - platform::errors::InvalidArgument( - "Unsupported data type of Scale and Bias: %s", - framework::DataTypeToString(scale_bias_dtype))); - } - - auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); - int64_t batch_size = static_cast(matrix_dim[0]); - int64_t feature_size = static_cast(matrix_dim[1]); - - auto stream = ctx.cuda_device_context().stream(); - -#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ - do { \ - switch (GetDesiredBlockDim(feature_size)) { \ - FIXED_BLOCK_DIM_CASE( \ - LayerNormForward<<< \ - batch_size, kBlockDim, 0, stream>>>( \ - x_data, static_cast(void_scale_data), \ - static_cast(void_bias_data), y_data, \ - mean_data, var_data, epsilon, feature_size)); \ - default: \ - PADDLE_THROW(platform::errors::InvalidArgument( \ - "Product from begin_norm_axis to end must be larger than 1")); \ - break; \ - } \ - } while (0) - -#ifdef PADDLE_WITH_CUDA - bool can_call_1024_kernel = false; - if (feature_size == 1024 && scale != nullptr && bias != nullptr) { - can_call_1024_kernel = true; - } - if (can_call_1024_kernel) { - const int WARPS_M = 4; - const int WARPS_N = 1; - const int THREADS_PER_WARP = 32; - const int BYTES_PER_LDG = 16; - const int VecSize = BYTES_PER_LDG / sizeof(T); - - const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; - const int ROWS_PER_CTA = WARPS_M; - - const int grid = static_cast( - std::ceil(batch_size / static_cast(ROWS_PER_CTA))); - if (is_scale_bias_same_dtype_with_x) { - ln_fwd_1024_kernel<<>>( - batch_size, feature_size, epsilon, x_data, - static_cast(void_scale_data), - static_cast(void_bias_data), mean_data, var_data, - y_data); - } else { - ln_fwd_1024_kernel<<>>( - batch_size, feature_size, epsilon, x_data, - static_cast(void_scale_data), - static_cast(void_bias_data), mean_data, var_data, - y_data); - } - } else { -#endif - if (is_scale_bias_same_dtype_with_x) { - PADDLE_LAUNCH_LAYERNORM_FWD(T, true); - } else { - PADDLE_LAUNCH_LAYERNORM_FWD(U, false); - } -#ifdef PADDLE_WITH_CUDA - } -#endif - -#undef PADDLE_LAUNCH_LAYERNORM_FWD - } -}; - -template -class LayerNormGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - const float epsilon = ctx.Attr("epsilon"); - // d_x, d_scale, d_bias may be nullptr - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - - auto *x = ctx.Input("X"); - auto *mean = ctx.Input("Mean"); - auto *var = ctx.Input("Variance"); - auto *scale = ctx.Input("Scale"); - auto *bias = ctx.Input("Bias"); - auto *d_y = ctx.Input(framework::GradVarName("Y")); - - const auto &x_dims = x->dims(); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); - int64_t batch_size = static_cast(matrix_dim[0]); - int64_t feature_size = static_cast(matrix_dim[1]); - - auto *x_data = x->data(); - auto *d_y_data = d_y->data(); - - auto *mean_data = mean->data(); - auto *var_data = var->data(); - - auto *d_x_data = - (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); - - framework::proto::VarType::Type x_dtype = - framework::TransToProtoVarType(x->dtype()); - framework::proto::VarType::Type scale_bias_dtype; - if (scale != nullptr) { - scale_bias_dtype = framework::TransToProtoVarType(scale->dtype()); - } else { - // FIXME(zengjinle): do not find a better way to get the right - // data type of the d_scale and d_bias if scale == nullptr. - auto *bias = ctx.Input("Bias"); - if (bias != nullptr) { - scale_bias_dtype = framework::TransToProtoVarType(bias->dtype()); - } else { - scale_bias_dtype = x_dtype; - } - } - -#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ - do { \ - auto *scale_data = \ - (scale == nullptr ? nullptr : scale->data()); \ - auto *d_scale_data = \ - (d_scale == nullptr ? nullptr : d_scale->mutable_data( \ - ctx.GetPlace())); \ - auto *d_bias_data = \ - (d_bias == nullptr ? nullptr : d_bias->mutable_data( \ - ctx.GetPlace())); \ - auto *d_x_data = \ - (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); \ - LayerNormBackward( \ - x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \ - d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \ - ctx.cuda_device_context()); \ - } while (0) - - if (scale_bias_dtype == x_dtype) { - PADDLE_LAUNCH_LAYERNORM_BWD(T, true); - } else { - PADDLE_LAUNCH_LAYERNORM_BWD(U, false); - } - -#undef PADDLE_LAUNCH_LAYERNORM_BWD - } -}; - -template class LayerNormDirectCUDAFunctor; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -#ifdef PADDLE_WITH_HIP -// MIOPEN do not support double -REGISTER_OP_CUDA_KERNEL( - layer_norm, - ops::LayerNormKernel, - ops::LayerNormKernel); -REGISTER_OP_CUDA_KERNEL( - layer_norm_grad, - ops::LayerNormGradKernel, - ops::LayerNormGradKernel); -#elif CUDNN_VERSION_MIN(8, 1, 0) -REGISTER_OP_CUDA_KERNEL( - layer_norm, - ops::LayerNormKernel, - ops::LayerNormKernel, - ops::LayerNormKernel, - ops::LayerNormKernel); -REGISTER_OP_CUDA_KERNEL( - layer_norm_grad, - ops::LayerNormGradKernel, - ops::LayerNormGradKernel, - ops::LayerNormGradKernel, - ops::LayerNormGradKernel); -#else -REGISTER_OP_CUDA_KERNEL( - layer_norm, - ops::LayerNormKernel, - ops::LayerNormKernel, - ops::LayerNormKernel); -REGISTER_OP_CUDA_KERNEL( - layer_norm_grad, - ops::LayerNormGradKernel, - ops::LayerNormGradKernel, - ops::LayerNormGradKernel); -#endif diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h deleted file mode 100644 index 9d70b7cf707437136bf358d31ea6fd4cc0f2a534..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/layer_norm_op.h +++ /dev/null @@ -1,374 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) -#include "paddle/fluid/operators/jit/kernels.h" -#endif -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace platform { -class CPUDeviceContext; -class CUDADeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { - -// Wrap RowwiseMean and ColwiseMean. -// Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is -// significantly faster. Unlike the RowwiseMean and ColwiseMean, the -// implementation only considers 2D. -template -struct RowwiseMean2D { - RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx); - - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor* vec); -}; - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -class RowwiseMean2D { - public: - RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) - : left_(left), right_(right) { - framework::DDim ones_dim({right_}); - divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); - phi::funcs::set_constant(dev_ctx, &divisor_, 1.0 / right); - } - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, framework::Tensor* out) { - phi::funcs::GetBlas(context).GEMV( - false, left_, right_, 1., input.data(), divisor_.data(), 0., - out->data()); - } - - private: - int left_; - int right_; - framework::Tensor divisor_; -}; -#endif - -template -class RowwiseMean2D { - public: - RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) {} - - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, framework::Tensor* out) { - row_mean_(context, input, out); - } - - private: - phi::funcs::RowwiseMean row_mean_; -}; - -template -struct ColwiseSum2D { - ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx); - - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor* vec); -}; - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -class ColwiseSum2D { - public: - ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) - : left_(left), right_(right) { - framework::DDim ones_dim({left_}); - divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); - phi::funcs::set_constant(dev_ctx, &divisor_, 1.0); - } - - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& input, framework::Tensor* out) { - phi::funcs::GetBlas(context).GEMV( - true, left_, right_, 1., input.data(), divisor_.data(), 0., - out->data()); - } - - private: - int left_; - int right_; - framework::Tensor divisor_; -}; -#endif - -template -class ColwiseSum2D { - public: - ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) {} - - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& input, framework::Tensor* out) { - col_wise_(context, input, out); - } - - private: - phi::funcs::ColwiseSum col_wise_; -}; - -template -struct SubAndSquareFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } -}; - -template -struct DivAndSqrtFunctor { - explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; } - inline HOSTDEVICE T operator()(T a, T b) const { - return a / (sqrt(b + epsilon_)); - } - - private: - T epsilon_; -}; - -template -struct MulInvVarFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { - return a * std::sqrt(1.0 / b); - } -}; - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DataLayout = framework::DataLayout; - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -class LayerNormDirectCUDAFunctor { - public: - void operator()(gpuStream_t stream, const T* input, - std::vector input_shape, const T* bias, const T* scale, - T* output, T* mean, T* variance, int begin_norm_axis, - float eps); -}; -#endif - -template -class LayerNormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const float epsilon = ctx.Attr("epsilon"); - auto* scale = ctx.Input("Scale"); - auto* bias = ctx.Input("Bias"); - auto x = *ctx.Input("X"); - - auto* y = ctx.Output("Y"); - auto* mean = ctx.Output("Mean"); - auto* var = ctx.Output("Variance"); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - - const auto x_dims = x.dims(); - - y->mutable_data(ctx.GetPlace()); - mean->mutable_data(ctx.GetPlace()); - var->mutable_data(ctx.GetPlace()); - - auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); - int left = static_cast(matrix_dim[0]); - int right = static_cast(matrix_dim[1]); - framework::DDim matrix_shape({left, right}); - - x.Resize(matrix_shape); - Tensor out; - out.ShareDataWith(*y); - out.Resize(matrix_shape); - -#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \ - defined(__OSX__) - auto& dev_ctx = ctx.template device_context(); - RowwiseMean2D row_mean(left, right, ctx.device_context()); - - // get mean - row_mean(dev_ctx, x, mean); - - // get variance - ElementwiseComputeEx, DeviceContext, T>( - ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor(), &out); - row_mean(dev_ctx, out, var); - - // get x_norm - ElementwiseComputeEx, DeviceContext, T>( - ctx, &x, mean, /*axis*/ 0, SubFunctor(), &out); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &out, var, /*axis*/ 0, - DivAndSqrtFunctor(static_cast(epsilon)), &out); - - if (scale) { - ElementwiseComputeEx, DeviceContext, T>( - ctx, &out, scale, /*axis*/ 1, MulFunctor(), &out); - } - if (bias) { - ElementwiseComputeEx, DeviceContext, T>( - ctx, &out, bias, /*axis*/ 1, AddFunctor(), &out); - } -#else - PADDLE_ENFORCE_EQ(mean->numel(), left, - platform::errors::InvalidArgument( - "mean's length (%d) is not equal with expected (%d).", - mean->numel(), left)); - PADDLE_ENFORCE_EQ(var->numel(), left, - platform::errors::InvalidArgument( - "var's length (%d) is not equal with expected (%d).", - var->numel(), left)); - if (scale) { - PADDLE_ENFORCE_EQ( - scale->numel(), right, - platform::errors::InvalidArgument( - "scale's length (%d) is not equal with expected (%d).", - scale->numel(), right)); - } - if (bias) { - PADDLE_ENFORCE_EQ( - bias->numel(), right, - platform::errors::InvalidArgument( - "bias's length (%d) is not equal with expected (%d).", - bias->numel(), right)); - } - - auto ker = - jit::KernelFuncs, platform::CPUPlace>::Cache() - .At(right); - ker(x.data(), out.data(), mean->data(), var->data(), - scale ? scale->data() : nullptr, bias ? bias->data() : nullptr, - static_cast(left), static_cast(epsilon), right); -#endif - } -}; - -template -class LayerNormGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const float epsilon = ctx.Attr("epsilon"); - auto x = *ctx.Input("X"); - auto* mean = ctx.Input("Mean"); - auto* var = ctx.Input("Variance"); - auto* scale = ctx.Input("Scale"); - auto d_y = *ctx.Input(framework::GradVarName("Y")); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - - // init output - auto* d_x = ctx.Output(framework::GradVarName("X")); - auto* d_scale = ctx.Output(framework::GradVarName("Scale")); - auto* d_bias = ctx.Output(framework::GradVarName("Bias")); - - const auto& x_dims = x.dims(); - auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); - int left = static_cast(matrix_dim[0]); - int right = static_cast(matrix_dim[1]); - framework::DDim matrix_shape({left, right}); - - d_y.Resize(matrix_shape); - auto& dev_ctx = ctx.template device_context(); - ColwiseSum2D colwise_sum(left, right, - ctx.device_context()); - - Tensor temp; - Tensor temp_norm; - if (d_scale || d_x) { - x.Resize(matrix_shape); - temp.mutable_data(matrix_shape, ctx.GetPlace()); - - temp_norm.mutable_data(matrix_shape, ctx.GetPlace()); - // get x_norm - ElementwiseComputeEx, DeviceContext, T>( - ctx, &x, mean, /*axis*/ 0, SubFunctor(), &temp_norm); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, var, /*axis*/ 0, - DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm); - } - - if (d_bias) { - d_bias->mutable_data(ctx.GetPlace()); - colwise_sum(dev_ctx, d_y, d_bias); - } - if (d_scale) { - d_scale->mutable_data(ctx.GetPlace()); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor(), &temp); - colwise_sum(dev_ctx, temp, d_scale); - } - - if (d_x) { - framework::DDim vec_shape({left}); - d_x->mutable_data(ctx.GetPlace()); - auto dx_dim = d_x->dims(); - Tensor temp_vec; - temp_vec.mutable_data(vec_shape, ctx.GetPlace()); - - RowwiseMean2D row_mean(left, right, - ctx.device_context()); - - if (d_scale) { - // dy_dx - ElementwiseComputeEx, DeviceContext, T>( - ctx, &d_y, scale, /*axis*/ 1, MulFunctor(), &temp); - framework::TensorCopy(temp, ctx.GetPlace(), ctx.device_context(), d_x); - - // dy_dmean_dx - row_mean(dev_ctx, temp, &temp_vec); - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); - - // dy_var_dx - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); - } else { - // dy_dx - framework::TensorCopy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); - - // dy_dmean_dx - row_mean(dev_ctx, d_y, &temp_vec); - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); - - // dy_var_dx - ElementwiseComputeEx, DeviceContext, T>( - ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); - } - // dy_var_dx - row_mean(dev_ctx, temp, &temp_vec); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor(), &temp); - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &temp, /*axis*/ 0, SubFunctor(), d_x); - - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, var, /*axis*/ 0, - DivAndSqrtFunctor(static_cast(epsilon)), d_x); - d_x->Resize(dx_dim); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/layer_norm_op_npu.cc b/paddle/fluid/operators/layer_norm_op_npu.cc index c88880b43fff9fccd9764f145fba8ca4c61343c7..3c7e5bf9593e0ae2b3d8c04db1467c3b8fd1e174 100644 --- a/paddle/fluid/operators/layer_norm_op_npu.cc +++ b/paddle/fluid/operators/layer_norm_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/layer_norm_op_xpu.cc b/paddle/fluid/operators/layer_norm_op_xpu.cc index 0480a354c8bd8fdb81c95a576f57e9a12019ffc9..3b21a55f8df0dbb532729cf5cbca4c7362223b9c 100644 --- a/paddle/fluid/operators/layer_norm_op_xpu.cc +++ b/paddle/fluid/operators/layer_norm_op_xpu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc index 812c55cdd5055186d7fd83a2057d88256f3b34a3..2e82b47e8da1c6eb6f4a05fc4f7f356110f9fff1 100644 --- a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace operators { @@ -139,7 +140,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { layer_norm_p->execute(astream, args); astream.wait(); - y->set_layout(DataLayout::kMKLDNN); + y->set_layout(phi::DataLayout::kMKLDNN); y->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; diff --git a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..cee48ed96db1c60fb77dc7c870cb256b7ce0cb6e --- /dev/null +++ b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/layer_norm_grad_kernel.h" +#include "paddle/phi/kernels/cpu/elementwise.h" +#include "paddle/phi/kernels/funcs/layer_norm_util.h" +#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) +#include "paddle/fluid/operators/jit/kernels.h" +#endif +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void LayerNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mean, + const DenseTensor& variance, + paddle::optional scale_opt, + paddle::optional bias_opt, + const DenseTensor& out_grad, + float epsilon, + int begin_norm_axis, + bool is_test, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + auto* scale = scale_opt.get_ptr(); + auto d_y = out_grad; + + // init output + auto* d_x = x_grad; + auto* d_scale = scale_grad; + auto* d_bias = bias_grad; + + const auto& x_dims = x.dims(); + auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + DDim matrix_shape({left, right}); + + d_y.Resize(matrix_shape); + + funcs::ColwiseSum2D colwise_sum(left, right, dev_ctx); + DenseTensor x_tmp = x; + + DenseTensor temp; + DenseTensor temp_norm; + if (d_scale || d_x) { + x_tmp.Resize(matrix_shape); + temp.Resize(matrix_shape); + dev_ctx.template Alloc(&temp); + + temp_norm.Resize(matrix_shape); + dev_ctx.template Alloc(&temp_norm); + // get x_norm + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + x_tmp, + mean, + /*axis*/ 0, + funcs::SubtractFunctor(), + &temp_norm); + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + temp_norm, + variance, + /*axis*/ 0, + funcs::DivAndSqrtFunctor(static_cast(epsilon)), + &temp_norm); + } + + if (d_bias) { + dev_ctx.template Alloc(d_bias); + colwise_sum(dev_ctx, d_y, d_bias); + } + if (d_scale) { + dev_ctx.template Alloc(d_scale); + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, temp_norm, d_y, 0, funcs::MultiplyFunctor(), &temp); + colwise_sum(dev_ctx, temp, d_scale); + } + + if (d_x) { + DDim vec_shape({left}); + dev_ctx.template Alloc(d_x); + auto dx_dim = d_x->dims(); + DenseTensor temp_vec; + temp_vec.Resize(vec_shape); + dev_ctx.template Alloc(&temp_vec); + + funcs::RowwiseMean2D row_mean(left, right, dev_ctx); + + if (d_scale) { + // dy_dx + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, d_y, *scale, /*axis*/ 1, funcs::MultiplyFunctor(), &temp); + phi::Copy(dev_ctx, temp, dev_ctx.GetPlace(), false, d_x); + + // dy_dmean_dx + row_mean(dev_ctx, temp, &temp_vec); + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + *d_x, + temp_vec, + /*axis*/ 0, + funcs::SubtractFunctor(), + d_x); + + // dy_var_dx + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + temp, + temp_norm, + /*axis*/ 0, + funcs::MultiplyFunctor(), + &temp); + } else { + // dy_dx + phi::Copy(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x); + + // dy_dmean_dx + row_mean(dev_ctx, d_y, &temp_vec); + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + *d_x, + temp_vec, + /*axis*/ 0, + funcs::SubtractFunctor(), + d_x); + + // dy_var_dx + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + d_y, + temp_norm, + /*axis*/ 0, + funcs::MultiplyFunctor(), + &temp); + } + // dy_var_dx + row_mean(dev_ctx, temp, &temp_vec); + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + temp_norm, + temp_vec, + /*axis*/ 0, + funcs::MultiplyFunctor(), + &temp); + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, *d_x, temp, /*axis*/ 0, funcs::SubtractFunctor(), d_x); + + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + *d_x, + variance, + /*axis*/ 0, + funcs::DivAndSqrtFunctor(static_cast(epsilon)), + d_x); + d_x->Resize(dx_dim); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + layer_norm_grad, CPU, ALL_LAYOUT, phi::LayerNormGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/cpu/layer_norm_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b09d68c7ca081e9a6157857eea8338aaa93d34d --- /dev/null +++ b/paddle/phi/kernels/cpu/layer_norm_kernel.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/layer_norm_kernel.h" +#include "paddle/phi/kernels/cpu/elementwise.h" +#include "paddle/phi/kernels/funcs/layer_norm_util.h" +#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) +#include "paddle/fluid/operators/jit/kernels.h" +#endif +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void LayerNormKernel(const Context& dev_ctx, + const DenseTensor& x, + paddle::optional scale_opt, + paddle::optional bias_opt, + float epsilon, + int begin_norm_axis, + bool is_test, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + const auto x_dims = x.dims(); + auto* scale = scale_opt.get_ptr(); + auto* bias = bias_opt.get_ptr(); + + dev_ctx.template Alloc(y); + dev_ctx.template Alloc(mean); + dev_ctx.template Alloc(var); + + auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + DDim matrix_shape({left, right}); + + auto x_tmp = x; + x_tmp.Resize(matrix_shape); + DenseTensor out; + out.ShareDataWith(*y); + out.Resize(matrix_shape); + +#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \ + defined(__OSX__) + + funcs::RowwiseMean2D row_mean(left, right, dev_ctx); + + // get mean + row_mean(dev_ctx, x_tmp, mean); + + // get variance + + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, x_tmp, *mean, 0, funcs::SubAndSquareFunctor(), &out); + + row_mean(dev_ctx, out, var); + + // get x_norm + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, x_tmp, *mean, 0, funcs::SubtractFunctor(), &out); + + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, + out, + *var, + 0, + funcs::DivAndSqrtFunctor(static_cast(epsilon)), + &out); + + if (scale) { + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, out, *scale, 1, funcs::MultiplyFunctor(), &out); + } + if (bias) { + phi::funcs::ElementwiseCompute, T, T>( + dev_ctx, out, *bias, 1, funcs::AddFunctor(), &out); + } +#else + PADDLE_ENFORCE_EQ(mean->numel(), + left, + phi::errors::InvalidArgument( + "mean's length (%d) is not equal with expected (%d).", + mean->numel(), + left)); + PADDLE_ENFORCE_EQ(var->numel(), + left, + phi::errors::InvalidArgument( + "var's length (%d) is not equal with expected (%d).", + var->numel(), + left)); + if (scale) { + PADDLE_ENFORCE_EQ( + scale->numel(), + right, + phi::errors::InvalidArgument( + "scale's length (%d) is not equal with expected (%d).", + scale->numel(), + right)); + } + if (bias) { + PADDLE_ENFORCE_EQ(bias->numel(), + right, + phi::errors::InvalidArgument( + "bias's length (%d) is not equal with expected (%d).", + bias->numel(), + right)); + } + + auto ker = paddle::operators::jit::KernelFuncs< + paddle::operators::jit::LayerNormTuple, + phi::CPUPlace>::Cache() + .At(right); + ker(x_tmp.data(), + out.data(), + mean->data(), + var->data(), + scale ? scale->data() : nullptr, + bias ? bias->data() : nullptr, + static_cast(left), + static_cast(epsilon), + right); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL( + layer_norm, CPU, ALL_LAYOUT, phi::LayerNormKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/layer_norm_util.h b/paddle/phi/kernels/funcs/layer_norm_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e78730cbf38495637e4bd4c455a3f522b38a9017 --- /dev/null +++ b/paddle/phi/kernels/funcs/layer_norm_util.h @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +namespace funcs { + +// Wrap RowwiseMean and ColwiseMean. +// Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is +// significantly faster. Unlike the RowwiseMean and ColwiseMean, the +// implementation only considers 2D. +template +struct RowwiseMean2D { + RowwiseMean2D(int left, int right, const DeviceContext& dev_ctx); + + void operator()(const DeviceContext& context, + const DenseTensor& input, + DenseTensor* vec); +}; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +class RowwiseMean2D { + public: + RowwiseMean2D(int left, int right, const DeviceContext& dev_ctx) + : left_(left), right_(right) { + DDim ones_dim({right_}); + divisor_.Resize(ones_dim); + dev_ctx.template Alloc(&divisor_); + phi::funcs::set_constant(dev_ctx, &divisor_, 1.0 / right); + } + void operator()(const phi::GPUContext& context, + const DenseTensor& input, + DenseTensor* out) { + phi::funcs::GetBlas(context).GEMV(false, + left_, + right_, + 1., + input.data(), + divisor_.data(), + 0., + out->data()); + } + + private: + int left_; + int right_; + DenseTensor divisor_; +}; +#endif + +template +class RowwiseMean2D { + public: + RowwiseMean2D(int left, int right, const DeviceContext& dev_ctx) {} + + void operator()(const phi::CPUContext& context, + const DenseTensor& input, + DenseTensor* out) { + row_mean_(context, input, out); + } + + private: + phi::funcs::RowwiseMean row_mean_; +}; + +template +struct ColwiseSum2D { + ColwiseSum2D(int left, int right, const DeviceContext& dev_ctx); + + void operator()(const phi::DeviceContext& context, + const DenseTensor& input, + DenseTensor* vec); +}; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +class ColwiseSum2D { + public: + ColwiseSum2D(int left, int right, const phi::GPUContext& dev_ctx) + : left_(left), right_(right) { + DDim ones_dim({left_}); + divisor_.Resize(ones_dim); + dev_ctx.template Alloc(&divisor_); + phi::funcs::set_constant(dev_ctx, &divisor_, 1.0); + } + + void operator()(const phi::GPUContext& context, + const DenseTensor& input, + DenseTensor* out) { + phi::funcs::GetBlas(context).GEMV(true, + left_, + right_, + 1., + input.data(), + divisor_.data(), + 0., + out->data()); + } + + private: + int left_; + int right_; + DenseTensor divisor_; +}; +#endif + +template +class ColwiseSum2D { + public: + ColwiseSum2D(int left, int right, const phi::CPUContext& dev_ctx) {} + + void operator()(const phi::CPUContext& context, + const DenseTensor& input, + DenseTensor* out) { + col_wise_(context, input, out); + } + + private: + phi::funcs::ColwiseSum col_wise_; +}; + +template +struct SubAndSquareFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } +}; + +template +struct DivAndSqrtFunctor { + explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; } + inline HOSTDEVICE T operator()(T a, T b) const { + return a / (sqrt(b + epsilon_)); + } + + private: + T epsilon_; +}; + +template +struct MulInvVarFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { + return a * std::sqrt(1.0 / b); + } +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 4201a75be8ac7ee9f7e633f6def1e002ce4b7e8a..afa2214f5b9df968d9fe01f6310e151c12e19362 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -331,12 +331,20 @@ template struct ColwiseSum; template struct ColwiseSum; template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; + template struct RowwiseSum; template struct RowwiseSum; template struct RowwiseMean; template struct RowwiseMean; +template struct RowwiseMean; +template struct RowwiseMean; + template struct ElementwiseAddTo { void operator()(paddle::platform::CPUDeviceContext* ctx, diff --git a/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c3f7a5261712a1d33bb4ad47dd080a489b303717 --- /dev/null +++ b/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu @@ -0,0 +1,139 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/layer_norm_grad_kernel.h" + +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/layer_norm_util.h" + +namespace phi { + +template +void LayerNormGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &mean, + const DenseTensor &variance, + paddle::optional scale_opt, + paddle::optional bias_opt, + const DenseTensor &out_grad, + float epsilon, + int begin_norm_axis, + bool is_test, + DenseTensor *x_grad, + DenseTensor *scale_grad, + DenseTensor *bias_grad) { + using U = paddle::operators::LayerNormParamType; + // d_x, d_scale, d_bias may be nullptr + auto *d_x = x_grad; + auto *d_scale = scale_grad; + auto *d_bias = bias_grad; + + auto *scale = scale_opt.get_ptr(); + auto *bias = bias_opt.get_ptr(); + auto *d_y = &out_grad; + + const auto &x_dims = x.dims(); + auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); + int64_t batch_size = static_cast(matrix_dim[0]); + int64_t feature_size = static_cast(matrix_dim[1]); + + auto *x_data = x.data(); + auto *d_y_data = d_y->data(); + + auto *mean_data = mean.data(); + auto *var_data = variance.data(); + + auto *d_x_data = (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); + + auto x_dtype = x.dtype(); + + phi::DataType scale_bias_dtype; + if (scale != nullptr) { + scale_bias_dtype = scale->dtype(); + } else { + // FIXME(zengjinle): do not find a better way to get the right + // data type of the d_scale and d_bias if scale == nullptr. + if (bias != nullptr) { + scale_bias_dtype = bias->dtype(); + } else { + scale_bias_dtype = x_dtype; + } + } + +#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ + do { \ + auto *scale_data = \ + (scale == nullptr ? nullptr : scale->data()); \ + auto *d_scale_data = \ + (d_scale == nullptr ? nullptr \ + : dev_ctx.template Alloc(d_scale)); \ + auto *d_bias_data = \ + (d_bias == nullptr ? nullptr \ + : dev_ctx.template Alloc(d_bias)); \ + auto *d_x_data = \ + (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); \ + paddle::operators::LayerNormBackward( \ + x_data, \ + d_y_data, \ + scale_data, \ + mean_data, \ + var_data, \ + d_x_data, \ + d_scale_data, \ + d_bias_data, \ + epsilon, \ + batch_size, \ + feature_size, \ + dev_ctx); \ + } while (0) + + if (scale_bias_dtype == x_dtype) { + PADDLE_LAUNCH_LAYERNORM_BWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_BWD(U, false); + } + +#undef PADDLE_LAUNCH_LAYERNORM_BWD +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +PD_REGISTER_KERNEL(layer_norm_grad, + GPU, + ALL_LAYOUT, + phi::LayerNormGradKernel, + float, + phi::dtype::float16) {} +#elif CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(layer_norm_grad, + GPU, + ALL_LAYOUT, + phi::LayerNormGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else +PD_REGISTER_KERNEL(layer_norm_grad, + GPU, + ALL_LAYOUT, + phi::LayerNormGradKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/layer_norm_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d87b7c2193811cd6cf8138d1904c7fce01d3884a --- /dev/null +++ b/paddle/phi/kernels/gpu/layer_norm_kernel.cu @@ -0,0 +1,229 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/layer_norm_kernel.h" + +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/layer_norm_util.h" + +namespace phi { + +template +void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, + const T *input, + std::vector input_shape, + const T *bias, + const T *scale, + T *output, + T *mean, + T *variance, + int begin_norm_axis, + float eps) { + const auto x_dims = phi::make_ddim(input_shape); + auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); + int64_t batch_size = static_cast(matrix_dim[0]); + int64_t feature_size = static_cast(matrix_dim[1]); + switch (paddle::operators::GetDesiredBlockDim(feature_size)) { + FIXED_BLOCK_DIM_CASE(paddle::operators::LayerNormForward< + T, + T, + kBlockDim><<>>( + input, scale, bias, output, mean, variance, eps, feature_size)); + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "Product from begin_norm_axis to end in layer_norm must be larger " + "than 1")); + break; + } +} + +template class LayerNormDirectCUDAFunctor; + +template +void LayerNormKernel(const Context &dev_ctx, + const DenseTensor &x, + paddle::optional scale_opt, + paddle::optional bias_opt, + float epsilon, + int begin_norm_axis, + bool is_test, + DenseTensor *y, + DenseTensor *mean, + DenseTensor *var) { + using U = paddle::operators::LayerNormParamType; + auto *scale = scale_opt.get_ptr(); + auto *bias = bias_opt.get_ptr(); + + const auto x_dims = x.dims(); + auto *x_data = x.data(); + auto *y_data = dev_ctx.template Alloc(y); + auto *mean_data = dev_ctx.template Alloc(mean); + auto *var_data = dev_ctx.template Alloc(var); + + auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); + auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); + + auto x_dtype = x.dtype(); + phi::DataType scale_bias_dtype; + if (void_scale_data != nullptr) { + scale_bias_dtype = scale->dtype(); + if (void_bias_data != nullptr) { + PADDLE_ENFORCE_EQ( + scale->dtype(), + bias->dtype(), + phi::errors::InvalidArgument("Thie Scale and Bias of layer_norm op " + "should have the same data type.")); + } + } else { + scale_bias_dtype = (void_bias_data != nullptr ? bias->dtype() : x_dtype); + } + + bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype; + if (!is_scale_bias_same_dtype_with_x) { + PADDLE_ENFORCE_EQ(scale_bias_dtype, + paddle::experimental::CppTypeToDataType::Type(), + phi::errors::InvalidArgument( + "Unsupported data type of Scale and Bias")); + } + + auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); + int64_t batch_size = static_cast(matrix_dim[0]); + int64_t feature_size = static_cast(matrix_dim[1]); + + auto stream = dev_ctx.stream(); + +#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ + do { \ + switch (paddle::operators::GetDesiredBlockDim(feature_size)) { \ + FIXED_BLOCK_DIM_CASE(paddle::operators::LayerNormForward< \ + T, \ + U, \ + kBlockDim, \ + IsScaleBiasSameDTypeWithX><<>>( \ + x_data, \ + static_cast(void_scale_data), \ + static_cast(void_bias_data), \ + y_data, \ + mean_data, \ + var_data, \ + epsilon, \ + feature_size)); \ + default: \ + PADDLE_THROW(phi::errors::InvalidArgument( \ + "Product from begin_norm_axis to end must be larger than 1")); \ + break; \ + } \ + } while (0) + +#ifdef PADDLE_WITH_CUDA + bool can_call_1024_kernel = false; + if (feature_size == 1024 && scale != nullptr && bias != nullptr) { + can_call_1024_kernel = true; + } + if (can_call_1024_kernel) { + const int WARPS_M = 4; + const int WARPS_N = 1; + const int THREADS_PER_WARP = 32; + const int BYTES_PER_LDG = 16; + const int VecSize = BYTES_PER_LDG / sizeof(T); + + const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; + const int ROWS_PER_CTA = WARPS_M; + + const int grid = static_cast( + std::ceil(batch_size / static_cast(ROWS_PER_CTA))); + if (is_scale_bias_same_dtype_with_x) { + paddle::operators::ln_fwd_1024_kernel< + T, + U, + T, + VecSize, + WARPS_M, + WARPS_N, + BYTES_PER_LDG><<>>( + batch_size, + feature_size, + epsilon, + x_data, + static_cast(void_scale_data), + static_cast(void_bias_data), + mean_data, + var_data, + y_data); + } else { + paddle::operators::ln_fwd_1024_kernel< + T, + U, + U, + VecSize, + WARPS_M, + WARPS_N, + BYTES_PER_LDG><<>>( + batch_size, + feature_size, + epsilon, + x_data, + static_cast(void_scale_data), + static_cast(void_bias_data), + mean_data, + var_data, + y_data); + } + } else { +#endif + if (is_scale_bias_same_dtype_with_x) { + PADDLE_LAUNCH_LAYERNORM_FWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_FWD(U, false); + } +#ifdef PADDLE_WITH_CUDA + } +#endif + +#undef PADDLE_LAUNCH_LAYERNORM_FWD +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +PD_REGISTER_KERNEL(layer_norm, + GPU, + ALL_LAYOUT, + phi::LayerNormKernel, + float, + phi::dtype::float16) {} +#elif CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(layer_norm, + GPU, + ALL_LAYOUT, + phi::LayerNormKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else +PD_REGISTER_KERNEL(layer_norm, + GPU, + ALL_LAYOUT, + phi::LayerNormKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/layer_norm_grad_kernel.h b/paddle/phi/kernels/layer_norm_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c32be63db4178f92d9564f357c30bb28fb415516 --- /dev/null +++ b/paddle/phi/kernels/layer_norm_grad_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LayerNormGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& mean, + const DenseTensor& variance, + paddle::optional scale, + paddle::optional bias, + const DenseTensor& out_grad, + float epsilon, + int begin_norm_axis, + bool is_test, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/layer_norm_kernel.h b/paddle/phi/kernels/layer_norm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c9679420bda5cf6beffb56b7ec319c1b80ac4eda --- /dev/null +++ b/paddle/phi/kernels/layer_norm_kernel.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LayerNormKernel(const Context& ctx, + const DenseTensor& x, + paddle::optional scale, + paddle::optional bias, + float epsilon, + int begin_norm_axis, + bool is_test, + DenseTensor* out, + DenseTensor* mean, + DenseTensor* variance); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +class LayerNormDirectCUDAFunctor { + public: + void operator()(gpuStream_t stream, + const T* input, + std::vector input_shape, + const T* bias, + const T* scale, + T* output, + T* mean, + T* variance, + int begin_norm_axis, + float eps); +}; +#endif + +} // namespace phi diff --git a/paddle/phi/ops/compat/layer_norm_sig.cc b/paddle/phi/ops/compat/layer_norm_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..17a81e9ec012f2c116762ff2d653bb96f0e1c4f4 --- /dev/null +++ b/paddle/phi/ops/compat/layer_norm_sig.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("layer_norm", + {"X", "Scale", "Bias"}, + {"epsilon", "begin_norm_axis", "is_test"}, + {"Y", "Mean", "Variance"}); +} + +KernelSignature LayerNormGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "layer_norm_grad", + {"X", "Mean", "Variance", "Scale", "Bias", GradVarName("Y")}, + {"epsilon", "begin_norm_axis", "is_test"}, + {GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(layer_norm, phi::LayerNormOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(layer_norm_grad, + phi::LayerNormGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index ca9a489c7496f33cb084f1cd43158cebc7a1add6..b75dc2c964ca0b22219de1b33cdbfc3d74c19e45 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -215,6 +215,8 @@ class TestLayerNormOp(unittest.TestCase): for name in ['x', 'scale', 'bias', 'y@GRAD'] }, fetch_list=fetch_list) + # print(y) + # print(out[0]) self.__assert_close(y, out[0], "y") self.__assert_close(mean, out[1], "mean") self.__assert_close(variance, out[2], "variance", 1e-3) @@ -238,6 +240,7 @@ class TestLayerNormOp(unittest.TestCase): def test_check_forward_backward_with_scale_and_bias(self): self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1) + self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward( shape=[2, 3, 4, 5], @@ -432,4 +435,5 @@ class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main()