未验证 提交 cabf3921 编写于 作者: Y Yichen Zhang 提交者: GitHub

Add group_norm composite rule (#51874)

* add group_norm composite rule

* add test for scale_grad and bias_grad

* resolve conflicts

* remove amp in composite_rule.py

* add float16 test

* deal with NHWC format

* keep the composite rule in float16 identical as original kernel

* resolve conflicts
上级 548d5522
......@@ -501,8 +501,20 @@ void GroupNormInferMeta(const MetaTensor& x,
y->set_dims(x_dim);
y->set_dtype(x.dtype());
y->share_lod(x);
mean->set_dims({batch_size, groups});
variance->set_dims({batch_size, groups});
phi::DataType x_dtype = x.dtype();
phi::DataType param_type =
(x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16)
? phi::DataType::FLOAT32
: x_dtype;
if (mean) {
mean->set_dims({batch_size, groups});
mean->set_dtype(param_type);
}
if (variance) {
variance->set_dims({batch_size, groups});
variance->set_dtype(param_type);
}
}
void LayerNormInferMeta(const MetaTensor& x,
......
......@@ -1203,7 +1203,8 @@ set(TEST_CINN_OPS
test_meshgrid_op
test_gather_op
test_cast_op
test_dropout_op)
test_dropout_op
test_group_norm_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -558,3 +558,43 @@ def rsqrt_composite(x):
# rsqrt(x) = x^(-0.5)
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
return pow(x, y)
@REGISTER_COMPOSITE('group_norm')
def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
"""
define composite rule of op group_norm.
x = ((x - mean) / sqrt(var + epsilon)) * scale + bias
mean and var are computed from groups
"""
# original GroupNorm op cannot support NHWC format
assert data_layout == 'NCHW'
N, C, H, W = x.shape
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
# when inputs are float16, convert to float32 in computing
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
scale = cast(scale, "float32")
bias = cast(bias, "float32")
x = reshape(x, (N * groups, -1))
mean_ = mean(x, axis=1, keepdim=True)
var_ = mean(x * x, axis=1, keepdim=True) - mean_ * mean_
var_ = maximum(var_, zeros_like(var_))
var_inv = 1 / sqrt(var_ + epsilon)
out = (x - mean_) * var_inv
out = reshape(out, (N, C, H, W))
if scale is not None:
out = out * reshape(scale, (-1, 1, 1))
if bias is not None:
out = out + reshape(bias, (-1, 1, 1))
ret_mean_ = reshape(mean_, (N, groups))
ret_var_ = reshape(var_, (N, groups))
# return output in float16, mean and var in float32
if is_amp:
out = cast(out, "float16")
return out, ret_mean_, ret_var_
......@@ -132,5 +132,6 @@ others = [
'uniform',
'greater_equal',
'zeros_like',
'transpose',
]
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册