diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 826d0e773a8a5f7bb37f41adfbe2f821b8f2d61b..303830e928cf0a25dcbbcdcf262275a73a202d84 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -103,6 +103,48 @@ function(register_cu_kernel TARGET) endforeach() endfunction() +# Just for those mkldnn kernels locating at "fluid/operators/mkldnn/", such as 'layer_norm_mkldnn_op.cc'. +# Add other file modes if need in the future. +function(register_mkldnn_kernel TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(register_mkldnn_kernel "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + set(mkldnn_cc_srcs) + set(op_common_deps operator op_registry math_function layer + common_infer_shape_functions) + foreach(mkldnn_src ${register_mkldnn_kernel_SRCS}) + if(${mkldnn_src} MATCHES ".*_mkldnn_op.cc$") + list(APPEND mkldnn_cc_srcs mkldnn/${mkldnn_src}) + endif() + endforeach() + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + if(${mkldnn_cc_srcs_len} EQUAL 0) + message( + FATAL_ERROR + "The MKLDNN kernel file of ${TARGET} should contains at least one *.*_mkldnn_op.cc file" + ) + endif() + if(WITH_MKLDNN) + cc_library( + ${TARGET} + SRCS ${mkldnn_cc_srcs} + DEPS ${op_library_DEPS} ${op_common_deps}) + endif() + set(OP_LIBRARY + ${TARGET} ${OP_LIBRARY} + CACHE INTERNAL "op libs") + foreach(mkldnn_src ${mkldnn_cc_srcs}) + set(op_name "") + find_register(${mkldnn_src} "REGISTER_OP_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MKLDNN);\n") + endif() + endforeach() +endfunction() + function(op_library TARGET) # op_library is a function to create op library. The interface is same as # cc_library. But it handle split GPU/CPU code and link some common library diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 0cefb7e69ace31ce956c633c6e2697d3a125a199..b12ec19b9b9df6e28aa974b2904babbf188fe8c1 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -106,6 +106,10 @@ if (WITH_GPU OR WITH_ROCM) register_cu_kernel(class_center_sample_op SRCS class_center_sample_op.cu DEPS ${OP_HEADER_DEPS}) endif() +if (WITH_MKLDNN) + register_mkldnn_kernel(layer_norm_op SRCS layer_norm_mkldnn_op.cc DEPS ${OP_HEADER_DEPS}) +endif() + if (WITH_GPU OR WITH_ROCM) op_library(activation_op SRCS activation_op.cc activation_op.kps soft_relu_op.cu DEPS ${OP_HEADER_DEPS}) elseif (WITH_XPU_KP) diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index 026d57bba4f7de7ab04d46720fcc7d7847aef1f2..28518859f48552a61d0376492ebb197d1898d4c4 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -250,5 +250,21 @@ phi::KernelKey GetInstanceNormExpectedKernelType( return phi::KernelKey(input_data_type, ctx.GetPlace()); } +phi::KernelKey GetLayerNormExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + auto input_data_type = + op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN + int begin_norm_axis = ctx.Attr("begin_norm_axis"); + if (begin_norm_axis != ctx.Input("X")->dims().size() - 1) { + op_ptr->SetDnnFallback(true); + } + // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN + + return phi::KernelKey(input_data_type, ctx.GetPlace()); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index 7923c8d79fb2ded0f938282d7cdabba690dd1f14..bb228dbe2e97a47842e6b8eae6f736e760e82fd7 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -64,5 +64,9 @@ phi::KernelKey GetYoloLossExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); +phi::KernelKey GetLayerNormExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc deleted file mode 100644 index facef32fa3b5c440db9af5b460ca554c3cb55768..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/layer_norm_op.cc +++ /dev/null @@ -1,336 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" -#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" -#include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/ternary.h" - -namespace paddle { -namespace operators { - -using DataLayout = phi::DataLayout; - -class LayerNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LayerNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "LayerNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Mean"), "Output", "Mean", "LayerNorm"); - OP_INOUT_CHECK( - ctx->HasOutput("Variance"), "Output", "Variance", "LayerNorm"); - - auto x_dim = ctx->GetInputDim("X"); - auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); - PADDLE_ENFORCE_LT( - begin_norm_axis, - x_dim.size(), - platform::errors::InvalidArgument( - "'begin_norm_axis' must be less than the dimensions of X," - "But received 'begin_norm_axis' is [%d]," - "received the dimensions of X is [%d].", - begin_norm_axis, - x_dim.size())); - - auto matrix_dim = phi::flatten_to_2d(x_dim, begin_norm_axis); - int left = static_cast(matrix_dim[0]); - int right = static_cast(matrix_dim[1]); - if (ctx->HasInput("Scale")) { - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), - 1, - platform::errors::InvalidArgument( - "The dimensions of Input(Scale) must be 1, but " - "received dimensions of" - "Input(Scale) is [%d]", - ctx->GetInputDim("Scale").size())); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Scale")[0], - right, - platform::errors::InvalidArgument( - "The first dimension value of Input(Scale) must equal to be the" - "second dimension value of the flattened 2D matrix of Input(X)," - "But received the first dimension value of Input(Scale) is" - "[%d], the second dimension value of the flattened 2D matrix of" - " Input(Scale) is [%d].", - ctx->GetInputDim("Scale")[0], - right)); - } - } - if (ctx->HasInput("Bias")) { - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), - 1, - platform::errors::InvalidArgument( - "The dimensions of Input(Bias) must be 1, but " - "received dimensions of" - "Input(Bias) is [%d]", - ctx->GetInputDim("Bias").size())); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Bias")[0], - right, - platform::errors::InvalidArgument( - "The first dimension value of Input(Bias) must equal to be the" - "second dimension value of the flattened 2D matrix of Input(X)," - "But received the first dimension value of Input(Bias) is" - "[%d], the second dimension value of the flattened 2D matrix of" - " Input(Bias) is [%d].", - ctx->GetInputDim("Scale")[0], - right)); - } - } - - ctx->SetOutputDim("Y", ctx->GetInputDim("X")); - ctx->SetOutputDim("Mean", {left}); - ctx->SetOutputDim("Variance", {left}); - ctx->ShareLoD("X", "Y"); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - - // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN - int begin_norm_axis = ctx.Attr("begin_norm_axis"); - if (begin_norm_axis != - ctx.Input("X")->dims().size() - 1) { - this->SetDnnFallback(true); - } - // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input tensor."); - AddInput("Scale", - "(optional) Scale is a 1-dimensional tensor of size " - "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." - "It is applied to the output.") - .AsDispensable(); - AddInput("Bias", - "(optional) Bias is a 1-dimensional tensor of size " - "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." - "It is applied to the output.") - .AsDispensable(); - AddOutput("Y", "Result after normalization."); - AddOutput("Mean", "Mean of the current mini batch.").AsIntermediate(); - AddOutput("Variance", "Variance of the current mini batch.") - .AsIntermediate(); - - AddAttr("epsilon", - "Constant for numerical stability [default 1e-5].") - .SetDefault(1e-5) - .AddCustomChecker([](const float &epsilon) { - PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, - true, - platform::errors::InvalidArgument( - "'epsilon' in Op(LayerNorm) should be between" - "0.0 and 0.001, But received [%s].", - epsilon)); - }); - AddAttr("begin_norm_axis", - "the axis of `begin_norm_axis ... Rank(X) - 1` will be " - "normalized. `begin_norm_axis` splits the tensor(`X`) to a " - "matrix [N,H]. [default 1].") - .SetDefault(1) - .AddCustomChecker([](const int &begin_norm_axis) { - PADDLE_ENFORCE_GT(begin_norm_axis, - 0, - platform::errors::InvalidArgument( - "'begin_norm_axis' in Op(LayerNorm) should be" - "greater than zero. But received [%d].", - begin_norm_axis)); - }); - AddComment(R"DOC( -Assume feature vectors exist on dimensions -:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics -along these dimensions for each feature vector :math:`a` with size -:math:`H`, then normalize each feature vector using the corresponding -statistics. After that, apply learnable gain and bias on the normalized -tensor to scale and shift if :attr:`scale` and :attr:`shift` are set. - -Refer to `Layer Normalization `_ -)DOC"); - } -}; - -class LayerNormGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - // check input - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LayerNormGrad"); - OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "LayerNormGrad"); - OP_INOUT_CHECK( - ctx->HasInput("Variance"), "Input", "Variance", "LayerNormGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), - "Input", - framework::GradVarName("Y"), - "LayerNormGrad"); - - // check output - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } - if (ctx->HasOutput(framework::GradVarName("Scale"))) { - ctx->SetOutputDim(framework::GradVarName("Scale"), - ctx->GetInputDim("Scale")); - } - if (ctx->HasOutput(framework::GradVarName("Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Bias"), - ctx->GetInputDim("Bias")); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - const auto *var = ctx.InputVar(framework::GradVarName("Y")); - PADDLE_ENFORCE_NOT_NULL( - var, - platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found.")); - const phi::DenseTensor *t = nullptr; - if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &var->Get(); - } - PADDLE_ENFORCE_NOT_NULL( - t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found.")); - - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); - } -}; - -template -class LayerNormGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("layer_norm_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("Mean", this->Output("Mean")); - op->SetInput("Variance", this->Output("Variance")); - if (this->HasInput("Scale")) { - op->SetInput("Scale", this->Input("Scale")); - op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); - } - - if (this->HasInput("Bias")) { - op->SetInput("Bias", this->Input("Bias")); - op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); - } - - op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer, - "Bias"); - -class LayerNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { - using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; - - public: - void Apply() override { - // get inputs - paddle::Tensor x = this->GetSingleForwardInput("X"); - paddle::Tensor mean = this->GetSingleForwardOutput("Mean"); - paddle::Tensor var = this->GetSingleForwardOutput("Variance"); - paddle::Tensor y_grad = this->GetSingleOutputGrad("Y"); - paddle::optional scale = - this->GetOptionalSingleForwardInput("Scale"); - paddle::optional bias = - this->GetOptionalSingleForwardInput("Bias"); - - // get Attrs - auto epsilon = this->Attr("epsilon"); - auto begin_norm_axis = this->Attr("begin_norm_axis"); - - // get outputs - paddle::Tensor x_grad = this->GetSingleInputGrad("X"); - paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale"); - paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias"); - - auto dx_ptr = this->GetOutputPtr(&x_grad); - std::string dx_name = this->GetOutputName(x_grad); - auto dscale_ptr = this->GetOutputPtr(&scale_grad); - std::string dscale_name = this->GetOutputName(scale_grad); - auto dbias_ptr = this->GetOutputPtr(&bias_grad); - std::string dbias_name = this->GetOutputName(bias_grad); - - VLOG(6) << "Runing layer_norm_grad composite func"; - prim::layer_norm_grad(x, - scale, - bias, - mean, - var, - y_grad, - epsilon, - begin_norm_axis, - dx_ptr, - dscale_ptr, - dbias_ptr); - - this->RecoverOutputName(x_grad, dx_name); - this->RecoverOutputName(scale_grad, dscale_name); - this->RecoverOutputName(bias_grad, dbias_name); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(layer_norm, - LayerNormInferShapeFunctor, - PD_INFER_META(phi::LayerNormInferMeta)); - -REGISTER_OPERATOR(layer_norm, - ops::LayerNormOp, - ops::LayerNormOpMaker, - ops::LayerNormGradOpMaker, - ops::LayerNormGradOpMaker, - ops::LayerNormCompositeGradOpMaker, - LayerNormInferShapeFunctor); - -DECLARE_INFER_SHAPE_FUNCTOR(layer_norm_grad, - LayerNormGradInferShapeFunctor, - PD_INFER_META(phi::LayerNormGradInferMeta)); - -REGISTER_OPERATOR(layer_norm_grad, - ops::LayerNormGradOp, - ops::LayerNormGradNoNeedBufferVarInferer, - LayerNormGradInferShapeFunctor); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 740402b155fc270e4f19770bb5d541d7254e6f97..6c5f4ccbaa4ad24ed89d232586020989b0bb402e 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -135,7 +135,7 @@ register_unity_group( kron_op.cc l1_norm_op.cc label_smooth_op.cc - layer_norm_op.cc + generated_op mkldnn/layer_norm_mkldnn_op.cc mkldnn/layer_norm_mkldnn_op.cc linspace_op.cc diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 895f0ccb112af0be06a822def0b3f202b827011a..661f59ef6a7d6c40c0d6741c49642e5d6c241b54 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -973,6 +973,20 @@ kernel : func : label_smooth_grad +- backward_op : layer_norm_grad + forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon = 1e-5, int begin_norm_axis = 1) -> Tensor(out), Tensor(mean), Tensor(variance) + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon = 1e-5, int begin_norm_axis = 1) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : LayerNormGradInferMeta + param : [x, scale, bias] + kernel : + func : layer_norm_grad + data_type : x + composite : layer_norm_grad(x, scale, bias, mean, variance, out_grad, epsilon, begin_norm_axis, x_grad, scale_grad, bias_grad) + no_need_buffer : bias + optional : scale, bias + - backward_op : leaky_relu_double_grad forward : leaky_relu_grad (Tensor x, Tensor grad_out, float negative_slope) -> Tensor(grad_x) args : (Tensor x, Tensor grad_x_grad, float negative_slope) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 3b39db07136ed0a79f4eca8e98058e621a1268a3..60e5354dd5354f624f590990fdc4d0b584f614af 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -454,20 +454,6 @@ kernel : func : hsigmoid_loss_grad -- backward_op : layer_norm_grad - forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance) - args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis) - output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) - infer_meta : - func : LayerNormGradInferMeta - param : [x, scale, bias] - kernel : - func : layer_norm_grad - data_type : out_grad - composite : layer_norm_grad(x, scale, bias, mean,varience, out_grad, epsilon, begin_norm_axis, x_grad, scale_grad, bias_grad) - no_need_buffer : bias - optional : scale, bias - - backward_op : logsumexp_grad forward : logsumexp(Tensor x, int64_t[] axis, bool keepdim, bool reduce_all) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keepdim, bool reduce_all) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 0dff2ba6c24e31fd8535d2b62f8c978cedc3112f..921c28caf4798fd1316240d8f93c98e6e8bca241 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -588,17 +588,6 @@ func : increment inplace : (x -> out) -- op : layer_norm - args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) - output : Tensor(out), Tensor(mean), Tensor(variance) - infer_meta : - func : LayerNormInferMeta - kernel : - func : layer_norm - data_type : x - backward : layer_norm_grad - optional : scale, bias - - op : less_equal args : (Tensor x, Tensor y) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 7087e913e94502d1c0578ed885623a45a964679d..edf08f5bb8fe5b5d64c52e37763eb1cde692667b 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1264,6 +1264,8 @@ variance : Variance extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] + get_expected_kernel_type : + layer_norm : GetLayerNormExpectedKernelType - op : leaky_relu backward : leaky_relu_grad, leaky_relu_double_grad (leaky_relu_grad_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 8928a825e46364fcdc0effd711775f2ba2808a9b..c98c9b910d5660280382fcb3f450a14790adca3b 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1135,6 +1135,18 @@ optional : master_param, skip_update, beta1_pow_out, beta2_pow_out, master_param_outs inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs) +- op : layer_norm + args : (Tensor x, Tensor scale, Tensor bias, float epsilon = 1e-5, int begin_norm_axis = 1) + output : Tensor(out), Tensor(mean), Tensor(variance) + infer_meta : + func : LayerNormInferMeta + kernel : + func : layer_norm + data_type : x + backward : layer_norm_grad + intermediate : mean, variance + optional : scale, bias + - op : leaky_relu args : (Tensor x, float negative_slope = 0.02f) output : Tensor diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 646767703154e1878547163773fc04289f9fa539..95a9c9546746cae83758167336ddaf26dcb2b4ae 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -529,6 +529,12 @@ void LayerNormInferMeta(const MetaTensor& x, MetaTensor* variance, MetaConfig config) { auto x_dim = x.dims(); + PADDLE_ENFORCE_GT(begin_norm_axis, + 0, + phi::errors::InvalidArgument( + "'begin_norm_axis' in Op(LayerNorm) should be" + "greater than zero. But received [%d].", + begin_norm_axis)); PADDLE_ENFORCE_LT( begin_norm_axis, x_dim.size(), @@ -588,6 +594,13 @@ void LayerNormInferMeta(const MetaTensor& x, right)); } + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + phi::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + phi::DataType x_dtype = x.dtype(); out->set_dims(x_dim); out->set_dtype(x_dtype); diff --git a/paddle/phi/ops/compat/layer_norm_sig.cc b/paddle/phi/ops/compat/layer_norm_sig.cc deleted file mode 100644 index d2e75a700d2b80fdc0f858ddea6bcbdba062dd8d..0000000000000000000000000000000000000000 --- a/paddle/phi/ops/compat/layer_norm_sig.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("layer_norm", - {"X", "Scale", "Bias"}, - {"epsilon", "begin_norm_axis"}, - {"Y", "Mean", "Variance"}); -} - -KernelSignature LayerNormGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("layer_norm_grad", - {"X", "Scale", "Bias", "Mean", "Variance", "Y@GRAD"}, - {"epsilon", "begin_norm_axis"}, - {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(layer_norm, phi::LayerNormOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(layer_norm_grad, - phi::LayerNormGradOpArgumentMapping); diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index ba57e53cf69f460bbd182d7650b89bbcac01f5c3..5cf69fb42b6c52e0b1dfd0d42926e4a17a81fa90 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -332,7 +332,7 @@ def layer_norm( ) if in_dygraph_mode(): - out, _, _ = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis) + out = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis) return out else: diff --git a/test/cpp/fluid/fused/CMakeLists.txt b/test/cpp/fluid/fused/CMakeLists.txt index 4bbcfbb93d2f8dfb9180db77c3e2df8ca0e749a0..ff239f2f0c6a19cbefda38ffa8e6a75474986442 100644 --- a/test/cpp/fluid/fused/CMakeLists.txt +++ b/test/cpp/fluid/fused/CMakeLists.txt @@ -13,7 +13,7 @@ if(WITH_GPU OR WITH_ROCM) DEPS tensor op_registry dropout_op - layer_norm_op + generated_op device_context generator memory) @@ -23,7 +23,7 @@ if(WITH_GPU OR WITH_ROCM) DEPS tensor op_registry dropout_op - layer_norm_op + generated_op device_context generator memory) @@ -33,7 +33,7 @@ if(WITH_GPU OR WITH_ROCM) DEPS tensor op_registry dropout_op - layer_norm_op + generated_op device_context generator memory) diff --git a/test/prim/composite_ops/test_composite_layer_norm.py b/test/prim/composite_ops/test_composite_layer_norm.py index cca9dcfc4c52abac644198b3b75fadb93a4cf4b3..f56dbea1556838f47f5568a2f2f94130c8827ef7 100644 --- a/test/prim/composite_ops/test_composite_layer_norm.py +++ b/test/prim/composite_ops/test_composite_layer_norm.py @@ -234,27 +234,27 @@ class TestCompositelayer_norm(unittest.TestCase): b_p = paddle.to_tensor(b) expect = expect_forward(x_p, n_shape, w_p, b_p) - actual = self.cal_composite(x, n_shape, w, b) - - assert expect[0].numpy().dtype == actual[0].dtype - for i in range(3): - np.testing.assert_allclose( - expect[i].numpy(), - actual[i], - rtol=attrs.get_rtol("forward"), - atol=attrs.get_atol("forward"), - ) + actual, _a_mean, _a_var = self.cal_composite(x, n_shape, w, b) + + assert expect.numpy().dtype == actual.dtype + np.testing.assert_allclose( + expect.numpy(), + actual, + rtol=attrs.get_rtol("forward"), + atol=attrs.get_atol("forward"), + ) expect_2 = expect_forward(x_p, n_shape, None, None) - actual_2 = self.cal2_composite(x, n_shape, None, None) - assert expect_2[0].numpy().dtype == actual_2[0].dtype - for i in range(3): - np.testing.assert_allclose( - expect_2[i].numpy(), - actual_2[i], - rtol=attrs.get_rtol("forward"), - atol=attrs.get_atol("forward"), - ) + actual_2, _a_mean_2, _a_var_2 = self.cal2_composite( + x, n_shape, None, None + ) + assert expect_2.numpy().dtype == actual_2.dtype + np.testing.assert_allclose( + expect_2.numpy(), + actual_2, + rtol=attrs.get_rtol("forward"), + atol=attrs.get_atol("forward"), + ) def test_forward(self): for j in self.dtypes: diff --git a/tools/enforce/count_enforce_by_file.sh b/tools/enforce/count_enforce_by_file.sh index c1e2903c092ce4124c55566679e081dbe3a03445..fafc3516904d86be73e0c1bbfcbc4db4a3ff7c25 100644 --- a/tools/enforce/count_enforce_by_file.sh +++ b/tools/enforce/count_enforce_by_file.sh @@ -51,7 +51,6 @@ if [ "$1" != "" ]; then fi FILE_WHITE_LIST="\ - layer_norm_op.cc \ box_clip_op.cc \ box_clip_op.h \ random_crop_op.h \