From 069bb2d9032b20d4bf4cfbd150c0f75d7838922e Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Tue, 18 Apr 2023 15:10:00 +0800 Subject: [PATCH] [Prim] Support prim vjp of operator group_norm (#52663) * add gn vjp * fix 0 * fix args num * fix type * debug2 * remove unused expand * support fp16 * fix typo * fix reshape bug * test3 * test4 * fix bug3 * add comment --- paddle/fluid/operators/group_norm_op.cc | 58 +++++++ .../composite_backward_api.h | 150 ++++++++++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + 3 files changed, 209 insertions(+) diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index cead83df152..611b00b7c62 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -17,6 +17,10 @@ limitations under the License. */ #include #include +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" + #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/infermeta_utils.h" @@ -158,6 +162,59 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker { } }; +class GroupNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // inputs and outputs of group_norm + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::optional scale = + this->GetOptionalSingleForwardInput("Scale"); + paddle::optional bias = + this->GetOptionalSingleForwardInput("Bias"); + paddle::Tensor y = this->GetSingleForwardOutput("Y"); + paddle::Tensor mean = this->GetSingleForwardOutput("Mean"); + paddle::Tensor variance = this->GetSingleForwardOutput("Variance"); + + paddle::Tensor y_grad = this->GetSingleOutputGrad("Y"); + paddle::Tensor x_grad = this->GetSingleInputGrad("X"); + paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale"); + paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias"); + + auto dx_ptr = this->GetOutputPtr(&x_grad); + std::string dx_name = this->GetOutputName(x_grad); + auto dscale_ptr = this->GetOutputPtr(&scale_grad); + std::string dscale_name = this->GetOutputName(scale_grad); + auto dbias_ptr = this->GetOutputPtr(&bias_grad); + std::string dbias_name = this->GetOutputName(bias_grad); + + // attrs of group_norm + auto groups = this->Attr("groups"); + auto epsilon = this->Attr("epsilon"); + auto data_layout = this->Attr("data_layout"); + + VLOG(3) << "Runing group_norm composite func"; + + prim::group_norm_grad(x, + scale, + bias, + y, + mean, + variance, + y_grad, + epsilon, + groups, + data_layout, + dx_ptr, + dscale_ptr, + dbias_ptr); + this->RecoverOutputName(x_grad, dx_name); + this->RecoverOutputName(scale_grad, dscale_name); + this->RecoverOutputName(bias_grad, dbias_name); + } +}; + DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInferer, {framework::GradVarName("Y"), framework::GradVarName("X")}); @@ -186,6 +243,7 @@ REGISTER_OPERATOR(group_norm, ops::GroupNormOpInferVarType, ops::GroupNormGradMaker, ops::GroupNormGradMaker, + ops::GroupNormCompositeGradOpMaker, GroupNormInferShapeFunctor); REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp, diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 230914559fb..1d537a720a3 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -954,6 +954,156 @@ void slice_grad(const Tensor& input, } } +template +void group_norm_grad(const Tensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + const Tensor& y, + const Tensor& mean, + const Tensor& variance, + const Tensor& out_grad, + float epsilon, + int groups, + const std::string& data_layout, + Tensor* x_grad, + Tensor* scale_grad, + Tensor* bias_grad) { + // x.shape=[n,c,h,w] + // y.shape=[n,c,h,w] + // g_size = c/g + // scale.shape=[c] + // mean, var: shape=[n, g] + // inv_std = rsqrt(var + epsilon) + // ds = sum(dy * x, axes=(2,3)) + // db = sum(dy, axes=(2,3)) + // + // cal d_x: + // s = g / (h*w*c) + // if scale: + // ds_val = sum((ds * scale).reshape(n, g, g_size), axes=2) + // db_val = sum((db * scale).reshape(n, g, g_size), axes=2) + // p1 = (inv_std.reshape(n, g, 1)) * (scale.reshape(1, g, g_size)) + // else: + // ds_val = sum(ds.reshape(n, g, g_size), axes=2) + // db_val = sum(db.reshape(n, g, g_size), axes=2) + // p1 = (inv_std.reshape(n, g, 1)) * (ones(1, g, g_size)) + // p2 = (db_val * mean - ds_val) * inv_std * inv_std * inv_std * s + // p3 = -p2 * mean - db_val * inv_std * s + // p1.reshape(n, g, g_size, 1) + // p2.reshape(n, g, 1, 1) + // p3.reshape(n, g, 1, 1) + // d_x = dy.reshape(n, g, g_size, h*w) * p1 + x.reshape(n, g, g_size, h*w)* p2 + // + p3 + // + // cal d_scale: + // temp = ds.reshape(n, g, g_size) - db.reshape(n, g, g_size) * + // mean.reshape(n, g, 1) + // d_scale = sum(temp * inv_std.reshape(n, g, 1), axes=0).reshape(c) + // + // cal d_bias: + // d_bias = sum(dy, axes=(0,2,3)) + DataLayout data_layout_ = phi::StringToDataLayout(data_layout); + if (data_layout_ != DataLayout::kNCHW) { + PADDLE_THROW(phi::errors::InvalidArgument("Unsupported storage order: %s", + data_layout)); + } + Tensor x_data = x; + Tensor out_grad_data = out_grad; + + if (x.dtype() == phi::DataType::FLOAT16) { + x_data = cast(x, phi::DataType::FLOAT32); + } + + if (out_grad.dtype() == phi::DataType::FLOAT16) { + out_grad_data = cast(out_grad, phi::DataType::FLOAT32); + } + + std::vector x_dims = phi::vectorize(x.dims()); + auto add_axis = std::vector({-1}); + const int N = x_dims[0]; + const int C = x_dims[1]; + + const int hw = x_dims[2] * x_dims[3]; + const int g_num = C / groups; + + auto reduce_axis = IntArray(std::vector({2, 3})); + auto shape_group = IntArray(std::vector({N, groups, g_num})); + auto whole_group_shape = + IntArray(std::vector({N, groups, g_num, hw})); + + auto scale_ptr = scale.get_ptr(); + auto bias_ptr = bias.get_ptr(); + auto inv_std = sqrt(1.0 / (variance + epsilon)); + auto inv_std_mul_s = inv_std / hw / g_num; + auto dtype = x_data.dtype(); + auto sum_y_grad_mul_x = + sum(out_grad_data * x_data, reduce_axis, dtype, false); + auto sum_y_grad = sum(out_grad_data, reduce_axis, dtype, false); + if (x_grad) { + Tensor d1; + Tensor d2; + Tensor p1; + if (scale_ptr) { + auto scale_data = scale.get(); + if (scale_data.dtype() == phi::DataType::FLOAT16) { + scale_data = cast(scale_data, phi::DataType::FLOAT32); + } + d1 = (reshape(sum_y_grad_mul_x * scale_data, shape_group)) + .sum(std::vector({2}), dtype, false); + d2 = (reshape(sum_y_grad * scale_data, shape_group)) + .sum(std::vector({2}), dtype, false); + p1 = reshape(inv_std, std::vector({N, groups, 1})) * + reshape(scale_data, std::vector({1, groups, g_num})); + } else { + d1 = (reshape(sum_y_grad_mul_x, shape_group)) + .sum(std::vector({2}), dtype, false); + d2 = (reshape(sum_y_grad, shape_group)) + .sum(std::vector({2}), dtype, false); + p1 = (reshape(inv_std, std::vector({N, groups, 1}))) + .expand(IntArray(shape_group)); + } + + auto p2 = (d2 * mean - d1) * (inv_std_mul_s * inv_std * inv_std); + auto p3 = -p2 * mean - d2 * inv_std_mul_s; + p1 = unsqueeze(p1, std::vector({3})); + p2 = unsqueeze(p2, std::vector({2, 3})); + p3 = unsqueeze(p3, std::vector({2, 3})); + auto tmp_1 = reshape(out_grad_data, whole_group_shape) * p1; + auto tmp_2 = reshape(x_data, whole_group_shape) * p2 + p3; + auto x_grad_data = tmp_1 + tmp_2; + x_grad_data = reshape(x_grad_data, x.shape()); + if (x.dtype() == phi::DataType::FLOAT16) { + x_grad_data = cast(x_grad_data, x.dtype()); + } + + set_output(x_grad_data, x_grad); + } + if (scale_grad) { + if (scale_ptr) { + auto tmp1 = (reshape(sum_y_grad_mul_x, shape_group) - + reshape(sum_y_grad, shape_group) * + unsqueeze(mean, std::vector({2}))) * + unsqueeze(inv_std, std::vector({2})); + auto scale_grad_tmp = + reshape(tmp1.sum(std::vector({0}), dtype, false), + IntArray(std::vector({C}))); + set_output(scale_grad_tmp, scale_grad); + } else { + scale_grad = nullptr; + } + } + + if (bias_grad) { + if (bias_ptr) { + auto bias_grad_tmp = + sum_y_grad.sum(std::vector({0}), dtype, false); + set_output(bias_grad_tmp, bias_grad); + } else { + bias_grad = nullptr; + } + } +} + template void layer_norm_grad(const Tensor& x, const paddle::optional& scale, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 33c6233da4d..2d4fcfb83ca 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -439,6 +439,7 @@ kernel : func : group_norm_grad data_type : y_grad + composite : group_norm_grad(x, scale, bias, y, mean, variance, y_grad, epsilon, groups, data_layout) optional: scale, bias inplace : (y_grad -> x_grad) -- GitLab