未验证 提交 3e0f0a00 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Fix batch_norm bias_grad loss in cinn (#54751)

* fix batch_norm grad kernel nhwc error

* fix batch_norm bias_grad loss in cinn

* disable cinn

* fix cinn_atol
上级 82eea3b9
......@@ -1307,7 +1307,7 @@ void batch_norm_grad(const Tensor& x,
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
set_output<T>(nhwc_out_grad_sum, bias_grad);
set_output<T>(assign<T>(nhwc_out_grad_sum), bias_grad);
}
break;
}
......@@ -1343,7 +1343,7 @@ void batch_norm_grad(const Tensor& x,
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
set_output<T>(out_grad_data_sum, bias_grad);
set_output<T>(assign<T>(out_grad_data_sum), bias_grad);
}
}
break;
......
......@@ -953,7 +953,8 @@ if(WITH_NV_JETSON)
set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 1200)
set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 1200)
set_tests_properties(test_norm_op PROPERTIES TIMEOUT 1200)
set_tests_properties(test_batch_norm_op_prim PROPERTIES TIMEOUT 1500)
set_tests_properties(test_batch_norm_op_prim_nchw PROPERTIES TIMEOUT 1500)
set_tests_properties(test_batch_norm_op_prim_nhwc PROPERTIES TIMEOUT 1500)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 1500)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 1500)
else()
......@@ -962,7 +963,8 @@ else()
set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_batch_norm_op_prim PROPERTIES TIMEOUT 250)
set_tests_properties(test_batch_norm_op_prim_nchw PROPERTIES TIMEOUT 250)
set_tests_properties(test_batch_norm_op_prim_nhwc PROPERTIES TIMEOUT 250)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150)
endif()
......@@ -1185,7 +1187,8 @@ set(TEST_CINN_OPS
test_clip_op
test_scatter_op
test_gather_op
test_batch_norm_op_prim
test_batch_norm_op_prim_nchw
test_batch_norm_op_prim_nhwc
test_layer_norm_op
test_cast_op
test_dropout_op
......
......@@ -108,9 +108,18 @@ class TestBatchNormOp(OpTest):
)
def test_check_grad_scale_bias(self):
self.enable_cinn = False
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
if self.data_format == "NCHW":
self.enable_cinn = False
if self.dtype == "float32":
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
elif self.dtype == "float64":
self.rev_comp_atol = 1e-12
self.rev_comp_rtol = 1e-12
self.cinn_atol = 1e-12
self.cinn_rtol = 1e-12
if self.dtype not in ("uint16", "float16"):
self.check_grad_with_place(
core.CPUPlace(),
......@@ -252,21 +261,6 @@ class TestBatchNormOpNCHWTestMode(TestBatchNormOp):
self.use_global_stats = True
class TestBatchNormOpNHWCTestMode(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = True
class TestBatchNormOpNCHWFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-11
......@@ -297,21 +291,6 @@ class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp):
self.use_global_stats = None
class TestBatchNormOpNHWCTestModeFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-15
self.fw_comp_rtol = 1e-15
self.rev_comp_atol = 1e-15
self.rev_comp_rtol = 1e-15
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNCHWFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
......@@ -342,21 +321,6 @@ class TestBatchNormOpNCHWTestModeFp16(TestBatchNormOp):
self.use_global_stats = None
class TestBatchNormOpNHWCTestModeFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
......@@ -401,95 +365,6 @@ class TestBatchNormOpNCHWTestModebf16(TestBatchNormOp):
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNHWCTestModebf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWC(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-11
self.fw_comp_rtol = 1e-11
self.rev_comp_atol = 1e-11
self.rev_comp_rtol = 1e-11
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNHWCbf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNCHWShape2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
......@@ -535,51 +410,6 @@ class TestBatchNormOpNCHWEps2(TestBatchNormOp):
self.use_global_stats = None
class TestBatchNormOpNHWCShape2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [4, 8, 16, 32]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCMomentum2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.9
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCEps2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-06
self.data_format = "NHWC"
self.use_global_stats = None
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import _set_use_system_allocator
from test_batch_norm_op_prim_nchw import TestBatchNormOp
import paddle
from paddle.fluid import core
paddle.enable_static()
np.random.seed(123)
paddle.seed(123)
_set_use_system_allocator(True)
class TestBatchNormOpNHWCTestMode(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = True
class TestBatchNormOpNHWCTestModeFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-15
self.fw_comp_rtol = 1e-15
self.rev_comp_atol = 1e-15
self.rev_comp_rtol = 1e-15
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCTestModeFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNHWCTestModebf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWC(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-11
self.fw_comp_rtol = 1e-11
self.rev_comp_atol = 1e-11
self.rev_comp_rtol = 1e-11
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNHWCbf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCShape2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [4, 8, 16, 32]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCMomentum2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.9
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCEps2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-06
self.data_format = "NHWC"
self.use_global_stats = None
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册