diff --git a/test/legacy_test/test_batch_norm_op_prim_nchw.py b/test/legacy_test/test_batch_norm_op_prim_nchw.py index d343231e6a3420db2a2db1d9cd832ca865eddbd4..9d11d264908f0690bbfdb685ca853235c1723aab 100644 --- a/test/legacy_test/test_batch_norm_op_prim_nchw.py +++ b/test/legacy_test/test_batch_norm_op_prim_nchw.py @@ -108,7 +108,7 @@ class TestBatchNormOp(OpTest): ) def test_check_grad_scale_bias(self): - if self.data_format == "NCHW": + if self.data_format == "NCHW" and self.training is False: self.enable_cinn = False if self.dtype == "float32": self.rev_comp_atol = 1e-3