From 7b81092b498274255ad92e0ac88e6029aa506758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 16 May 2023 14:17:17 +0800 Subject: [PATCH] [static op generation] InstanceNorm (#53340) * mv InstanceNorm * modify op_version.yaml * modify add Operator:: in get_expected_kernel_func.cc * rm gradexpectedkernel * add extra * add float epsilon=1e-5 --- .../generator/get_expected_kernel_func.cc | 30 +++ .../generator/get_expected_kernel_func.h | 4 + paddle/fluid/operators/instance_norm_op.cc | 238 ------------------ paddle/fluid/operators/instance_norm_op.h | 116 --------- paddle/phi/api/yaml/backward.yaml | 24 ++ paddle/phi/api/yaml/legacy_backward.yaml | 24 -- paddle/phi/api/yaml/legacy_ops.yaml | 12 - paddle/phi/api/yaml/op_compat.yaml | 4 + paddle/phi/api/yaml/op_version.yaml | 11 + paddle/phi/api/yaml/ops.yaml | 12 + paddle/phi/ops/compat/instance_norm_sig.cc | 56 ----- 11 files changed, 85 insertions(+), 446 deletions(-) delete mode 100644 paddle/fluid/operators/instance_norm_op.cc delete mode 100644 paddle/fluid/operators/instance_norm_op.h delete mode 100644 paddle/phi/ops/compat/instance_norm_sig.cc diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index 5c654d942c2..49697a48a17 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -181,5 +181,35 @@ phi::KernelKey GetUniqueExpectedKernelType( } } +phi::KernelKey GetInstanceNormExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + auto input_data_type = + op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X"); + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). + auto in_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + in_param_type = framework::proto::VarType::FP64; + } + if (ctx.HasInput("Scale")) { + PADDLE_ENFORCE_EQ(in_param_type, + framework::TransToProtoVarType( + ctx.Input("Scale")->dtype()), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + } + if (ctx.HasInput("Bias")) { + PADDLE_ENFORCE_EQ(in_param_type, + framework::TransToProtoVarType( + ctx.Input("Bias")->dtype()), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + } + + return phi::KernelKey(input_data_type, ctx.GetPlace()); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index a6fe0fc4b22..4ef88909984 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -48,6 +48,10 @@ phi::KernelKey GetUniqueExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); +phi::KernelKey GetInstanceNormExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + phi::KernelKey GetYoloLossExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc deleted file mode 100644 index 8d76b46968b..00000000000 --- a/paddle/fluid/operators/instance_norm_op.cc +++ /dev/null @@ -1,238 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/instance_norm_op.h" - -#include -#include -#include - -#include "paddle/fluid/framework/data_layout.h" -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" -#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" -#include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/ternary.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -phi::KernelKey InstanceNormOp::GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - // By default, the type of the scale, bias, mean, - // and var tensors should both be float. (For float or float16 input tensor) - // or double (For double input tensor). - auto in_param_type = framework::proto::VarType::FP32; - if (input_data_type == framework::proto::VarType::FP64) { - in_param_type = framework::proto::VarType::FP64; - } - if (ctx.HasInput("Scale")) { - PADDLE_ENFORCE_EQ(in_param_type, - framework::TransToProtoVarType( - ctx.Input("Scale")->dtype()), - platform::errors::InvalidArgument( - "Scale input should be of float type")); - } - if (ctx.HasInput("Bias")) { - PADDLE_ENFORCE_EQ(in_param_type, - framework::TransToProtoVarType( - ctx.Input("Bias")->dtype()), - platform::errors::InvalidArgument( - "Bias input should be of float type")); - } - - return phi::KernelKey(input_data_type, ctx.GetPlace()); -} - -void InstanceNormOpMaker::Make() { - AddAttr("epsilon", "") - .SetDefault(1e-5) - .AddCustomChecker([](const float &epsilon) { - PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, - true, - platform::errors::InvalidArgument( - "'epsilon' should be between 0.0 and 0.001.")); - }); - AddInput("X", "The input tensor"); - AddInput("Scale", - "Scale is a 1-dimensional tensor of size C " - "that is applied to the output") - .AsDispensable(); - AddInput("Bias", - "Bias is a 1-dimensional tensor of size C " - "that is applied to the output") - .AsDispensable(); - AddOutput("Y", "result after normalization"); - AddOutput("SavedMean", - "Mean of the current mini batch, " - "will apply to output when training") - .AsIntermediate() - .AsExtra(); - AddOutput("SavedVariance", - "Variance of the current mini batch, " - "will apply to output when training") - .AsIntermediate() - .AsExtra(); - AddComment(R"DOC( -Instance Normalization. - -Instance Norm has been implemented as disscussed in the paper: -https://arxiv.org/pdf/1607.08022.pdf -Can be used as a normalizer function for conv2d and fully_connected operations. -The required data format for this layer is as following: -NCHW `[batch, in_channels, in_height, in_width]` - -)DOC"); -} - -phi::KernelKey InstanceNormGradOp::GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { - const auto *var = ctx.InputVar(framework::GradVarName("Y")); - if (var == nullptr) { - PADDLE_THROW( - platform::errors::NotFound("cannot find gradient variable of Y")); - } - const phi::DenseTensor *t = nullptr; - if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &var->Get(); - } - if (t == nullptr) { - PADDLE_THROW( - platform::errors::InvalidArgument("gradient variable of Y is empty")); - } - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -phi::KernelKey InstanceNormDoubleGradOp::GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { - const auto *var = ctx.InputVar("DY"); - if (var == nullptr) { - PADDLE_THROW( - platform::errors::NotFound("cannot find gradient variable of Y")); - } - const phi::DenseTensor *t = nullptr; - if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &var->Get(); - } - if (t == nullptr) { - PADDLE_THROW( - platform::errors::InvalidArgument("gradient variable of Y is empty")); - } - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -class InstanceNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { - using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; - - public: - void Apply() override { - // inputs and outputs of batch_norm - paddle::Tensor x = this->GetSingleForwardInput("X"); - paddle::Tensor scale = this->GetSingleForwardInput("Scale"); - paddle::Tensor saved_mean = this->GetSingleForwardOutput("SavedMean"); - paddle::Tensor saved_variance = - this->GetSingleForwardOutput("SavedVariance"); - - paddle::Tensor y_grad = this->GetSingleOutputGrad("Y"); - paddle::Tensor x_grad = this->GetSingleInputGrad("X"); - paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale"); - paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias"); - - auto x_grad_ptr = this->GetOutputPtr(&x_grad); - std::string x_grad_name = this->GetOutputName(x_grad); - auto scale_grad_ptr = this->GetOutputPtr(&scale_grad); - std::string scale_grad_name = this->GetOutputName(scale_grad); - auto bias_grad_ptr = this->GetOutputPtr(&bias_grad); - std::string bias_grad_name = this->GetOutputName(bias_grad); - - auto epsilon = this->Attr("epsilon"); - - VLOG(3) << "Runing instance_norm composite func"; - prim::instance_norm_grad(x, - scale, - saved_mean, - saved_variance, - y_grad, - epsilon, - x_grad_ptr, - scale_grad_ptr, - bias_grad_ptr); - this->RecoverOutputName(x_grad, x_grad_name); - this->RecoverOutputName(scale_grad, scale_grad_name); - this->RecoverOutputName(bias_grad, bias_grad_name); - } -}; - -DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer, - {"DY", "DDY"}); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(instance_norm, - InstanceNormInferShapeFunctor, - PD_INFER_META(phi::InstanceNormInferMeta)); -DECLARE_INFER_SHAPE_FUNCTOR(instance_norm_grad, - InstanceNormGradInferShapeFunctor, - PD_INFER_META(phi::InstanceNormGradInferMeta)); -DECLARE_INFER_SHAPE_FUNCTOR( - instance_norm_grad_grad, - InstanceNormDoubleGradInferShapeFunctor, - PD_INFER_META(phi::InstanceNormDoubleGradInferMeta)); -REGISTER_OPERATOR(instance_norm, - ops::InstanceNormOp, - ops::InstanceNormOpMaker, - ops::InstanceNormOpInferVarType, - ops::InstanceNormGradMaker, - ops::InstanceNormGradMaker, - InstanceNormInferShapeFunctor, - ops::InstanceNormCompositeGradOpMaker); -REGISTER_OPERATOR(instance_norm_grad, - ops::InstanceNormGradOp, - ops::InstanceNormDoubleGradMaker, - ops::InstanceNormDoubleGradMaker, - InstanceNormGradInferShapeFunctor); -REGISTER_OPERATOR(instance_norm_grad_grad, - ops::InstanceNormDoubleGradOp, - ops::InstanceNormDoubleGradOpInplaceInferer, - InstanceNormDoubleGradInferShapeFunctor); - -REGISTER_OP_VERSION(instance_norm) - .AddCheckpoint( - R"ROC( - Change dispensable of attribute from False to True in instance_norm. - )ROC", - paddle::framework::compatible::OpVersionDesc() - .ModifyAttr( - "Bias", - "The arg 'dispensable' of Input 'Bias' is changed: from " - "'False' to 'True'.", - true) - .ModifyAttr( - "Scale", - "The arg 'dispensable' of Input 'Scale' is changed: from " - "'False' to 'True'.", - true)); diff --git a/paddle/fluid/operators/instance_norm_op.h b/paddle/fluid/operators/instance_norm_op.h deleted file mode 100644 index 9a885e47e40..00000000000 --- a/paddle/fluid/operators/instance_norm_op.h +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using DataLayout = phi::DataLayout; - -class InstanceNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override; -}; - -class InstanceNormGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override; -}; - -class InstanceNormDoubleGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override; -}; - -class InstanceNormOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -template -class InstanceNormGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("instance_norm_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - - op->SetInput("Scale", this->Input("Scale")); - op->SetInput("SavedMean", this->Output("SavedMean")); - op->SetInput("SavedVariance", this->Output("SavedVariance")); - - op->SetAttrMap(this->Attrs()); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); - op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); - } -}; - -template -class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("instance_norm_grad_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("Scale", this->Input("Scale")); - op->SetInput("SavedMean", this->Input("SavedMean")); - op->SetInput("SavedVariance", this->Input("SavedVariance")); - op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); - op->SetInput("DDScale", this->OutputGrad(framework::GradVarName("Scale"))); - op->SetInput("DDBias", this->OutputGrad(framework::GradVarName("Bias"))); - op->SetInput("DY", this->Input(framework::GradVarName("Y"))); - - op->SetAttrMap(this->Attrs()); - op->SetOutput("DX", this->InputGrad("X")); - op->SetOutput("DScale", this->InputGrad("Scale")); - op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y"))); - } -}; - -class InstanceNormOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map &GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", "Y"}}; - return m; - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 92ce63e4709..a75fe0e596a 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -851,6 +851,30 @@ data_type : out_grad no_need_buffer : x +- backward_op : instance_norm_double_grad + forward : instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias) + args : (Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon) + output : Tensor(x_grad), Tensor(fwd_scale_grad), Tensor(grad_y_grad) + infer_meta : + func : InstanceNormDoubleGradInferMeta + kernel : + func : instance_norm_double_grad + data_type : x + optional : fwd_scale, grad_x_grad, grad_scale_grad, grad_bias_grad + +- backward_op : instance_norm_grad + forward : instance_norm(Tensor x, Tensor scale, Tensor bias, float epsilon) -> Tensor(y), Tensor(saved_mean), Tensor(saved_variance) + args : (Tensor x, Tensor scale, Tensor saved_mean, Tensor saved_variance, Tensor y_grad, float epsilon=1e-5) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : InstanceNormGradInferMeta + kernel : + func : instance_norm_grad + data_type : x + optional : scale + backward : instance_norm_double_grad + composite: instance_norm_grad(x, scale, saved_mean, saved_variance, y_grad, epsilon, x_grad, scale_grad, bias_grad) + - backward_op : inverse_grad forward : inverse(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 37a3626861c..4c5106f963b 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -476,30 +476,6 @@ kernel : func : hsigmoid_loss_grad -- backward_op : instance_norm_double_grad - forward : instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias) - args : (Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon) - output : Tensor(x_grad), Tensor(fwd_scale_grad), Tensor(grad_y_grad) - infer_meta : - func : InstanceNormDoubleGradInferMeta - kernel : - func : instance_norm_double_grad - data_type : x - optional : fwd_scale, grad_x_grad, grad_scale_grad, grad_bias_grad - -- backward_op : instance_norm_grad - forward : instance_norm(Tensor x, Tensor scale, Tensor bias, float epsilon) -> Tensor(y), Tensor(saved_mean), Tensor(saved_variance) - args : (Tensor x, Tensor scale, Tensor saved_mean, Tensor saved_variance, Tensor y_grad, float epsilon) - output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) - infer_meta : - func : InstanceNormGradInferMeta - kernel : - func : instance_norm_grad - data_type : x - optional : scale - backward : instance_norm_double_grad - composite: instance_norm_grad(x, scale, saved_mean, saved_variance, y_grad, epsilon, x_grad, scale_grad, bias_grad) - - backward_op : layer_norm_grad forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index cf85f1b8bfe..408fe94987c 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -597,18 +597,6 @@ func : increment inplace : (x -> out) -- op : instance_norm - args : (Tensor x, Tensor scale, Tensor bias, float epsilon) - output : Tensor(y), Tensor(saved_mean), Tensor(saved_variance) - infer_meta : - func : InstanceNormInferMeta - kernel : - func : instance_norm - data_type : x - optional : scale, bias - intermediate : saved_mean, saved_variance - backward : instance_norm_grad - - op : layer_norm args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) output : Tensor(out), Tensor(mean), Tensor(variance) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 5ad73f4e104..ab8274ac601 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1171,6 +1171,10 @@ y : Y saved_mean : SavedMean saved_variance : SavedVariance + extra: + outputs: [ saved_mean, saved_variance ] + get_expected_kernel_type: + instance_norm: GetInstanceNormExpectedKernelType - op : inverse inputs : diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 9d9cb6187b2..5d6904ee589 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -167,6 +167,17 @@ comment : In order to specify interpolation mode default : std::string("bilinear") +- op : instance_norm + version : + - checkpoint : Change dispensable of attribute from False to True in instance_norm. + action : + - modify_attr : Bias + comment : "The arg 'dispensable' of Input 'Bias' is changed: from 'False' to 'True'." + default : "true" + - modify_attr : Scale + comment : "The arg 'dispensable' of Input 'Scale' is changed: from 'False' to 'True'." + default : "true" + - op : lamb version : - checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut]. diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index b8f99bbe106..7fc3acc5614 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -989,6 +989,18 @@ data_type : x backward : index_select_grad +- op : instance_norm + args : (Tensor x, Tensor scale, Tensor bias, float epsilon=1e-5) + output : Tensor(y), Tensor(saved_mean), Tensor(saved_variance) + infer_meta : + func : InstanceNormInferMeta + kernel : + func : instance_norm + data_type : x + optional : scale, bias + intermediate : saved_mean, saved_variance + backward : instance_norm_grad + - op : inverse args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/ops/compat/instance_norm_sig.cc b/paddle/phi/ops/compat/instance_norm_sig.cc deleted file mode 100644 index 6ccf1209798..00000000000 --- a/paddle/phi/ops/compat/instance_norm_sig.cc +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature InstanceNormOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("instance_norm", - {"X", "Scale", "Bias"}, - {"epsilon"}, - {"Y", "SavedMean", "SavedVariance"}); -} - -KernelSignature InstanceNormGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("instance_norm_grad", - {"X", "Scale", "SavedMean", "SavedVariance", "Y@GRAD"}, - {"epsilon"}, - {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); -} -KernelSignature InstanceNormDoubleGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("instance_norm_double_grad", - {"X", - "Scale", - "SavedMean", - "SavedVariance", - "DY", - "DDX", - "DDScale", - "DDBias"}, - {"epsilon"}, - {"DX", "DScale", "DDY"}); -} -} // namespace phi - -PD_REGISTER_BASE_KERNEL_NAME(instance_norm_grad_grad, - instance_norm_double_grad); -PD_REGISTER_ARG_MAPPING_FN(instance_norm, phi::InstanceNormOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(instance_norm_grad, - phi::InstanceNormGradOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(instance_norm_grad_grad, - phi::InstanceNormDoubleGradOpArgumentMapping); -- GitLab