未验证 提交 7b81092b 编写于 作者: 张春乔 提交者: GitHub

[static op generation] InstanceNorm (#53340)

* mv InstanceNorm

* modify op_version.yaml

* modify add Operator:: in get_expected_kernel_func.cc

* rm gradexpectedkernel

* add extra

* add float epsilon=1e-5
上级 ad45b368
...@@ -181,5 +181,35 @@ phi::KernelKey GetUniqueExpectedKernelType( ...@@ -181,5 +181,35 @@ phi::KernelKey GetUniqueExpectedKernelType(
} }
} }
phi::KernelKey GetInstanceNormExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto input_data_type =
op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto in_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
in_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(in_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Scale")->dtype()),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(in_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Bias")->dtype()),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -48,6 +48,10 @@ phi::KernelKey GetUniqueExpectedKernelType( ...@@ -48,6 +48,10 @@ phi::KernelKey GetUniqueExpectedKernelType(
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr); const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetInstanceNormExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetYoloLossExpectedKernelType( phi::KernelKey GetYoloLossExpectedKernelType(
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr); const framework::OperatorWithKernel* op_ptr);
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/instance_norm_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
phi::KernelKey InstanceNormOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto in_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
in_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(in_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Scale")->dtype()),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(in_param_type,
framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Bias")->dtype()),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
void InstanceNormOpMaker::Make() {
AddAttr<float>("epsilon", "")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
true,
platform::errors::InvalidArgument(
"'epsilon' should be between 0.0 and 0.001."));
});
AddInput("X", "The input tensor");
AddInput("Scale",
"Scale is a 1-dimensional tensor of size C "
"that is applied to the output")
.AsDispensable();
AddInput("Bias",
"Bias is a 1-dimensional tensor of size C "
"that is applied to the output")
.AsDispensable();
AddOutput("Y", "result after normalization");
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training")
.AsIntermediate()
.AsExtra();
AddOutput("SavedVariance",
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate()
.AsExtra();
AddComment(R"DOC(
Instance Normalization.
Instance Norm has been implemented as disscussed in the paper:
https://arxiv.org/pdf/1607.08022.pdf
Can be used as a normalizer function for conv2d and fully_connected operations.
The required data format for this layer is as following:
NCHW `[batch, in_channels, in_height, in_width]`
)DOC");
}
phi::KernelKey InstanceNormGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW(
platform::errors::NotFound("cannot find gradient variable of Y"));
}
const phi::DenseTensor *t = nullptr;
if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
phi::KernelKey InstanceNormDoubleGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar("DY");
if (var == nullptr) {
PADDLE_THROW(
platform::errors::NotFound("cannot find gradient variable of Y"));
}
const phi::DenseTensor *t = nullptr;
if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
class InstanceNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
// inputs and outputs of batch_norm
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor scale = this->GetSingleForwardInput("Scale");
paddle::Tensor saved_mean = this->GetSingleForwardOutput("SavedMean");
paddle::Tensor saved_variance =
this->GetSingleForwardOutput("SavedVariance");
paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");
auto x_grad_ptr = this->GetOutputPtr(&x_grad);
std::string x_grad_name = this->GetOutputName(x_grad);
auto scale_grad_ptr = this->GetOutputPtr(&scale_grad);
std::string scale_grad_name = this->GetOutputName(scale_grad);
auto bias_grad_ptr = this->GetOutputPtr(&bias_grad);
std::string bias_grad_name = this->GetOutputName(bias_grad);
auto epsilon = this->Attr<float>("epsilon");
VLOG(3) << "Runing instance_norm composite func";
prim::instance_norm_grad<prim::DescTensor>(x,
scale,
saved_mean,
saved_variance,
y_grad,
epsilon,
x_grad_ptr,
scale_grad_ptr,
bias_grad_ptr);
this->RecoverOutputName(x_grad, x_grad_name);
this->RecoverOutputName(scale_grad, scale_grad_name);
this->RecoverOutputName(bias_grad, bias_grad_name);
}
};
DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer,
{"DY", "DDY"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(instance_norm,
InstanceNormInferShapeFunctor,
PD_INFER_META(phi::InstanceNormInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(instance_norm_grad,
InstanceNormGradInferShapeFunctor,
PD_INFER_META(phi::InstanceNormGradInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(
instance_norm_grad_grad,
InstanceNormDoubleGradInferShapeFunctor,
PD_INFER_META(phi::InstanceNormDoubleGradInferMeta));
REGISTER_OPERATOR(instance_norm,
ops::InstanceNormOp,
ops::InstanceNormOpMaker,
ops::InstanceNormOpInferVarType,
ops::InstanceNormGradMaker<paddle::framework::OpDesc>,
ops::InstanceNormGradMaker<paddle::imperative::OpBase>,
InstanceNormInferShapeFunctor,
ops::InstanceNormCompositeGradOpMaker);
REGISTER_OPERATOR(instance_norm_grad,
ops::InstanceNormGradOp,
ops::InstanceNormDoubleGradMaker<paddle::framework::OpDesc>,
ops::InstanceNormDoubleGradMaker<paddle::imperative::OpBase>,
InstanceNormGradInferShapeFunctor);
REGISTER_OPERATOR(instance_norm_grad_grad,
ops::InstanceNormDoubleGradOp,
ops::InstanceNormDoubleGradOpInplaceInferer,
InstanceNormDoubleGradInferShapeFunctor);
REGISTER_OP_VERSION(instance_norm)
.AddCheckpoint(
R"ROC(
Change dispensable of attribute from False to True in instance_norm.
)ROC",
paddle::framework::compatible::OpVersionDesc()
.ModifyAttr(
"Bias",
"The arg 'dispensable' of Input 'Bias' is changed: from "
"'False' to 'True'.",
true)
.ModifyAttr(
"Scale",
"The arg 'dispensable' of Input 'Scale' is changed: from "
"'False' to 'True'.",
true));
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using DataLayout = phi::DataLayout;
class InstanceNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class InstanceNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class InstanceNormDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class InstanceNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
template <typename T>
class InstanceNormGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("instance_norm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("SavedMean", this->Output("SavedMean"));
op->SetInput("SavedVariance", this->Output("SavedVariance"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
}
};
template <typename T>
class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("instance_norm_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("SavedMean", this->Input("SavedMean"));
op->SetInput("SavedVariance", this->Input("SavedVariance"));
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DDScale", this->OutputGrad(framework::GradVarName("Scale")));
op->SetInput("DDBias", this->OutputGrad(framework::GradVarName("Bias")));
op->SetInput("DY", this->Input(framework::GradVarName("Y")));
op->SetAttrMap(this->Attrs());
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DScale", this->InputGrad("Scale"));
op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y")));
}
};
class InstanceNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", "Y"}};
return m;
}
};
} // namespace operators
} // namespace paddle
...@@ -851,6 +851,30 @@ ...@@ -851,6 +851,30 @@
data_type : out_grad data_type : out_grad
no_need_buffer : x no_need_buffer : x
- backward_op : instance_norm_double_grad
forward : instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias)
args : (Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon)
output : Tensor(x_grad), Tensor(fwd_scale_grad), Tensor(grad_y_grad)
infer_meta :
func : InstanceNormDoubleGradInferMeta
kernel :
func : instance_norm_double_grad
data_type : x
optional : fwd_scale, grad_x_grad, grad_scale_grad, grad_bias_grad
- backward_op : instance_norm_grad
forward : instance_norm(Tensor x, Tensor scale, Tensor bias, float epsilon) -> Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
args : (Tensor x, Tensor scale, Tensor saved_mean, Tensor saved_variance, Tensor y_grad, float epsilon=1e-5)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : InstanceNormGradInferMeta
kernel :
func : instance_norm_grad
data_type : x
optional : scale
backward : instance_norm_double_grad
composite: instance_norm_grad(x, scale, saved_mean, saved_variance, y_grad, epsilon, x_grad, scale_grad, bias_grad)
- backward_op : inverse_grad - backward_op : inverse_grad
forward : inverse(Tensor x) -> Tensor(out) forward : inverse(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
......
...@@ -476,30 +476,6 @@ ...@@ -476,30 +476,6 @@
kernel : kernel :
func : hsigmoid_loss_grad func : hsigmoid_loss_grad
- backward_op : instance_norm_double_grad
forward : instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias)
args : (Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon)
output : Tensor(x_grad), Tensor(fwd_scale_grad), Tensor(grad_y_grad)
infer_meta :
func : InstanceNormDoubleGradInferMeta
kernel :
func : instance_norm_double_grad
data_type : x
optional : fwd_scale, grad_x_grad, grad_scale_grad, grad_bias_grad
- backward_op : instance_norm_grad
forward : instance_norm(Tensor x, Tensor scale, Tensor bias, float epsilon) -> Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
args : (Tensor x, Tensor scale, Tensor saved_mean, Tensor saved_variance, Tensor y_grad, float epsilon)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : InstanceNormGradInferMeta
kernel :
func : instance_norm_grad
data_type : x
optional : scale
backward : instance_norm_double_grad
composite: instance_norm_grad(x, scale, saved_mean, saved_variance, y_grad, epsilon, x_grad, scale_grad, bias_grad)
- backward_op : layer_norm_grad - backward_op : layer_norm_grad
forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance) forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance)
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis)
......
...@@ -597,18 +597,6 @@ ...@@ -597,18 +597,6 @@
func : increment func : increment
inplace : (x -> out) inplace : (x -> out)
- op : instance_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon)
output : Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
infer_meta :
func : InstanceNormInferMeta
kernel :
func : instance_norm
data_type : x
optional : scale, bias
intermediate : saved_mean, saved_variance
backward : instance_norm_grad
- op : layer_norm - op : layer_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis)
output : Tensor(out), Tensor(mean), Tensor(variance) output : Tensor(out), Tensor(mean), Tensor(variance)
......
...@@ -1171,6 +1171,10 @@ ...@@ -1171,6 +1171,10 @@
y : Y y : Y
saved_mean : SavedMean saved_mean : SavedMean
saved_variance : SavedVariance saved_variance : SavedVariance
extra:
outputs: [ saved_mean, saved_variance ]
get_expected_kernel_type:
instance_norm: GetInstanceNormExpectedKernelType
- op : inverse - op : inverse
inputs : inputs :
......
...@@ -167,6 +167,17 @@ ...@@ -167,6 +167,17 @@
comment : In order to specify interpolation mode comment : In order to specify interpolation mode
default : std::string("bilinear") default : std::string("bilinear")
- op : instance_norm
version :
- checkpoint : Change dispensable of attribute from False to True in instance_norm.
action :
- modify_attr : Bias
comment : "The arg 'dispensable' of Input 'Bias' is changed: from 'False' to 'True'."
default : "true"
- modify_attr : Scale
comment : "The arg 'dispensable' of Input 'Scale' is changed: from 'False' to 'True'."
default : "true"
- op : lamb - op : lamb
version : version :
- checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut]. - checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut].
......
...@@ -989,6 +989,18 @@ ...@@ -989,6 +989,18 @@
data_type : x data_type : x
backward : index_select_grad backward : index_select_grad
- op : instance_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon=1e-5)
output : Tensor(y), Tensor(saved_mean), Tensor(saved_variance)
infer_meta :
func : InstanceNormInferMeta
kernel :
func : instance_norm
data_type : x
optional : scale, bias
intermediate : saved_mean, saved_variance
backward : instance_norm_grad
- op : inverse - op : inverse
args : (Tensor x) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature InstanceNormOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("instance_norm",
{"X", "Scale", "Bias"},
{"epsilon"},
{"Y", "SavedMean", "SavedVariance"});
}
KernelSignature InstanceNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("instance_norm_grad",
{"X", "Scale", "SavedMean", "SavedVariance", "Y@GRAD"},
{"epsilon"},
{"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
}
KernelSignature InstanceNormDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("instance_norm_double_grad",
{"X",
"Scale",
"SavedMean",
"SavedVariance",
"DY",
"DDX",
"DDScale",
"DDBias"},
{"epsilon"},
{"DX", "DScale", "DDY"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(instance_norm_grad_grad,
instance_norm_double_grad);
PD_REGISTER_ARG_MAPPING_FN(instance_norm, phi::InstanceNormOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(instance_norm_grad,
phi::InstanceNormGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(instance_norm_grad_grad,
phi::InstanceNormDoubleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册