未验证 提交 e9e95db6 编写于 作者: Z zyfncg 提交者: GitHub

fix instance_norm for cinn (#53075)

上级 0d4f4960
......@@ -125,7 +125,6 @@ class TestInstanceNormOp(OpTest):
'SavedMean': ref_mean_np,
'SavedVariance': ref_var_np,
}
self.enable_cinn = False
def test_check_output(self):
self.check_output(check_prim=True)
......@@ -165,6 +164,8 @@ class TestInstanceNormFP64(TestInstanceNormOp):
self.mean_np, self.var_np = _cal_mean_variance(
self.x_np, self.epsilon, mean_shape
)
self.cinn_atol = 1e-13
self.cinn_rtol = 1e-13
self.fw_comp_rtol = 1e-14
self.fw_comp_atol = 1e-14
self.rev_comp_rtol = 1e-13
......@@ -667,7 +668,7 @@ class TestCompositeInstanceNormNorm(unittest.TestCase):
stop_gradient=False,
)
net = PrimGroupNorm(self.num_channels, scale_, bias_)
net = apply_to_static(net, False)
net = apply_to_static(net, True)
output = net(input_)
grad = paddle.grad(output, input_)
fwd_actual.append(output.numpy())
......
......@@ -194,7 +194,7 @@ def instancenorm_composite(x, scale, bias, epsilon):
var_tmp1 = difference * difference
variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon
sqrt_var = pow(var_tmp3, full([], 0.5, dtype=var_tmp3.dtype))
sqrt_var = pow(var_tmp3, full([1], 0.5, dtype=var_tmp3.dtype))
out = difference / sqrt_var
if scale is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册