未验证 提交 e9e95db6 编写于 作者: Z zyfncg 提交者: GitHub

fix instance_norm for cinn (#53075)

上级 0d4f4960
...@@ -125,7 +125,6 @@ class TestInstanceNormOp(OpTest): ...@@ -125,7 +125,6 @@ class TestInstanceNormOp(OpTest):
'SavedMean': ref_mean_np, 'SavedMean': ref_mean_np,
'SavedVariance': ref_var_np, 'SavedVariance': ref_var_np,
} }
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -165,6 +164,8 @@ class TestInstanceNormFP64(TestInstanceNormOp): ...@@ -165,6 +164,8 @@ class TestInstanceNormFP64(TestInstanceNormOp):
self.mean_np, self.var_np = _cal_mean_variance( self.mean_np, self.var_np = _cal_mean_variance(
self.x_np, self.epsilon, mean_shape self.x_np, self.epsilon, mean_shape
) )
self.cinn_atol = 1e-13
self.cinn_rtol = 1e-13
self.fw_comp_rtol = 1e-14 self.fw_comp_rtol = 1e-14
self.fw_comp_atol = 1e-14 self.fw_comp_atol = 1e-14
self.rev_comp_rtol = 1e-13 self.rev_comp_rtol = 1e-13
...@@ -667,7 +668,7 @@ class TestCompositeInstanceNormNorm(unittest.TestCase): ...@@ -667,7 +668,7 @@ class TestCompositeInstanceNormNorm(unittest.TestCase):
stop_gradient=False, stop_gradient=False,
) )
net = PrimGroupNorm(self.num_channels, scale_, bias_) net = PrimGroupNorm(self.num_channels, scale_, bias_)
net = apply_to_static(net, False) net = apply_to_static(net, True)
output = net(input_) output = net(input_)
grad = paddle.grad(output, input_) grad = paddle.grad(output, input_)
fwd_actual.append(output.numpy()) fwd_actual.append(output.numpy())
......
...@@ -194,7 +194,7 @@ def instancenorm_composite(x, scale, bias, epsilon): ...@@ -194,7 +194,7 @@ def instancenorm_composite(x, scale, bias, epsilon):
var_tmp1 = difference * difference var_tmp1 = difference * difference
variance = mean(var_tmp1, axis=axis, keepdim=True) variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon var_tmp3 = variance + epsilon
sqrt_var = pow(var_tmp3, full([], 0.5, dtype=var_tmp3.dtype)) sqrt_var = pow(var_tmp3, full([1], 0.5, dtype=var_tmp3.dtype))
out = difference / sqrt_var out = difference / sqrt_var
if scale is not None: if scale is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册