From f7b80adab06ffadf4265e5f366ddad4726740a96 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Tue, 18 Apr 2023 09:37:04 +0800 Subject: [PATCH] [prim add instance_norm custom vjp] (#52935) --- paddle/fluid/operators/instance_norm_op.cc | 48 ++++++++++++++++- .../composite_backward_api.h | 52 +++++++++++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + 3 files changed, 100 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index 289df565b88..8d76b46968b 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -21,6 +21,9 @@ limitations under the License. */ #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/ternary.h" @@ -140,6 +143,48 @@ phi::KernelKey InstanceNormDoubleGradOp::GetExpectedKernelType( ctx.GetPlace()); } +class InstanceNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // inputs and outputs of batch_norm + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor scale = this->GetSingleForwardInput("Scale"); + paddle::Tensor saved_mean = this->GetSingleForwardOutput("SavedMean"); + paddle::Tensor saved_variance = + this->GetSingleForwardOutput("SavedVariance"); + + paddle::Tensor y_grad = this->GetSingleOutputGrad("Y"); + paddle::Tensor x_grad = this->GetSingleInputGrad("X"); + paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale"); + paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias"); + + auto x_grad_ptr = this->GetOutputPtr(&x_grad); + std::string x_grad_name = this->GetOutputName(x_grad); + auto scale_grad_ptr = this->GetOutputPtr(&scale_grad); + std::string scale_grad_name = this->GetOutputName(scale_grad); + auto bias_grad_ptr = this->GetOutputPtr(&bias_grad); + std::string bias_grad_name = this->GetOutputName(bias_grad); + + auto epsilon = this->Attr("epsilon"); + + VLOG(3) << "Runing instance_norm composite func"; + prim::instance_norm_grad(x, + scale, + saved_mean, + saved_variance, + y_grad, + epsilon, + x_grad_ptr, + scale_grad_ptr, + bias_grad_ptr); + this->RecoverOutputName(x_grad, x_grad_name); + this->RecoverOutputName(scale_grad, scale_grad_name); + this->RecoverOutputName(bias_grad, bias_grad_name); + } +}; + DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer, {"DY", "DDY"}); @@ -163,7 +208,8 @@ REGISTER_OPERATOR(instance_norm, ops::InstanceNormOpInferVarType, ops::InstanceNormGradMaker, ops::InstanceNormGradMaker, - InstanceNormInferShapeFunctor); + InstanceNormInferShapeFunctor, + ops::InstanceNormCompositeGradOpMaker); REGISTER_OPERATOR(instance_norm_grad, ops::InstanceNormGradOp, ops::InstanceNormDoubleGradMaker, diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index c0830b2a754..230914559fb 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1493,6 +1493,58 @@ void batch_norm_grad(const Tensor& x, } } +template +void instance_norm_grad(const Tensor& x, + const paddle::optional& scale, + const Tensor& saved_mean, + const Tensor& saved_variance, + const Tensor& y_grad, + float epsilon, + Tensor* x_grad, + Tensor* scale_grad, + Tensor* bias_grad) { + const int n = x.dims()[0]; + const int c = x.dims()[1]; + const int h = x.dims()[2]; + const int w = x.dims()[3]; + + Tensor x_hat; + Tensor std_inv; + if (scale_grad || x_grad) { + auto mean = reshape(saved_mean, IntArray({n, c, 1, 1})) + .tile(IntArray({1, 1, h, w})); + std_inv = reshape(saved_variance, IntArray({n, c, 1, 1})) + .tile(IntArray({1, 1, h, w})); + x_hat = (x - mean) * std_inv; + } + + // x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad * + // x_hat).mean((h,w))) + if (x_grad) { + auto scale_t = + reshape(scale.get_ptr() ? scale.get() + : full(IntArray({c}), 1., x.dtype()), + IntArray({1, c, 1, 1})) + .tile(IntArray({n, 1, h, w})); + set_output( + (scale_t * std_inv) * + (y_grad - + y_grad.sum(IntArray({2, 3}), y_grad.dtype(), true) / (h * w) - + (x_hat * + ((y_grad * x_hat).sum(IntArray({2, 3}), y_grad.dtype(), true) / + (h * w)))), + x_grad); + } + // scale_grad = x_hat * y_grad.sum(n, h, w) + if (scale_grad) { + set_output((y_grad * x_hat).sum(IntArray({0, 2, 3})), scale_grad); + } + // d_bias = y_grad.sum(n, h, w) + if (bias_grad) { + set_output(y_grad.sum(IntArray({0, 2, 3})), bias_grad); + } +} + template void gelu_grad(const Tensor& x, const Tensor& out_grad, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 9d323fe3a06..626e0ed6f62 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -496,6 +496,7 @@ data_type : x optional : scale backward : instance_norm_double_grad + composite: instance_norm_grad(x, scale, saved_mean, saved_variance, y_grad, epsilon, x_grad, scale_grad, bias_grad) - backward_op : layer_norm_grad forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance) -- GitLab