未验证 提交 069bb2d9 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Support prim vjp of operator group_norm (#52663)

* add gn vjp

* fix 0

* fix args num

* fix type

* debug2

* remove unused expand

* support fp16

* fix typo

* fix reshape bug

* test3

* test4

* fix bug3

* add comment
上级 417e5baf
......@@ -17,6 +17,10 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
......@@ -158,6 +162,59 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
}
};
class GroupNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
// inputs and outputs of group_norm
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::optional<paddle::Tensor> scale =
this->GetOptionalSingleForwardInput("Scale");
paddle::optional<paddle::Tensor> bias =
this->GetOptionalSingleForwardInput("Bias");
paddle::Tensor y = this->GetSingleForwardOutput("Y");
paddle::Tensor mean = this->GetSingleForwardOutput("Mean");
paddle::Tensor variance = this->GetSingleForwardOutput("Variance");
paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");
auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dscale_ptr = this->GetOutputPtr(&scale_grad);
std::string dscale_name = this->GetOutputName(scale_grad);
auto dbias_ptr = this->GetOutputPtr(&bias_grad);
std::string dbias_name = this->GetOutputName(bias_grad);
// attrs of group_norm
auto groups = this->Attr<int>("groups");
auto epsilon = this->Attr<float>("epsilon");
auto data_layout = this->Attr<std::string>("data_layout");
VLOG(3) << "Runing group_norm composite func";
prim::group_norm_grad<prim::DescTensor>(x,
scale,
bias,
y,
mean,
variance,
y_grad,
epsilon,
groups,
data_layout,
dx_ptr,
dscale_ptr,
dbias_ptr);
this->RecoverOutputName(x_grad, dx_name);
this->RecoverOutputName(scale_grad, dscale_name);
this->RecoverOutputName(bias_grad, dbias_name);
}
};
DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInferer,
{framework::GradVarName("Y"),
framework::GradVarName("X")});
......@@ -186,6 +243,7 @@ REGISTER_OPERATOR(group_norm,
ops::GroupNormOpInferVarType,
ops::GroupNormGradMaker<paddle::framework::OpDesc>,
ops::GroupNormGradMaker<paddle::imperative::OpBase>,
ops::GroupNormCompositeGradOpMaker,
GroupNormInferShapeFunctor);
REGISTER_OPERATOR(group_norm_grad,
ops::GroupNormGradOp,
......
......@@ -954,6 +954,156 @@ void slice_grad(const Tensor& input,
}
}
template <typename T>
void group_norm_grad(const Tensor& x,
const paddle::optional<Tensor>& scale,
const paddle::optional<Tensor>& bias,
const Tensor& y,
const Tensor& mean,
const Tensor& variance,
const Tensor& out_grad,
float epsilon,
int groups,
const std::string& data_layout,
Tensor* x_grad,
Tensor* scale_grad,
Tensor* bias_grad) {
// x.shape=[n,c,h,w]
// y.shape=[n,c,h,w]
// g_size = c/g
// scale.shape=[c]
// mean, var: shape=[n, g]
// inv_std = rsqrt(var + epsilon)
// ds = sum(dy * x, axes=(2,3))
// db = sum(dy, axes=(2,3))
//
// cal d_x:
// s = g / (h*w*c)
// if scale:
// ds_val = sum((ds * scale).reshape(n, g, g_size), axes=2)
// db_val = sum((db * scale).reshape(n, g, g_size), axes=2)
// p1 = (inv_std.reshape(n, g, 1)) * (scale.reshape(1, g, g_size))
// else:
// ds_val = sum(ds.reshape(n, g, g_size), axes=2)
// db_val = sum(db.reshape(n, g, g_size), axes=2)
// p1 = (inv_std.reshape(n, g, 1)) * (ones(1, g, g_size))
// p2 = (db_val * mean - ds_val) * inv_std * inv_std * inv_std * s
// p3 = -p2 * mean - db_val * inv_std * s
// p1.reshape(n, g, g_size, 1)
// p2.reshape(n, g, 1, 1)
// p3.reshape(n, g, 1, 1)
// d_x = dy.reshape(n, g, g_size, h*w) * p1 + x.reshape(n, g, g_size, h*w)* p2
// + p3
//
// cal d_scale:
// temp = ds.reshape(n, g, g_size) - db.reshape(n, g, g_size) *
// mean.reshape(n, g, 1)
// d_scale = sum(temp * inv_std.reshape(n, g, 1), axes=0).reshape(c)
//
// cal d_bias:
// d_bias = sum(dy, axes=(0,2,3))
DataLayout data_layout_ = phi::StringToDataLayout(data_layout);
if (data_layout_ != DataLayout::kNCHW) {
PADDLE_THROW(phi::errors::InvalidArgument("Unsupported storage order: %s",
data_layout));
}
Tensor x_data = x;
Tensor out_grad_data = out_grad;
if (x.dtype() == phi::DataType::FLOAT16) {
x_data = cast<T>(x, phi::DataType::FLOAT32);
}
if (out_grad.dtype() == phi::DataType::FLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
}
std::vector<int64_t> x_dims = phi::vectorize<int64_t>(x.dims());
auto add_axis = std::vector<int64_t>({-1});
const int N = x_dims[0];
const int C = x_dims[1];
const int hw = x_dims[2] * x_dims[3];
const int g_num = C / groups;
auto reduce_axis = IntArray(std::vector<int64_t>({2, 3}));
auto shape_group = IntArray(std::vector<int64_t>({N, groups, g_num}));
auto whole_group_shape =
IntArray(std::vector<int64_t>({N, groups, g_num, hw}));
auto scale_ptr = scale.get_ptr();
auto bias_ptr = bias.get_ptr();
auto inv_std = sqrt<T>(1.0 / (variance + epsilon));
auto inv_std_mul_s = inv_std / hw / g_num;
auto dtype = x_data.dtype();
auto sum_y_grad_mul_x =
sum<T>(out_grad_data * x_data, reduce_axis, dtype, false);
auto sum_y_grad = sum<T>(out_grad_data, reduce_axis, dtype, false);
if (x_grad) {
Tensor d1;
Tensor d2;
Tensor p1;
if (scale_ptr) {
auto scale_data = scale.get();
if (scale_data.dtype() == phi::DataType::FLOAT16) {
scale_data = cast<T>(scale_data, phi::DataType::FLOAT32);
}
d1 = (reshape<T>(sum_y_grad_mul_x * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
d2 = (reshape<T>(sum_y_grad * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
p1 = reshape<T>(inv_std, std::vector<int64_t>({N, groups, 1})) *
reshape<T>(scale_data, std::vector<int64_t>({1, groups, g_num}));
} else {
d1 = (reshape<T>(sum_y_grad_mul_x, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
d2 = (reshape<T>(sum_y_grad, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
p1 = (reshape<T>(inv_std, std::vector<int64_t>({N, groups, 1})))
.expand(IntArray(shape_group));
}
auto p2 = (d2 * mean - d1) * (inv_std_mul_s * inv_std * inv_std);
auto p3 = -p2 * mean - d2 * inv_std_mul_s;
p1 = unsqueeze<T>(p1, std::vector<int64_t>({3}));
p2 = unsqueeze<T>(p2, std::vector<int64_t>({2, 3}));
p3 = unsqueeze<T>(p3, std::vector<int64_t>({2, 3}));
auto tmp_1 = reshape<T>(out_grad_data, whole_group_shape) * p1;
auto tmp_2 = reshape<T>(x_data, whole_group_shape) * p2 + p3;
auto x_grad_data = tmp_1 + tmp_2;
x_grad_data = reshape<T>(x_grad_data, x.shape());
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}
set_output<T>(x_grad_data, x_grad);
}
if (scale_grad) {
if (scale_ptr) {
auto tmp1 = (reshape<T>(sum_y_grad_mul_x, shape_group) -
reshape<T>(sum_y_grad, shape_group) *
unsqueeze<T>(mean, std::vector<int64_t>({2}))) *
unsqueeze<T>(inv_std, std::vector<int64_t>({2}));
auto scale_grad_tmp =
reshape<T>(tmp1.sum(std::vector<int64_t>({0}), dtype, false),
IntArray(std::vector<int64_t>({C})));
set_output<T>(scale_grad_tmp, scale_grad);
} else {
scale_grad = nullptr;
}
}
if (bias_grad) {
if (bias_ptr) {
auto bias_grad_tmp =
sum_y_grad.sum(std::vector<int64_t>({0}), dtype, false);
set_output<T>(bias_grad_tmp, bias_grad);
} else {
bias_grad = nullptr;
}
}
}
template <typename T>
void layer_norm_grad(const Tensor& x,
const paddle::optional<Tensor>& scale,
......
......@@ -439,6 +439,7 @@
kernel :
func : group_norm_grad
data_type : y_grad
composite : group_norm_grad(x, scale, bias, y, mean, variance, y_grad, epsilon, groups, data_layout)
optional: scale, bias
inplace : (y_grad -> x_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册