diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/dist_fleet_sync_batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd71ca9e58ff7520cfe9009674b932e1866e7d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_fleet_sync_batch_norm.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import ast +import os +import random + +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.distributed import fleet +from paddle.static import Executor, Program, program_guard + + +def get_program(args): + main, startup = Program(), Program() + main.random_seed = 10 + startup.random_seed = 10 + with fluid.unique_name.guard(): + with program_guard(main, startup): + data = paddle.static.data( + name='input', + shape=args.dshape, + dtype=args.dtype, + ) + data.desc.set_need_check_feed(False) + conv = paddle.static.nn.conv2d( + input=data, + num_filters=32, + filter_size=1, + param_attr=fluid.ParamAttr(name='conv2d_weight'), + bias_attr=False, + use_cudnn=args.use_cudnn, + ) + bn = paddle.static.nn.batch_norm( + conv, + param_attr=fluid.ParamAttr(name='bn_scale'), + bias_attr=fluid.ParamAttr(name='bn_bias'), + moving_mean_name='bn_moving_mean', + moving_variance_name='bn_moving_variance', + data_layout=args.layout, + is_test=args.only_forward, + ) + if core.is_compiled_with_rocm(): + bn = paddle.cast(bn, 'float32') + else: + bn = paddle.cast(bn, 'float64') + sigmoid = paddle.nn.functional.sigmoid(bn) + out = paddle.sum(sigmoid) + if not args.only_forward: + sgd_opt = fluid.optimizer.SGD(learning_rate=0.0) + opt = fleet.distributed_optimizer(sgd_opt) + opt.minimize(out) + return main, startup, [out, conv, bn] + + +def train(args): + + build_strategy = fluid.BuildStrategy() + build_strategy.sync_batch_norm = True + build_strategy.enable_inplace = False + build_strategy.memory_optimize = False + + distributed_strategy = fleet.DistributedStrategy() + distributed_strategy.build_strategy = build_strategy + distributed_strategy.without_graph_optimization = True + distributed_strategy.fuse_all_reduce_ops = True + distributed_strategy.fuse_grad_size_in_num = 8 + + fleet.init(is_collective=True, strategy=distributed_strategy) + main, startup, outs = get_program(args) + exe = Executor() + exe.run(startup) + + for nm in args.fetch_list: + fv = fluid.framework._get_var(str(nm), program=main) + fv.persistable = True + + fetch_list = [v.name for v in outs] + args.fetch_list + + rank = paddle.distributed.get_rank() + filepath = os.path.join( + args.data_dir, + 'input_{}_{}_{}_{}.npy'.format( + rank, args.only_forward, str(args.dtype), args.layout + ), + ) + data = np.load(filepath) + + comp_prog = fluid.compiler.CompiledProgram( + main, build_strategy=build_strategy + ) + sync_bn_fetches = exe.run( + program=comp_prog, feed={'input': data}, fetch_list=fetch_list + ) + + for i in range(0, len(sync_bn_fetches)): + file_path = os.path.join( + args.data_dir, + 'output_{}_{}_{}_{}.npy'.format( + rank, args.only_forward, str(args.dtype), i + ), + ) + np.save(file_path, sync_bn_fetches[i]) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, required=True) + parser.add_argument('--dshape', type=str, required=True) + parser.add_argument('--dtype', type=str, required=True) + parser.add_argument('--layout', type=str, required=True) + parser.add_argument('--fetch_list', type=str, required=True) + parser.add_argument('--use_cudnn', action='store_true') + parser.add_argument('--only_forward', action='store_true') + + args = parser.parse_args() + args.dshape = ast.literal_eval(args.dshape) + args.fetch_list = ast.literal_eval(args.fetch_list) + + paddle.enable_static() + + paddle.seed(0) + np.random.seed(0) + random.seed(0) + + train(args) diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index b48ce4d5cd5eeb935d6386d55242385831e46d10..4d2f69f163da203dab608b1232419b21bb17b095 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -17,6 +17,9 @@ for both FP64 and FP16 input. """ import os +import random +import subprocess +import tempfile import unittest import numpy as np @@ -27,7 +30,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.nn as nn -from paddle.fluid import Program, compiler, program_guard +from paddle.fluid import Program, program_guard _set_use_system_allocator(True) @@ -55,6 +58,39 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): self.W = 32 self.dshape = [self.N, self.C, self.H, self.W] self.atol = 1e-3 + self.data_dir = tempfile.TemporaryDirectory() + self.fleet_log_dir = tempfile.TemporaryDirectory() + + def tearDown(self) -> None: + self.data_dir.cleanup() + self.fleet_log_dir.cleanup() + + def multi_device_run(self, layout, fetch_list, only_forward=False): + cmds = [ + "python", + "-m", + "paddle.distributed.launch", + ] + cmds += ["--log_dir", self.fleet_log_dir.name] + cmds += ["dist_fleet_sync_batch_norm.py"] + cmds += ["--data_dir", self.data_dir.name] + + dshape = [ + self.N // core.get_cuda_device_count(), + self.C, + self.H, + self.W, + ] + cmds += ["--dshape", str(dshape)] + cmds += ["--dtype", str(self.dtype.__name__)] + cmds += ["--layout", layout] + cmds += ["--fetch_list", str(fetch_list)] + if only_forward: + cmds += ["--only_forward"] + if self.dtype == np.float16: + cmds += ["--use_cudnn"] + p = subprocess.run(cmds) + assert p.returncode == 0, f"Fleet train: Failed: {p}" def _build_program( self, place, layout, seed, sync_bn=False, only_forward=False @@ -108,8 +144,18 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): """Compare results.""" seed = 10 os.environ['FLAGS_cudnn_deterministic'] = "1" + paddle.enable_static() scope = core.Scope() data = np.random.random(size=self.dshape).astype(self.dtype) * 4.0 - 2 + stride = self.N // core.get_cuda_device_count() + for id in range(core.get_cuda_device_count()): + filepath = os.path.join( + self.data_dir.name, + 'input_{}_{}_{}_{}.npy'.format( + id, only_forward, str(self.dtype.__name__), layout + ), + ) + np.save(filepath, data[id * stride : (id + 1) * stride]) data = create_or_get_tensor( scope, "input", OpTest.np_dtype_to_fluid_dtype(data), place ) @@ -143,12 +189,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ##################################################################### # Multi-GPUs, self.N / core.get_cuda_device_count() per GPU assert core.get_cuda_device_count() > 1 - main, startup, outs = self._build_program( - place, layout, seed, True, only_forward - ) - exe = fluid.Executor(place) - exe.run(startup) - fetch_names = [v.name for v in outs] + [ + + fetch_names = [ 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', @@ -164,26 +206,24 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): 'conv2d_0.tmp_0@GRAD', ] fetch_names += others - for nm in fetch_names: - fv = fluid.framework._get_var(str(nm), program=main) - fv.persistable = True - build_strategy = fluid.BuildStrategy() - build_strategy.sync_batch_norm = True - build_strategy.enable_inplace = False - build_strategy.memory_optimize = False - comp_prog = compiler.CompiledProgram(main).with_data_parallel( - outs[0].name if not only_forward else None, - build_strategy=build_strategy, - ) - sync_bn_fetches = exe.run( - program=comp_prog, feed={'input': data}, fetch_list=fetch_names + + self.multi_device_run( + layout, fetch_list=fetch_names, only_forward=only_forward ) - for i in range(1, len(sync_bn_fetches)): + fetch_names = [v.name for v in outs] + fetch_names + + for i in range(1, len(bn_fetches)): bn_val = bn_fetches[i] - sync_bn_val = sync_bn_fetches[i] + file_path = os.path.join( + self.data_dir.name, + 'output_{}_{}_{}_{}.npy'.format( + 0, only_forward, self.dtype.__name__, i + ), + ) + sync_bn_val = np.load(file_path) if sync_bn_val.shape != bn_val.shape: - sync_bn_val = sync_bn_val[: bn_val.shape[0]] + bn_val = bn_val[:stride] np.testing.assert_allclose( bn_val, sync_bn_val, @@ -206,7 +246,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): places = [core.CUDAPlace(0)] for place in places: - for layout in ["NCHW", "NHWC"]: + for layout in ["NHWC", "NCHW"]: self._compare(place, layout, False) def test_infer(self): @@ -216,7 +256,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): places = [core.CUDAPlace(0)] for place in places: - for layout in ["NCHW", "NHWC"]: + for layout in ["NHWC", "NCHW"]: self._compare(place, layout, True) @@ -232,6 +272,8 @@ class TestFP16SyncBatchNormOpTraining(TestSyncBatchNormOpTraining): self.W = 32 self.dshape = [self.N, self.C, self.H, self.W] self.atol = 1e-2 + self.data_dir = tempfile.TemporaryDirectory() + self.fleet_log_dir = tempfile.TemporaryDirectory() class TestDygraphSyncBatchNormAPIError(unittest.TestCase): @@ -390,4 +432,7 @@ class TestDygraphSyncBatchNormDataFormatError(unittest.TestCase): if __name__ == '__main__': + paddle.seed(0) + np.random.seed(0) + random.seed(0) unittest.main()