From ebff68fa74c3f278b97326fec56d775a94323623 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 24 Sep 2019 20:40:37 +0800 Subject: [PATCH] Add float16 support to `sync_batch_norm_op` (#19681) * Add float16 support to `sync_batch_norm_op` test=develop * Add test for sync_bn with FP16 input test=develop --- paddle/fluid/operators/sync_batch_norm_op.cu | 297 ++++++++++-------- .../unittests/test_sync_batch_norm_op.py | 74 ++++- 2 files changed, 219 insertions(+), 152 deletions(-) diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu index 4f584cbb56a..fb4ae48eb07 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +// clang-format off #include #include +#include #include #include #include "cub/cub.cuh" @@ -32,24 +34,27 @@ using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; +template +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; template -__global__ void KeLocalStats(const T *x, int N, int M, int C, T *mean_var) { - typedef cub::BlockReduce BlockReduce; +__global__ void KeLocalStats(const T *x, int N, int M, int C, + BatchNormParamType *mean_var) { + typedef cub::BlockReduce, BlockDim> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; for (int k = blockIdx.x; k < C; k += gridDim.x) { - T x_sum = 0; - T x2_sum = 0; + BatchNormParamType x_sum = 0.; + BatchNormParamType x2_sum = 0.; for (int i = threadIdx.x; i < N * M; i += BlockDim) { int id = layout == framework::DataLayout::kNCHW ? (i / M) * C * M + k * M + i % M : i * C + k; - T x_in = x[id]; + auto x_in = static_cast>(x[id]); x_sum += x_in; x2_sum += x_in * x_in; } __syncthreads(); - T out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); + auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); __syncthreads(); if (threadIdx.x == 0) { mean_var[k] = out / (N * M); @@ -61,22 +66,24 @@ __global__ void KeLocalStats(const T *x, int N, int M, int C, T *mean_var) { } } if (blockIdx.x == 0 && threadIdx.x == 0) { - mean_var[2 * C] = static_cast(1.0); + mean_var[2 * C] = static_cast>(1.0); } } template -__global__ void KeSyncAndMovingStats(T *means, T *variances, T *num_dev, - const int C, const T momentum, - const double epsilon, T *sv_mean_data, - T *sv_inv_var_data, T *moving_means, - T *moving_variances) { +__global__ void KeSyncAndMovingStats( + BatchNormParamType *means, BatchNormParamType *variances, + BatchNormParamType *num_dev, const int C, + const BatchNormParamType momentum, const double epsilon, + BatchNormParamType *sv_mean_data, BatchNormParamType *sv_inv_var_data, + BatchNormParamType *moving_means, + BatchNormParamType *moving_variances) { // sync stats across multi-devices int gid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (int i = gid; i < C; i += stride) { - T mean = means[i] / (*num_dev); - T var = variances[i] / (*num_dev); + auto mean = means[i] / (*num_dev); + auto var = variances[i] / (*num_dev); var = var - mean * mean; // sync stats @@ -92,15 +99,21 @@ __global__ void KeSyncAndMovingStats(T *means, T *variances, T *num_dev, } template -static __global__ void KeNormAffine(const T *x, const T *scale, const T *bias, - const T *mean, const T *variance, +static __global__ void KeNormAffine(const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const BatchNormParamType *mean, + const BatchNormParamType *variance, const double epsilon, const int C, const int M, const int num, T *y) { int gid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (int i = gid; i < num; i += stride) { const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; - y[i] = (x[i] - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; + auto x_i = static_cast>(x[i]); + auto y_i = + (x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; + y[i] = static_cast(y_i); } } @@ -128,14 +141,14 @@ class SyncBatchNormKernel : public framework::OpKernel { int x_numel = x->numel(); const T *x_d = x->data(); - const T *s_d = ctx.Input("Scale")->data(); - const T *b_d = ctx.Input("Bias")->data(); + const auto *s_d = ctx.Input("Scale")->data>(); + const auto *b_d = ctx.Input("Bias")->data>(); auto *y = ctx.Output("Y"); T *y_d = y->mutable_data(ctx.GetPlace()); - const T *mean_data = nullptr; - const T *var_data = nullptr; + const BatchNormParamType *mean_data = nullptr; + const BatchNormParamType *var_data = nullptr; auto &dev_ctx = ctx.cuda_device_context(); auto stream = dev_ctx.stream(); @@ -148,51 +161,53 @@ class SyncBatchNormKernel : public framework::OpKernel { if (is_test) { const auto *est_mean = ctx.Input("Mean"); const auto *est_var = ctx.Input("Variance"); - mean_data = est_mean->data(); - var_data = est_var->data(); + mean_data = est_mean->data>(); + var_data = est_var->data>(); } else { // x, x^2, 1, here 1 is used to calc device num // device num also can be got from platform::DeviceContextPool - const int bytes = (C * 2 + 1) * sizeof(T); + const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); alloc_ptr = memory::Alloc(dev_ctx, bytes); - T *stats = reinterpret_cast(alloc_ptr->ptr()); + auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); const int threads = 256; int grid = std::min(C, (max_threads + threads - 1) / threads); if (layout == framework::DataLayout::kNCHW) { - KeLocalStats< - T, threads, - framework::DataLayout::kNCHW><<>>( - x_d, N, H * W * D, C, stats); + KeLocalStats + <<>>(x_d, N, H * W * D, C, stats); } else { - KeLocalStats< - T, threads, - framework::DataLayout::kNHWC><<>>( - x_d, N, H * W * D, C, stats); + KeLocalStats + <<>>(x_d, N, H * W * D, C, stats); } + // moving mean/variance + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *est_mean_data = + mean_out->mutable_data>(ctx.GetPlace()); + auto *est_var_data = + variance_out->mutable_data>(ctx.GetPlace()); + + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_inv_variance = ctx.Output("SavedVariance"); + auto *sv_mean_data = + saved_mean->mutable_data>(ctx.GetPlace()); + auto *sv_inv_var_data = + saved_inv_variance->mutable_data>( + ctx.GetPlace()); + Tensor c_g_st; - T *c_g_st_d = c_g_st.mutable_data({2 * C + 1}, platform::CPUPlace()); + auto *c_g_st_d = c_g_st.mutable_data>( + {2 * C + 1}, platform::CPUPlace()); auto gplace = boost::get(ctx.GetPlace()); memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); - int dtype = platform::ToNCCLDataType(x->type()); + int dtype = platform::ToNCCLDataType(mean_out->type()); // In-place operation PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, comm, stream)); - // moving mean/variance - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - T *est_mean_data = mean_out->mutable_data(ctx.GetPlace()); - T *est_var_data = variance_out->mutable_data(ctx.GetPlace()); - - auto *saved_mean = ctx.Output("SavedMean"); - auto *saved_inv_variance = ctx.Output("SavedVariance"); - T *sv_mean_data = saved_mean->mutable_data(ctx.GetPlace()); - T *sv_inv_var_data = saved_inv_variance->mutable_data(ctx.GetPlace()); - // Note, Input('Mean')/Input('Variance') share variable with // Output('MeanOut')/Output('VarianceOut') KeSyncAndMovingStats<<<(C + block - 1) / block, block, 0, stream>>>( @@ -205,39 +220,40 @@ class SyncBatchNormKernel : public framework::OpKernel { int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; if (layout == framework::DataLayout::kNCHW) { - KeNormAffine<<>>( - x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, - y_d); + KeNormAffine + <<>>(x_d, s_d, b_d, mean_data, var_data, + epsilon, C, H * W * D, x_numel, y_d); } else { - KeNormAffine<<>>( - x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, - y_d); + KeNormAffine + <<>>(x_d, s_d, b_d, mean_data, var_data, + epsilon, C, H * W * D, x_numel, y_d); } } }; template -__global__ void KeBackwardLocalStats(const T *dy, const T *x, const T *means, - int N, int M, int C, T *sum_dy_prod) { - typedef cub::BlockReduce BlockReduce; +__global__ void KeBackwardLocalStats(const T *dy, const T *x, + const BatchNormParamType *means, int N, + int M, int C, + BatchNormParamType *sum_dy_prod) { + typedef cub::BlockReduce, BlockDim> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; for (int k = blockIdx.x; k < C; k += gridDim.x) { - T sum1 = 0; - T sum2 = 0; - T mean = means[k]; + BatchNormParamType sum1 = 0.; + BatchNormParamType sum2 = 0.; + auto mean = means[k]; for (int i = threadIdx.x; i < N * M; i += blockDim.x) { int id = layout == framework::DataLayout::kNCHW ? (i / M) * C * M + k * M + i % M : i * C + k; - T g = dy[id]; + auto g = static_cast>(dy[id]); sum1 += g; - sum2 += g * (x[id] - mean); + auto x_i = static_cast>(x[id]); + sum2 += g * (x_i - mean); } __syncthreads(); - T out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); + auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); __syncthreads(); if (threadIdx.x == 0) { sum_dy_prod[k] = out; @@ -249,72 +265,75 @@ __global__ void KeBackwardLocalStats(const T *dy, const T *x, const T *means, } } if (blockIdx.x == 0 && threadIdx.x == 0) { - sum_dy_prod[2 * C] = static_cast(1.0); + sum_dy_prod[2 * C] = 1.0; } } template -static __global__ void KeBNBackwardScaleBias(const T *dy, const T *x, - const T *mean, - const T *inv_variance, - const double epsilon, const int N, - const int C, const int HxW, - T *dscale, T *dbias) { +static __global__ void KeBNBackwardScaleBias( + const T *dy, const T *x, const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, const double epsilon, + const int N, const int C, const int HxW, BatchNormParamType *dscale, + BatchNormParamType *dbias) { const int outer_size = C; const int inner_size = N * HxW; - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce, BlockDim> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - T ds_sum = static_cast(0); - T db_sum = static_cast(0); + BatchNormParamType ds_sum = 0.; + BatchNormParamType db_sum = 0.; - T inv_var_i = inv_variance[i]; - T mean_i = mean[i]; + auto inv_var_i = inv_variance[i]; + auto mean_i = mean[i]; for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { const int id = layout == framework::DataLayout::kNCHW ? ((j / HxW) * C + i) * HxW + (j % HxW) : j * outer_size + i; - ds_sum += dy[id] * (x[id] - mean_i); - db_sum += dy[id]; + auto x_i = static_cast>(x[id]); + auto dy_i = static_cast>(dy[id]); + ds_sum += dy_i * (x_i - mean_i); + db_sum += dy_i; } __syncthreads(); - double os = BlockReduce(temp_storage) - .Reduce(static_cast(ds_sum), cub::Sum()); + auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum()); __syncthreads(); - double ob = BlockReduce(temp_storage) - .Reduce(static_cast(db_sum), cub::Sum()); + auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum()); __syncthreads(); if (threadIdx.x == 0) { - dscale[i] = static_cast(os * inv_var_i); - dbias[i] = static_cast(ob); + dscale[i] = os * inv_var_i; + dbias[i] = ob; } __syncthreads(); } } template -static __global__ void KeBNBackwardData(const T *dy, const T *x, const T *beta, - const T *mean, const T *inv_variance, - const T *g_sum_dy, - const T *g_sum_dy_prod, - const T *num_dev, const double epsilon, - const int C, const int HxW, - const int num, T *dx) { +static __global__ void KeBNBackwardData( + const T *dy, const T *x, const BatchNormParamType *gamma, + const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, + const BatchNormParamType *g_sum_dy, + const BatchNormParamType *g_sum_dy_prod, + const BatchNormParamType *num_dev, const double epsilon, const int C, + const int HxW, const int num, T *dx) { int gid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; - T scale = static_cast(C) / num; - T dev_num = num_dev[0]; + auto scale = static_cast>(C) / num; + auto dev_num = num_dev[0]; for (int i = gid; i < num; i += stride) { const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; - T inv_var = inv_variance[c]; - T s_d = beta[c]; - T gvar = -1.0 * (g_sum_dy_prod[c] / dev_num) * s_d * inv_var * - (inv_var * inv_var); - T gmean = -1.0 * (g_sum_dy[c] / dev_num) * s_d * inv_var; - - dx[i] = - dy[i] * s_d * inv_var + gmean * scale + gvar * scale * (x[i] - mean[c]); + auto inv_var = inv_variance[c]; + auto s_d = gamma[c]; + auto gvar = + -((g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var)); + auto gmean = -((g_sum_dy[c] / dev_num) * s_d * inv_var); + + auto x_i = static_cast>(x[i]); + auto dy_i = static_cast>(dy[i]); + auto dx_i = + dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]); + dx[i] = static_cast(dx_i); } } @@ -348,8 +367,8 @@ class SyncBatchNormGradKernel : public framework::OpKernel { d_x->mutable_data(ctx.GetPlace()); if (d_scale && d_bias) { - d_scale->mutable_data(ctx.GetPlace()); - d_bias->mutable_data(ctx.GetPlace()); + d_scale->mutable_data>(ctx.GetPlace()); + d_bias->mutable_data>(ctx.GetPlace()); } PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims()[0], C); @@ -371,11 +390,13 @@ class SyncBatchNormGradKernel : public framework::OpKernel { auto stream = dev_ctx.stream(); auto *comm = dev_ctx.nccl_comm(); - const T *saved_mean = ctx.Input("SavedMean")->data(); - const T *saved_inv_var = ctx.Input("SavedVariance")->data(); - const int bytes = (C * 2 + 1) * sizeof(T); + const auto *saved_mean = + ctx.Input("SavedMean")->data>(); + const auto *saved_inv_var = + ctx.Input("SavedVariance")->data>(); + const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); auto alloc_ptr = memory::Alloc(dev_ctx, bytes); - T *stats = reinterpret_cast(alloc_ptr->ptr()); + auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); const int threads = 256; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); @@ -384,17 +405,15 @@ class SyncBatchNormGradKernel : public framework::OpKernel { int fsize = H * W * D; if (layout == framework::DataLayout::kNCHW) { - KeBackwardLocalStats< - T, threads, - framework::DataLayout::kNCHW><<>>( - dy_d, x_d, saved_mean, N, fsize, C, stats); + KeBackwardLocalStats + <<>>(dy_d, x_d, saved_mean, N, fsize, C, + stats); } else { - KeBackwardLocalStats< - T, threads, - framework::DataLayout::kNHWC><<>>( - dy_d, x_d, saved_mean, N, fsize, C, stats); + KeBackwardLocalStats + <<>>(dy_d, x_d, saved_mean, N, fsize, C, + stats); } - int dtype = platform::ToNCCLDataType(x->type()); + int dtype = platform::ToNCCLDataType(scale->type()); // In-place operation PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, @@ -404,33 +423,33 @@ class SyncBatchNormGradKernel : public framework::OpKernel { int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; if (layout == framework::DataLayout::kNCHW) { if (d_scale && d_bias) { - KeBNBackwardScaleBias< - T, threads, - framework::DataLayout::kNCHW><<>>( - dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, - d_scale->data(), d_bias->data()); + KeBNBackwardScaleBias + <<>>( + dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, + d_scale->data>(), + d_bias->data>()); } if (d_x) { - KeBNBackwardData< - T, framework::DataLayout::kNCHW><<>>( - dy_d, x_d, scale->data(), saved_mean, saved_inv_var, stats, - stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(), - d_x->data()); + KeBNBackwardData + <<>>( + dy_d, x_d, scale->data>(), saved_mean, + saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, + fsize, x->numel(), d_x->data()); } } else { if (d_scale && d_bias) { - KeBNBackwardScaleBias< - T, threads, - framework::DataLayout::kNHWC><<>>( - dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, - d_scale->data(), d_bias->data()); + KeBNBackwardScaleBias + <<>>( + dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, + d_scale->data>(), + d_bias->data>()); } if (d_x) { - KeBNBackwardData< - T, framework::DataLayout::kNHWC><<>>( - dy_d, x_d, scale->data(), saved_mean, saved_inv_var, stats, - stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(), - d_x->data()); + KeBNBackwardData + <<>>( + dy_d, x_d, scale->data>(), saved_mean, + saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, + fsize, x->numel(), d_x->data()); } } } @@ -443,8 +462,12 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( sync_batch_norm, ops::SyncBatchNormKernel, - ops::SyncBatchNormKernel); + ops::SyncBatchNormKernel, + ops::SyncBatchNormKernel); REGISTER_OP_CUDA_KERNEL( sync_batch_norm_grad, ops::SyncBatchNormGradKernel, - ops::SyncBatchNormGradKernel); + ops::SyncBatchNormGradKernel, + ops::SyncBatchNormGradKernel); + +// clang-format on diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index b8a2515e716..a9eccf4a210 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +test for sync bachnorm op. +for both FP64 and FP16 input. +""" from __future__ import print_function @@ -22,9 +26,24 @@ import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler +from op_test import OpTest + + +def create_or_get_tensor(scope, var_name, var, place): + """Get tensor, if not found, create a new one.""" + tensor = scope.var(var_name).get_tensor() + if var is not None: + assert isinstance(var, np.ndarray) + tensor.set_recursive_sequence_lengths([]) + tensor.set(var, place) + return tensor + class TestSyncBatchNormOpTraining(unittest.TestCase): + """sync_batch_norm op test.""" + def setUp(self): + """Setup.""" #self.dtype = np.float32 self.dtype = np.float64 self.N = 32 @@ -32,17 +51,20 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): self.H = 64 self.W = 32 self.dshape = [self.N, self.C, self.H, self.W] + self.atol = 1e-3 - def build_program(self, - place, - layout, - seed, - sync_bn=False, - only_forward=False): + def _build_program(self, + place, + layout, + seed, + sync_bn=False, + only_forward=False): + """Build program.""" main = fluid.Program() startup = fluid.Program() main.random_seed = seed startup.random_seed = seed + use_cudnn = self.dtype == np.float16 with fluid.unique_name.guard(): with fluid.program_guard(main, startup): data = fluid.layers.data( @@ -56,7 +78,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): filter_size=1, param_attr=fluid.ParamAttr(name='conv2d_weight'), bias_attr=False, - use_cudnn=False) + use_cudnn=use_cudnn) bn = fluid.layers.batch_norm( conv, param_attr=fluid.ParamAttr(name='bn_scale'), @@ -65,6 +87,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): moving_variance_name='bn_moving_variance', data_layout=layout, is_test=only_forward) + bn = fluid.layers.cast(bn, 'float64') sigmoid = fluid.layers.sigmoid(bn) out = fluid.layers.reduce_sum(sigmoid) if not sync_bn: @@ -74,13 +97,18 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): sgd_opt.backward(out) return main, startup, [out, conv, bn] - def compare(self, place, layout, only_forward): + def _compare(self, place, layout, only_forward): + """Compare results.""" seed = 10 os.environ['FLAGS_cudnn_deterministic'] = "1" + scope = core.Scope() data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 + data = create_or_get_tensor(scope, "input", + OpTest.np_dtype_to_fluid_dtype(data), place) + # Single-GPU, N = 32 per GPU - main, startup, outs = self.build_program(place, layout, seed, False, - only_forward) + main, startup, outs = self._build_program(place, layout, seed, False, + only_forward) exe = fluid.Executor(place) exe.run(startup) fetch_names = [v.name for v in outs] + [ @@ -99,8 +127,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ##################################################################### # Multi-GPUs, self.N / core.get_cuda_device_count() per GPU assert core.get_cuda_device_count() > 1 - main, startup, outs = self.build_program(place, layout, seed, True, - only_forward) + main, startup, outs = self._build_program(place, layout, seed, True, + only_forward) exe = fluid.Executor(place) exe.run(startup) fetch_names = [v.name for v in outs] + [ @@ -133,27 +161,43 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): sync_bn_val = sync_bn_val[:bn_val.shape[0]] self.assertTrue( np.allclose( - bn_val, sync_bn_val, atol=1e-3), + bn_val, sync_bn_val, atol=self.atol), "Output (" + fetch_names[i] + ") has diff. \n" + "\nBN " + str(bn_val) + "\n" + "Sync BN " + str(sync_bn_val)) def test_train(self): + """Test training.""" if not core.is_compiled_with_cuda(): return places = [core.CUDAPlace(0)] for place in places: for layout in ["NCHW", "NHWC"]: - self.compare(place, layout, False) + self._compare(place, layout, False) def test_infer(self): + """Test inference.""" if not core.is_compiled_with_cuda(): return places = [core.CUDAPlace(0)] for place in places: for layout in ["NCHW", "NHWC"]: - self.compare(place, layout, True) + self._compare(place, layout, True) + + +class TestFP16SyncBatchNormOpTraining(TestSyncBatchNormOpTraining): + """sync_batch_norm op test for FP16 input.""" + + def setUp(self): + """Setup.""" + self.dtype = np.float16 + self.N = 32 + self.C = 16 + self.H = 64 + self.W = 32 + self.dshape = [self.N, self.C, self.H, self.W] + self.atol = 1e-2 if __name__ == '__main__': -- GitLab