未验证 提交 b0f17d05 编写于 作者: C chenjian 提交者: GitHub

[Prim] Add instance_norm composite rule (#52203)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* isamp

* gpu

* cpu

* noamp

* fix instance_norm

* fix

* fix unit test

* fix unit test

* add unit test

* fix

* add big data tests

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* add test case

* fix

* fix

* fix

* fix

* fix

* remove amp test

---------
Co-authored-by: Nheyanru01 <429520051@qq.com>
上级 f05c870b
......@@ -178,6 +178,36 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
return out, mean_, variance
@REGISTER_COMPOSITE('instance_norm')
def instancenorm_composite(x, scale, bias, epsilon):
"""
define composite rule of op instance_norm
out = (x - mean(x)) / sqrt(var + epsilon))
var = mean((x-mean(x))^2)
"""
n, c, h, w = x.shape
axis = tuple(range(2, len(x.shape)))
mean_ = mean(x, axis=axis, keepdim=True)
difference = x - mean_
var_tmp1 = difference * difference
variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon
sqrt_var = pow(var_tmp3, full([], 0.5, dtype=var_tmp3.dtype))
out = difference / sqrt_var
if scale is not None:
scale_tile = reshape(scale, [1, c, 1, 1])
out = out * scale_tile
if bias is not None:
bias_tile = reshape(bias, [1, c, 1, 1])
out = out + bias_tile
mean_ = reshape(mean_, [-1])
saved_variance = 1 / sqrt_var
saved_variance = reshape(saved_variance, [-1])
return out, mean_, saved_variance
@REGISTER_COMPOSITE('gelu')
def gelu_composite(x, approximate):
"""define composite rule of op gelu"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册