From 31ed9c9eed2e939fc6160a7af173082dea453e1f Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 1 Feb 2021 00:36:10 +0800 Subject: [PATCH] Fleet distributed strategy support pure fp16 (#30754) --- .../framework/distributed_strategy.proto | 2 + .../fleet/base/distributed_strategy.py | 25 +++++- .../distributed/fleet/base/fleet_base.py | 18 +++- .../fleet/base/strategy_compiler.py | 3 + .../fleet/meta_optimizers/amp_optimizer.py | 11 ++- .../graph_execution_optimizer.py | 4 +- .../fluid/tests/unittests/CMakeLists.txt | 2 + .../unittests/fleet_meta_optimizer_base.py | 15 ++++ .../tests/unittests/test_fleet_amp_init.py | 82 +++++++++++++++---- .../test_fleet_amp_meta_optimizer.py | 15 ++++ .../tests/unittests/test_fleet_base_single.py | 2 + ...est_fleet_gradient_merge_meta_optimizer.py | 17 ++++ python/paddle/optimizer/adam.py | 2 +- 13 files changed, 178 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 07ea824dc7a..8754c3a0c43 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -44,6 +44,8 @@ message AMPConfig { repeated string custom_white_list = 7; repeated string custom_black_list = 8; repeated string custom_black_varnames = 9; + optional bool use_pure_fp16 = 10 [ default = false ]; + optional bool use_fp16_guard = 11 [ default = true ]; } message LocalSGDConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index f7a28f15e9b..186d9263dc5 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -49,6 +49,9 @@ def assign_configs_value(msg, config): for key in config: for f in fields: if key == f.name: + # LABEL_OPTIONAL = 1 + # LABEL_REPEATED = 3 + # LABEL_REQUIRED = 2 if f.label == 3: getattr(msg, f.name).extend(config[f.name]) elif f.label == 1 or f.label == 2: @@ -366,7 +369,14 @@ class DistributedStrategy(object): custom_black_list(list[str]): Users' custom black list which forbidden execution fp16. - Examples: + custom_black_varnames(list[str]): Users' custom black varibles' names. + + use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. + + use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. + Default True. Only takes effect when `use_pure_fp16` is turned on. + + Examples 1: .. code-block:: python @@ -376,6 +386,19 @@ class DistributedStrategy(object): strategy.amp_configs = { "init_loss_scaling": 32768, "custom_white_list": ['conv2d']} + + Examples 2: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.amp = True + # pure fp16 + strategy.amp_configs = { + "init_loss_scaling": 32768, + "use_pure_fp16": True + } """ return get_msg_dict(self.strategy.amp_configs) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 0e4559e6bc6..f4d62b9bf1b 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -196,6 +196,7 @@ class Fleet(object): else: if isinstance(role_maker, RoleMakerBase): self._role_maker = role_maker + self._is_collective = role_maker._is_collective else: raise ValueError( "`role_maker` should be subclass of `RoleMakerBase`, but got {}". @@ -1018,9 +1019,22 @@ class Fleet(object): if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: run_example_code() """ + # imitate target optimizer retrieval - return self.user_defined_optimizer.amp_init(place, scope, test_program, - use_fp16_test) + amp_optimizer = None + for optimizer in self.strategy_compiler._get_applied_meta_optimizer(): + if hasattr(optimizer, 'amp_init'): + amp_optimizer = optimizer + break + + if amp_optimizer is None: + if hasattr(self.user_defined_optimizer, 'amp_init'): + amp_optimizer = self.user_defined_optimizer + + assert amp_optimizer is not None, \ + "amp_init can only be used when the amp(auto mixed precision) strategy is turned on." + + return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test) def _final_strategy(self): if "valid_strategy" not in self._context: diff --git a/python/paddle/distributed/fleet/base/strategy_compiler.py b/python/paddle/distributed/fleet/base/strategy_compiler.py index 1d6fcee5442..7b146318abe 100644 --- a/python/paddle/distributed/fleet/base/strategy_compiler.py +++ b/python/paddle/distributed/fleet/base/strategy_compiler.py @@ -129,6 +129,9 @@ class StrategyCompiler(StrategyCompilerBase): self._meta_optimizer_candidates = [] self._graph_optimizer_candidates = [] + def _get_applied_meta_optimizer(self): + return self._meta_optimizers + def _get_applied_meta_list(self): return [type(opt).__name__ for opt in self._meta_optimizers] diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index c751e229cbb..dba3c944f70 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -50,7 +50,8 @@ class AMPOptimizer(MetaOptimizerBase): self.inner_opt, amp_lists, config['init_loss_scaling'], config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], config['incr_ratio'], config['decr_ratio'], - config['use_dynamic_loss_scaling']) + config['use_dynamic_loss_scaling'], config['use_pure_fp16'], + config['use_fp16_guard']) # if worker_num > 1, all cards will communication with each other, # add is_distributed to optimize amp, overlap communication and @@ -112,3 +113,11 @@ class AMPOptimizer(MetaOptimizerBase): self.wrapped_opt.minimize(loss, startup_program, parameter_list, no_grad_set) return optimize_ops, params_grads + + def amp_init(self, + place, + scope=None, + test_program=None, + use_fp16_test=False): + return self.wrapped_opt.amp_init(place, scope, test_program, + use_fp16_test) diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index 7ee184cfc5e..dd73577ae2e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -165,7 +165,9 @@ class GraphExecutionOptimizer(MetaOptimizerBase): main_program._hierarchical_allreduce_inter_nranks = local_build_strategy.hierarchical_allreduce_inter_nranks # TODO(guru4elephant): should be an independent optimizer - self._setup_nccl_op(startup_program, main_program, local_build_strategy) + if worker_num > 1: + self._setup_nccl_op(startup_program, main_program, + local_build_strategy) local_build_strategy.num_trainers = self.role_maker._worker_num() local_build_strategy.trainer_id = self.role_maker._worker_index() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 88027e46d27..d23b255a38f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -47,6 +47,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_3) list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer) @@ -487,6 +488,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS}) + py_test_modules(test_fleet_amp_init MODULES test_fleet_amp_init ENVS ${dist_ENVS}) py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS}) py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py index b5eacecd003..1c74a11cc4d 100755 --- a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py +++ b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py @@ -88,6 +88,21 @@ class TestFleetMetaOptimizer(unittest.TestCase): "custom_white_list": ['softmax'], "custom_black_list": ['tanh'], } + elif name == 'pure_fp16': + strategy.amp = True + strategy.amp_configs = { + "init_loss_scaling": 32768, + "decr_every_n_nan_or_inf": 2, + "incr_every_n_steps": 1000, + "incr_ratio": 2.0, + "use_dynamic_loss_scaling": True, + "decr_ratio": 0.5, + "custom_white_list": ['softmax'], + "custom_black_list": ['tanh'], + "use_pure_fp16": True, + "use_fp16_guard": False, + } + elif name == 'dgc': strategy.dgc = True strategy.dgc_configs = { diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py index 2fa6bf54769..869ca41a192 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py @@ -46,34 +46,88 @@ class TestFleetAMPInit(unittest.TestCase): def test_fleet_amp_init(self): if not fluid.core.is_compiled_with_cuda(): return - input_x = paddle.static.data( - name="x", shape=[None, 32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[None, 1], dtype='int64') - cost = mlp(input_x, input_y) - optimizer = paddle.optimizer.Momentum( - learning_rate=0.001, - momentum=0.9, - weight_decay=fluid.regularizer.L2Decay(1e-4), - multi_precision=True) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) - optimizer = paddle.static.amp.decorate(optimizer) - optimizer = fleet.distributed_optimizer(optimizer) - optimizer.minimize(cost) + with paddle.static.program_guard(main_program, startup_program): + input_x = paddle.static.data( + name="x", shape=[None, 32], dtype='float32') + input_y = paddle.static.data( + name="y", shape=[None, 1], dtype='int64') + + cost = mlp(input_x, input_y) + optimizer = paddle.optimizer.Momentum( + learning_rate=0.001, + momentum=0.9, + weight_decay=fluid.regularizer.L2Decay(1e-4), + multi_precision=True) + + optimizer = paddle.static.amp.decorate(optimizer) + optimizer = fleet.distributed_optimizer(optimizer) + optimizer.minimize(cost) + place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) + exe.run(startup_program) optimizer.amp_init(place) step = 1 for i in range(step): - cost_val = exe.run(program=paddle.static.default_main_program(), + cost_val = exe.run(program=main_program, + feed=gen_data(), + fetch_list=[cost.name]) + + def test_fleet_amp_meta_optimizer_init(self): + if not fluid.core.is_compiled_with_cuda(): + return + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + + with paddle.static.program_guard(main_program, startup_program): + input_x = paddle.static.data( + name="x", shape=[None, 32], dtype='float32') + input_y = paddle.static.data( + name="y", shape=[None, 1], dtype='int64') + + cost = mlp(input_x, input_y) + optimizer = paddle.optimizer.Momentum( + learning_rate=0.001, + momentum=0.9, + weight_decay=fluid.regularizer.L2Decay(1e-4), + multi_precision=True) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.amp = True + strategy.amp_configs = {'use_pure_fp16': True} + strategy.gradient_merge = True + strategy.gradient_merge_configs = {"k_steps": 2} + + optimizer = fleet.distributed_optimizer(optimizer, strategy) + optimizer.minimize(cost) + + print(fleet._get_applied_meta_list()) + + place = paddle.CUDAPlace(0) + + exe = paddle.static.Executor(place) + exe.run(startup_program) + optimizer.amp_init(place) + + step = 3 + for i in range(step): + cost_val = exe.run(program=main_program, feed=gen_data(), fetch_list=[cost.name]) + print(cost_val) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py index 30f6607df9d..982ec4eb5c7 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py @@ -93,6 +93,21 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer): self.assertIn('cast', ops) self.assertIn('check_finite_and_unscale', ops) + def test_pure_fp16_optimizer(self): + """ test pure fp16 """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'pure_fp16') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + params = train_prog.all_parameters() + for param in train_prog.all_parameters(): + self.assertEqual(param.dtype, fluid.core.VarDesc.VarType.FP16) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + def test_amp_distributed_optimizer(self): """ test amp when distributed """ train_prog, startup_prog = fluid.Program(), fluid.Program() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_single.py b/python/paddle/fluid/tests/unittests/test_fleet_base_single.py index 03e29399482..42b30e45b68 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_single.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_single.py @@ -78,6 +78,7 @@ class TestFleetBaseSingleRunCollective(unittest.TestCase): } def test_single_run_collective_minimize(self): + paddle.enable_static() input_x = paddle.static.data(name="x", shape=[-1, 32], dtype='float32') input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') @@ -114,6 +115,7 @@ class TestFleetBaseSingleRunPS(unittest.TestCase): } def test_single_run_ps_minimize(self): + paddle.enable_static() input_x = paddle.static.data(name="x", shape=[-1, 32], dtype='float32') input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') diff --git a/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py index 2d03b267fe9..efe62a32fc3 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py @@ -53,8 +53,25 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer): self.set_strategy(strategy, 'gradient_merge') self.set_strategy(strategy, 'amp') self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + vars = [x.name for x in train_prog.list_vars()] + self.assertIn('@GradientMerge', ''.join(vars)) + self.assertIn('cast', ''.join(vars)) + + def test_gm_pure_fp16_optimizer(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'gradient_merge') + self.set_strategy(strategy, 'pure_fp16') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) print(train_prog) + params = train_prog.all_parameters() + for param in train_prog.all_parameters(): + self.assertEqual(param.dtype, + paddle.fluid.core.VarDesc.VarType.FP16) + vars = [x.name for x in train_prog.list_vars()] self.assertIn('@GradientMerge', ''.join(vars)) self.assertIn('cast', ''.join(vars)) diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index cd6156d105b..b0c05cf8de7 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -244,7 +244,7 @@ class Adam(Optimizer): if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: warnings.warn( "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." - "Consider using multi_precision=True option of the Momentum optimizer." + "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) -- GitLab