未验证 提交 58b4c60f 编写于 作者: C cyber-pioneer 提交者: GitHub

[prim] move batch_norm prim test to op_test (#54458)

* move batch_norm prim test to op_test

* fix optest bug

* add test to cmake

* add cinn test case

* fix batch_norm prim grad bf16

* fix code

* add cuda check

* fix batch_norm bfloat16

* fix cpu bfloat16 bug

* skip non-bfloat16-supported platform

* fix code

* fix cinn rtol and atol in bfloat16

* fix name

* fix config
上级 f7eb03c6
...@@ -1210,12 +1210,17 @@ void batch_norm_grad(const Tensor& x, ...@@ -1210,12 +1210,17 @@ void batch_norm_grad(const Tensor& x,
Tensor x_data = x; Tensor x_data = x;
Tensor out_grad_data = out_grad; Tensor out_grad_data = out_grad;
if (x.dtype() == phi::DataType::FLOAT16) {
bool need_cast = x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16;
if (need_cast) {
x_data = cast<T>(x, phi::DataType::FLOAT32); x_data = cast<T>(x, phi::DataType::FLOAT32);
} }
if (out_grad.dtype() == phi::DataType::FLOAT16) { if (out_grad.dtype() == phi::DataType::FLOAT16 ||
out_grad.dtype() == phi::DataType::BFLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32); out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
} }
auto x_dims = x_data.dims(); auto x_dims = x_data.dims();
const int C = (data_layout_ == DataLayout::kNCHW ? x_dims[1] const int C = (data_layout_ == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
...@@ -1278,7 +1283,7 @@ void batch_norm_grad(const Tensor& x, ...@@ -1278,7 +1283,7 @@ void batch_norm_grad(const Tensor& x,
if (use_global_stats) { if (use_global_stats) {
auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad;
auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim); auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim);
if (x.dtype() == phi::DataType::FLOAT16) { if (need_cast) {
nchw_x_grad = cast<T>(nchw_x_grad, x.dtype()); nchw_x_grad = cast<T>(nchw_x_grad, x.dtype());
} }
set_output<T>(nchw_x_grad, x_grad); set_output<T>(nchw_x_grad, x_grad);
...@@ -1291,7 +1296,7 @@ void batch_norm_grad(const Tensor& x, ...@@ -1291,7 +1296,7 @@ void batch_norm_grad(const Tensor& x,
auto x_grad_data = part1 * part2; auto x_grad_data = part1 * part2;
auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim); auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim);
if (x.dtype() == phi::DataType::FLOAT16) { if (need_cast) {
nchw_x_grad = cast<T>(nchw_x_grad, x.dtype()); nchw_x_grad = cast<T>(nchw_x_grad, x.dtype());
} }
set_output<T>(nchw_x_grad, x_grad); set_output<T>(nchw_x_grad, x_grad);
...@@ -1314,7 +1319,7 @@ void batch_norm_grad(const Tensor& x, ...@@ -1314,7 +1319,7 @@ void batch_norm_grad(const Tensor& x,
out_grad_data * (x_data - mean_data), reduce_axis, dtype, false); out_grad_data * (x_data - mean_data), reduce_axis, dtype, false);
if (use_global_stats) { if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data; auto x_grad_data = scale * rsqrt_var * out_grad_data;
if (x.dtype() == phi::DataType::FLOAT16) { if (need_cast) {
x_grad_data = cast<T>(x_grad_data, x.dtype()); x_grad_data = cast<T>(x_grad_data, x.dtype());
} }
set_output<T>(x_grad_data, x_grad); set_output<T>(x_grad_data, x_grad);
...@@ -1328,7 +1333,7 @@ void batch_norm_grad(const Tensor& x, ...@@ -1328,7 +1333,7 @@ void batch_norm_grad(const Tensor& x,
out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2; out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2;
auto x_grad_data = part1 * part2; auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) { if (need_cast) {
x_grad_data = cast<T>(x_grad_data, x.dtype()); x_grad_data = cast<T>(x_grad_data, x.dtype());
} }
set_output<T>(x_grad_data, x_grad); set_output<T>(x_grad_data, x_grad);
......
...@@ -79,9 +79,12 @@ def composite_batchnorm( ...@@ -79,9 +79,12 @@ def composite_batchnorm(
is_amp = False is_amp = False
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16": dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True is_amp = True
x = cast(x, "float32") x = cast(x, "float32")
scale = cast(scale, "float32") if scale else scale
bias = cast(bias, "float32") if bias else bias
feature_axis = ( feature_axis = (
1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1 1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1
...@@ -124,7 +127,7 @@ def composite_batchnorm( ...@@ -124,7 +127,7 @@ def composite_batchnorm(
else: else:
y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape)
if is_amp: if is_amp:
y = cast(y, "float16") y = cast(y, dtype)
# add op assign to detach tensor in void unsafe change outside the rule. # add op assign to detach tensor in void unsafe change outside the rule.
batch_mean_ = assign(batch_mean) batch_mean_ = assign(batch_mean)
......
...@@ -240,7 +240,9 @@ def batch_norm( ...@@ -240,7 +240,9 @@ def batch_norm(
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
param_dtype = ( param_dtype = (
x.dtype if convert_dtype(x.dtype) != 'float16' else 'float32' x.dtype
if convert_dtype(x.dtype) not in ['float16', 'uint16']
else 'float32'
) )
saved_mean = helper.create_variable_for_type_inference( saved_mean = helper.create_variable_for_type_inference(
dtype=param_dtype, stop_gradient=True dtype=param_dtype, stop_gradient=True
......
...@@ -953,6 +953,7 @@ if(WITH_NV_JETSON) ...@@ -953,6 +953,7 @@ if(WITH_NV_JETSON)
set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 1200) set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 1200)
set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 1200) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 1200)
set_tests_properties(test_norm_op PROPERTIES TIMEOUT 1200) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 1200)
set_tests_properties(test_batch_norm_op_prim PROPERTIES TIMEOUT 1500)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 1500) set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 1500)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 1500) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 1500)
else() else()
...@@ -961,6 +962,7 @@ else() ...@@ -961,6 +962,7 @@ else()
set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_batch_norm_op_prim PROPERTIES TIMEOUT 250)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 250) set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150)
endif() endif()
......
...@@ -1624,6 +1624,7 @@ class OpTest(unittest.TestCase): ...@@ -1624,6 +1624,7 @@ class OpTest(unittest.TestCase):
equal_nan=False, equal_nan=False,
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
only_check_prim=False,
inplace_atol=None, inplace_atol=None,
check_cinn=False, check_cinn=False,
): ):
...@@ -2033,6 +2034,8 @@ class OpTest(unittest.TestCase): ...@@ -2033,6 +2034,8 @@ class OpTest(unittest.TestCase):
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
self.__class__.check_prim = True self.__class__.check_prim = True
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
if only_check_prim:
return
static_checker = StaticChecker(self, self.outputs) static_checker = StaticChecker(self, self.outputs)
static_checker.check() static_checker.check()
...@@ -2150,6 +2153,7 @@ class OpTest(unittest.TestCase): ...@@ -2150,6 +2153,7 @@ class OpTest(unittest.TestCase):
check_prim=False, check_prim=False,
inplace_atol=None, inplace_atol=None,
check_cinn=False, check_cinn=False,
only_check_prim=False,
): ):
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
if self.is_mkldnn_op(): if self.is_mkldnn_op():
...@@ -2171,9 +2175,12 @@ class OpTest(unittest.TestCase): ...@@ -2171,9 +2175,12 @@ class OpTest(unittest.TestCase):
equal_nan, equal_nan,
check_dygraph=check_dygraph, check_dygraph=check_dygraph,
check_prim=check_prim, check_prim=check_prim,
only_check_prim=only_check_prim,
inplace_atol=inplace_atol, inplace_atol=inplace_atol,
check_cinn=check_cinn, check_cinn=check_cinn,
) )
if not res and only_check_prim:
continue
if check_dygraph: if check_dygraph:
outs, dygraph_dygraph_outs, fetch_list = res outs, dygraph_dygraph_outs, fetch_list = res
else: else:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import (
OpTest,
_set_use_system_allocator,
convert_float_to_uint16,
)
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
paddle.enable_static()
np.random.seed(123)
paddle.seed(123)
_set_use_system_allocator(True)
def batch_norm_wrapper(
x,
running_mean,
running_variance,
weight,
bias,
is_test,
momentum,
epsilon,
data_format,
use_global_stats,
):
y = F.batch_norm(
x,
running_mean,
running_variance,
weight,
bias,
training=not is_test,
momentum=momentum,
epsilon=epsilon,
data_format=data_format,
use_global_stats=use_global_stats,
)
z = F.relu(y)
return z
class TestBatchNormOp(OpTest):
def setUp(self):
self.python_api = batch_norm_wrapper
self.public_python_api = batch_norm_wrapper
self.op_type = "batch_norm"
self.prim_op_type = "comp"
self.python_out_sig = ["Y"]
self.initConfig()
self.initTestCase()
def test_check_output(self):
if self.dtype not in ("uint16", "float16"):
self.check_output_with_place(
core.CPUPlace(),
no_check_set=None,
check_prim=True,
only_check_prim=True,
)
if paddle.is_compiled_with_cuda():
self.check_output_with_place(
core.CUDAPlace(0),
no_check_set=None,
check_prim=True,
only_check_prim=True,
)
def test_check_grad_x(self):
if self.dtype not in ("uint16", "float16"):
self.check_grad_with_place(
core.CPUPlace(),
["X"],
['Y'],
user_defined_grad_outputs=self.out_grad,
check_prim=True,
only_check_prim=True,
)
elif self.data_format == "NCHW" and paddle.is_compiled_with_cuda():
# origin batch_norm cuda kernel differ in nhwc x_grad whether to calculate scale_grad and bias_grad
self.check_grad_with_place(
core.CUDAPlace(0),
["X"],
['Y'],
user_defined_grad_outputs=self.out_grad,
check_prim=True,
only_check_prim=True,
)
def test_check_grad_scale_bias(self):
self.enable_cinn = False
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
if self.dtype not in ("uint16", "float16"):
self.check_grad_with_place(
core.CPUPlace(),
["X", "Scale", "Bias"],
['Y'],
user_defined_grad_outputs=self.out_grad,
check_prim=True,
only_check_prim=True,
)
if paddle.is_compiled_with_cuda():
self.check_grad_with_place(
core.CUDAPlace(0),
["X", "Scale", "Bias"],
['Y'],
user_defined_grad_outputs=self.out_grad,
check_prim=True,
only_check_prim=True,
)
# restore init config
self.initConfig()
def initConfig(self):
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.cinn_atol = 1e-5
self.cinn_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 24, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
def initTestCase(self):
if (
self.dtype in ("uint16", "float16")
and not paddle.is_compiled_with_cuda()
):
self.__class__.op_type = self.op_type
self.__class__.no_need_check_grad = True
return
np.random.seed(123)
self.C = self.shape[1] if self.data_format == "NCHW" else self.shape[-1]
if self.dtype == "uint16":
x = convert_float_to_uint16(
np.random.random(self.shape).astype("float32")
)
else:
x = np.random.random(self.shape).astype(self.dtype)
self.var_dtype = (
"float32" if self.dtype in ["float16", "uint16"] else self.dtype
)
weight = np.random.random(self.C).astype(self.var_dtype)
bias = np.random.random(self.C).astype(self.var_dtype)
running_mean = np.random.random(self.C).astype(self.var_dtype)
running_var = np.random.random(self.C).astype(self.var_dtype)
if self.dtype == "uint16":
self.out_grad = [
convert_float_to_uint16(
np.random.random(self.shape).astype("float32")
)
]
else:
self.out_grad = [np.random.random(self.shape).astype(self.dtype)]
self.inputs = {
"X": x,
"Scale": weight,
"Bias": bias,
"Mean": running_mean,
"Variance": running_var,
}
if self.use_global_stats is None:
self.use_global_stats = not self.training
trainable_statistics = False
else:
trainable_statistics = not self.use_global_stats
self.attrs = {
"momentum": self.momentum,
"epsilon": self.epsilon,
"is_test": not self.training,
"data_layout": self.data_format,
"use_global_stats": self.use_global_stats,
"trainable_statistics": trainable_statistics,
}
paddle.disable_static()
(
y,
running_mean,
running_var,
saved_mean,
saved_variance,
_,
) = paddle._C_ops.batch_norm(
paddle.to_tensor(x),
paddle.to_tensor(running_mean),
paddle.to_tensor(running_var),
paddle.to_tensor(weight),
paddle.to_tensor(bias),
not self.training,
self.momentum,
self.epsilon,
self.data_format,
self.use_global_stats,
trainable_statistics,
)
if self.dtype == "uint16":
y = convert_float_to_uint16(y)
paddle.enable_static()
self.outputs = {
"Y": y,
"MeanOut": running_mean,
"VarianceOut": running_var,
"SavedMean": saved_mean,
"SavedVariance": saved_variance,
}
class TestBatchNormOpNCHWTestMode(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = True
class TestBatchNormOpNCHWFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-11
self.fw_comp_rtol = 1e-11
self.rev_comp_atol = 1e-11
self.rev_comp_rtol = 1e-11
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-15
self.fw_comp_rtol = 1e-15
self.rev_comp_atol = 1e-15
self.rev_comp_rtol = 1e-15
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
class TestBatchNormOpNCHWFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
class TestBatchNormOpNCHWTestModeFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNCHWbf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNCHWTestModebf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
class TestBatchNormOpNHWC(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-11
self.fw_comp_rtol = 1e-11
self.rev_comp_atol = 1e-11
self.rev_comp_rtol = 1e-11
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNHWCbf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNCHWShape2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [4, 8, 16, 32]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
class TestBatchNormOpNCHWMomentum2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.9
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
class TestBatchNormOpNCHWEps2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-06
self.data_format = "NCHW"
self.use_global_stats = None
class TestBatchNormOpNHWCShape2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [4, 8, 16, 32]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCMomentum2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.9
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWCEps2(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = True
self.momentum = 0.1
self.epsilon = 1e-06
self.data_format = "NHWC"
self.use_global_stats = None
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册