diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 1fc29ad0428832d1a302fb996d689e69b36c4987..c7798b15c67fe82c660e7b17417d99c5909b4856 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -744,13 +744,13 @@ class DistributedStrategy(object): strategy.adaptive_localsgd = True # by default this is false """ - return self.strategy.localsgd + return self.strategy.adaptive_localsgd @adaptive_localsgd.setter @is_strict_auto def adaptive_localsgd(self, flag): if isinstance(flag, bool): - self.strategy.localsgd = flag + self.strategy.adaptive_localsgd = flag else: print("WARNING: adaptive_localsgd should have value of bool type") diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index ad96e1426694f090943bdd08902e5e2219d32eda..283589c5f332089ecb1e4e97e326c7314ee437c3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -19,16 +19,14 @@ class AMPOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(AMPOptimizer, self).__init__(optimizer) self.inner_opt = optimizer - self.amp_opt = None + self.wrapped_opt = None # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [ "LarsOptimizer", "LambOptimizer", "RecomputeOptimizer", - "LocalSGDOptimizer", "GradientMergeOptimizer", "GraphExecutionOptimizer", - "AdaptiveLocalSGDOptimizer", ] self.meta_optimizers_black_list = ["DGCOptimizer"] @@ -37,6 +35,24 @@ class AMPOptimizer(MetaOptimizerBase): super(AMPOptimizer, self)._set_basic_info( loss, role_maker, user_defined_optimizer, user_defined_strategy) + def _init_wrapped_opt(self): + if self.wrapped_opt is not None: + return + + config = self.user_defined_strategy.amp_configs + + custom_white_list = set(config['custom_white_list']) + custom_black_list = set(config['custom_black_list']) + custom_black_varnames = set(config['custom_black_varnames']) + amp_lists = mixed_precision.AutoMixedPrecisionLists( + custom_white_list, custom_black_list, custom_black_varnames) + + self.wrapped_opt = mixed_precision.decorate( + self.inner_opt, amp_lists, config['init_loss_scaling'], + config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], + config['incr_ratio'], config['decr_ratio'], + config['use_dynamic_loss_scaling']) + def _can_apply(self): if not self.role_maker._is_collective: return False @@ -60,26 +76,31 @@ class AMPOptimizer(MetaOptimizerBase): "use_dynamic_loss_scaling": True } + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + # maybe inner_opt of other meta optimizer + self._init_wrapped_opt() + return self.wrapped_opt.backward(loss, startup_program, parameter_list, + no_grad_set, callbacks) + + def apply_gradients(self, params_grads): + return self.wrapped_opt.apply_gradients(params_grads=params_grads) + + def apply_optimize(self, loss, startup_program, params_grads): + return self.wrapped_opt.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, parameter_list=None, no_grad_set=None): - if self.amp_opt is None: - config = self.user_defined_strategy.amp_configs - custom_white_list = set(config['custom_white_list']) - custom_black_list = set(config['custom_black_list']) - custom_black_varnames = set(config['custom_black_varnames']) - amp_lists = mixed_precision.AutoMixedPrecisionLists( - custom_white_list, custom_black_list, custom_black_varnames) - - self.amp_opt = mixed_precision.decorate( - self.inner_opt, amp_lists, config['init_loss_scaling'], - config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], - config['incr_ratio'], config['decr_ratio'], - config['use_dynamic_loss_scaling']) - + self._init_wrapped_opt() optimize_ops, params_grads = \ - self.amp_opt.minimize(loss, startup_program, + self.wrapped_opt.minimize(loss, startup_program, parameter_list, no_grad_set) return optimize_ops, params_grads diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index 6806a479d30f467bd8b6f6d5c6832dda63af4055..9990021c8506a386d0084811ae73b97f2ac37ca4 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -85,6 +85,13 @@ class DGCOptimizer(MetaOptimizerBase): return self.dgc_opt.backward(loss, startup_program, parameter_list, no_grad_set, callbacks) + def apply_gradients(self, params_grads): + return self.dgc_opt.apply_gradients(params_grads=params_grads) + + def apply_optimize(self, loss, startup_program, params_grads): + return self.dgc_opt.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py index df9887759e16fddb0579abdcdf3ef5f9024825e7..64d54ae3bab03b4511340c3ae222001aa7942f9c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py @@ -98,6 +98,10 @@ class LambOptimizer(MetaOptimizerBase): def apply_gradients(self, params_grads): return self.lamb_opt.apply_gradients(params_grads=params_grads) + def apply_optimize(self, loss, startup_program, params_grads): + return self.lamb_opt.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py index 609d8b85e714c1c7247898f8d506f9dadab9f499..32c6be505a5467b2fe6cc3f155cc8df7e21bfeca 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py @@ -85,6 +85,10 @@ class LarsOptimizer(MetaOptimizerBase): def apply_gradients(self, params_grads): return self.lars_opt.apply_gradients(params_grads=params_grads) + def apply_optimize(self, loss, startup_program, params_grads): + return self.lars_opt.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index 9f094978d842a8ba194742b527dc6f3cd19234cd..91030f07629343497426268106650ccb3f5011fd 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -24,7 +24,7 @@ class LocalSGDOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(LocalSGDOptimizer, self).__init__(optimizer) self.inner_opt = optimizer - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = ['AMPOptimizer'] self.meta_optimizers_black_list = [ "GraphExecutionOptimizer", "AdaptiveLocalSGDOptimizer", @@ -195,7 +195,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(AdaptiveLocalSGDOptimizer, self).__init__(optimizer) self.inner_opt = optimizer - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = ['AMPOptimizer'] self.meta_optimizers_black_list = [ "GraphExecutionOptimizer", "LocalSGDOptimizer" ] diff --git a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py index 59ca7e633099e8688a57fa9024575e29008c0341..ea2b67ac4bd1f647718cf454d85e8888141bdf83 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py @@ -18,15 +18,14 @@ from .meta_optimizer_base import MetaOptimizerBase class RecomputeOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(RecomputeOptimizer, self).__init__(optimizer) - #self.inner_opt = RO(optimizer) self.inner_opt = optimizer - self.wrapped_opt = RO(optimizer) + self.wrapped_opt = None # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [ "LarsOptimizer", "LambOptimizer", - "GradientMergeOptimizer", "GraphExecutionOptimizer", + "DGCOptimizer", ] self.meta_optimizers_black_list = [] @@ -34,8 +33,15 @@ class RecomputeOptimizer(MetaOptimizerBase): user_defined_strategy): super(RecomputeOptimizer, self)._set_basic_info( loss, role_maker, user_defined_optimizer, user_defined_strategy) - self.wrapped_opt._set_checkpoints( - list(user_defined_strategy.recompute_configs["checkpoints"])) + + def _init_wrapped_opt(self): + if self.wrapped_opt is not None: + return + + configs = self.user_defined_strategy.recompute_configs + + self.wrapped_opt = RO(self.inner_opt) + self.wrapped_opt._set_checkpoints(list(configs["checkpoints"])) def _can_apply(self): if not self.role_maker._is_collective: @@ -62,14 +68,24 @@ class RecomputeOptimizer(MetaOptimizerBase): parameter_list=None, no_grad_set=None, callbacks=None): + # maybe inner_opt of other meta optimizer + self._init_wrapped_opt() return self.wrapped_opt.backward(loss, startup_program, parameter_list, no_grad_set, callbacks) + def apply_gradients(self, params_grads): + return self.wrapped_opt.apply_gradients(params_grads=params_grads) + + def apply_optimize(self, loss, startup_program, params_grads): + return self.wrapped_opt.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, parameter_list=None, no_grad_set=None): + self._init_wrapped_opt() optimize_ops, params_grads = \ self.wrapped_opt.minimize(loss, startup_program, parameter_list, no_grad_set) diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index c9112ac849ce0506b7afd941b2213710e06bd1c6..529c664e7083ccd86d65464302dbaac7bffaab3c 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -16,6 +16,7 @@ from ... import default_main_program from ... import default_startup_program from ... import layers from ... import unique_name +from ... import program_guard from . import fp16_utils from .fp16_utils import rewrite_program from .fp16_utils import update_role_var_grad @@ -58,21 +59,40 @@ class OptimizerWithMixedPrecision(object): self._optimizer = optimizer self._amp_lists = amp_lists self._param_grads = None - self._train_program = default_main_program() - self._startup_prog = default_startup_program() + self._train_program = None + self._scaled_loss = None - self._loss_scaling = layers.create_global_var( - name=unique_name.generate("loss_scaling"), - shape=[1], - value=init_loss_scaling, - dtype='float32', - persistable=True) + self._loss_scaling = None + self._init_loss_scaling = init_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling if self._use_dynamic_loss_scaling: self._incr_every_n_steps = incr_every_n_steps self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf self._incr_ratio = incr_ratio self._decr_ratio = decr_ratio + self._num_good_steps = None + self._num_bad_steps = None + + def get_loss_scaling(self): + """Return the real-time loss scaling factor. + """ + return self._loss_scaling + + def get_scaled_loss(self): + """Return the scaled loss. + It's useful when you feed customed loss into executor. + """ + return self._scaled_loss + + def _init_amp_var(self): + self._loss_scaling = layers.create_global_var( + name=unique_name.generate("loss_scaling"), + shape=[1], + value=self._init_loss_scaling, + dtype='float32', + persistable=True) + + if self._use_dynamic_loss_scaling: self._num_good_steps = layers.create_global_var( name=unique_name.generate("num_good_steps"), shape=[1], @@ -86,28 +106,16 @@ class OptimizerWithMixedPrecision(object): dtype='int32', persistable=True) - # Ensure the data type of learning rate vars is float32 (same as the + # Ensure the data type of learning rate vars is float32 (same as the # master parameter dtype) - if isinstance(optimizer._learning_rate, float): - optimizer._learning_rate_map[default_main_program()] = \ - layers.create_global_var( - name=unique_name.generate("learning_rate"), - shape=[1], - value=float(optimizer._learning_rate), - dtype='float32', - persistable=True) - - def get_loss_scaling(self): - """Return the real-time loss scaling factor. - """ - return self._loss_scaling - - def get_scaled_loss(self): - """Return the scaled loss. - It's useful when you feed customed loss into executor. - """ - - return self._scaled_loss + if isinstance(self._optimizer._learning_rate, float): + self._optimizer._learning_rate_map[default_main_program()] = \ + layers.create_global_var( + name=unique_name.generate("learning_rate"), + shape=[1], + value=float(self._optimizer._learning_rate), + dtype='float32', + persistable=True) def backward(self, loss, @@ -131,16 +139,21 @@ class OptimizerWithMixedPrecision(object): A list of (param, grad), which is a tuple of a parameter and its gradient respectively, and the scaled loss. """ - rewrite_program(self._train_program, self._amp_lists) - self._scaled_loss = loss * self._loss_scaling - self._params_grads = self._optimizer.backward( - self._scaled_loss, startup_program, parameter_list, no_grad_set, - callbacks) - # Change the op_role_var attr for some ops, so that gradients - # transferred across GPUs can be FP16. - update_role_var_grad(self._train_program, self._params_grads) - - return self._params_grads + train_program = loss.block.program + self._train_program = train_program + + with program_guard(train_program, startup_program): + self._init_amp_var() + + rewrite_program(train_program, self._amp_lists) + self._scaled_loss = loss * self._loss_scaling + params_grads = self._optimizer.backward( + self._scaled_loss, startup_program, parameter_list, no_grad_set, + callbacks) + # Change the op_role_var attr for some ops, so that gradients + # transferred across GPUs can be FP16. + update_role_var_grad(train_program, params_grads) + return params_grads def apply_gradients(self, params_grads): """ @@ -182,6 +195,12 @@ class OptimizerWithMixedPrecision(object): return optimize_ops + def apply_optimize(self, loss, startup_program, params_grads): + program = loss.block.program + with program_guard(program, startup_program): + optimize_ops = self.apply_gradients(params_grads) + return optimize_ops + def minimize(self, loss, startup_program=None, @@ -207,7 +226,8 @@ class OptimizerWithMixedPrecision(object): parameter_list=parameter_list, no_grad_set=no_grad_set) - optimize_ops = self.apply_gradients(scaled_params_grads) + optimize_ops = self.apply_optimize(loss, startup_program, + scaled_params_grads) return optimize_ops, scaled_params_grads diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 4a9ce4454af0be86f784a8ea9bcbc81564d9a383..367be181f4725a6cf72adc633ffc817066d7c5d6 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -731,9 +731,6 @@ class Optimizer(object): outputs={"ParamOut": param_and_grad[0]}) return new_param_grads, (table_param, table_grad), sgd_op - def _append_dgc_ops(self, param_and_grad): - pass - def backward(self, loss, startup_program=None, @@ -801,9 +798,6 @@ class Optimizer(object): with program_guard(program, startup_program): params_grads = append_backward(loss, parameter_list, act_no_grad_set, callbacks) - # Note: since we can't use all_reduce_op now, - # dgc_op should be the last op of one grad. - self._append_dgc_ops(params_grads) return params_grads def apply_gradients(self, params_grads): @@ -1569,6 +1563,11 @@ class DGCMomentumOptimizer(Optimizer): @imperative_base.no_grad def apply_gradients(self, params_grads): + # Note: since we can't use all_reduce_op now, + # dgc_op should be the last op of one grad. + # Maybe need a grad allreduce pass. + self._append_dgc_ops(params_grads) + params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads, table_param_and_grad, table_optimize_op = \ self._process_distribute_lookuptable(params_grads) @@ -4784,10 +4783,6 @@ class RecomputeOptimizer(Optimizer): params_grads = append_backward( loss, parameter_list, no_grad_set, checkpoints=checkpoint_vars) - # Note: since we can't use all_reduce_op now, - # dgc_op should be the last op of one grad. - if hasattr(self._optimizer, "_append_dgc_ops"): - self._optimizer._append_dgc_ops(params_grads) return params_grads def apply_optimize(self, loss, startup_program, params_grads): diff --git a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py new file mode 100755 index 0000000000000000000000000000000000000000..e7cdd49a32c2683394cc08c6d3027084d082c117 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py @@ -0,0 +1,122 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +from paddle import fluid +import os +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker + + +class TestFleetMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" + + def net(self, main_prog, startup_prog): + with fluid.program_guard(main_prog, startup_prog): + with fluid.unique_name.guard(): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data( + name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, + size=64, + act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=256, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], + size=2, + act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + return avg_cost, strategy + + def optimizer(self, + loss, + strategy, + train_prog, + startup_prog, + name='momentum'): + with fluid.program_guard(train_prog, startup_prog): + with fluid.unique_name.guard(): + if name == 'momentum': + optimizer = paddle.fluid.optimizer.Momentum( + learning_rate=0.01, momentum=0.9) + elif name == 'adam': + optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01) + optimizer = fleet.distributed_optimizer( + optimizer, strategy=strategy) + optimizer.minimize(loss) + + def set_strategy(self, strategy, name): + if name == 'amp': + strategy.amp = True + strategy.amp_configs = { + "init_loss_scaling": 32768, + "decr_every_n_nan_or_inf": 2, + "incr_every_n_steps": 1000, + "incr_ratio": 2.0, + "use_dynamic_loss_scaling": True, + "decr_ratio": 0.5, + "custom_white_list": ['softmax'], + "custom_black_list": ['tanh'], + } + elif name == 'dgc': + strategy.dgc = True + strategy.dgc_configs = { + "rampup_begin_step": 128, + "rampup_step": 100, + "sparsity": [0.996, 0.999] + } + elif name == 'recompute': + strategy.recompute = True + strategy.recompute_configs = { + "checkpoints": ["fc_0.tmp_2", "fc_1.tmp_2"] + } + elif name == 'lars': + strategy.lars = True + strategy.lars_configs = { + "lars_coeff": 0.001, + "lars_weight_decay": 0.0005, + "epsilon": 0, + "exclude_from_weight_decay": ["batch_norm", ".b"], + } + elif name == 'lamb': + strategy.lamb = True + strategy.lamb_configs = { + 'lamb_weight_decay': 0.01, + 'exclude_from_weight_decay': [], + } + elif name == 'localsgd': + strategy.localsgd = True + strategy.localsgd_configs = { + 'k_steps': 1, + 'begin_step': 1, + } + elif name == 'adaptive_localsgd': + strategy.adaptive_localsgd = True + strategy.adaptive_localsgd_configs = { + 'init_k_steps': 1, + 'begin_step': 1, + } + else: + raise NotImplementedError() diff --git a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py index 49b93e0dfaaacddc9916f91a9ccd6c7e8bbd1714..d615f7cb7044e588557b1b14dbb54a881bdb8730 100644 --- a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py @@ -16,12 +16,14 @@ from __future__ import print_function import unittest +import paddle import paddle.fluid.framework as framework import paddle.fluid.optimizer as optimizer import paddle.fluid.regularizer as regularizer import paddle.fluid.clip as clip import paddle.compat as cpt from paddle.fluid.backward import append_backward +paddle.enable_static() class TestDGCMomentumOptimizer(unittest.TestCase): @@ -86,13 +88,17 @@ class TestDGCMomentumOptimizer(unittest.TestCase): block.append_op( type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) # params_grads = append_backward(mean_out) - params_grads = dgc_momentum_optimizer.backward(mean_out) + params_grads = dgc_momentum_optimizer.backward( + mean_out, startup_program=init_program) + + with framework.program_guard(program, init_program): + opts = dgc_momentum_optimizer.apply_gradients(params_grads) + accumulator_count = 1 if name == "momentum" else 2 self.assertEqual(len(params_grads), 1) self.assertEqual( len(dgc_momentum_optimizer.get_accumulators()), accumulator_count) - with framework.program_guard(program, init_program): - opts = dgc_momentum_optimizer.apply_gradients(params_grads) + self.assertEqual(len(opts), 2) sgd_op = opts[-1] self.assertEqual([op.type for op in opts], ["scale", name]) @@ -108,8 +114,11 @@ class TestDGCMomentumOptimizer(unittest.TestCase): self.assertTrue(mul_x.name in velocity_acc) # Check init_program + # dgc not apply include: lr, dgc(count, nranks, begin step), (u,) + # dgc apply include: lr, dgc(count, nranks, begin_step), (u,v,k,encode,gather) + init_ops_count = 5 if name == "momentum" else 9 init_ops = init_program.global_block().ops - self.assertEqual(len(init_ops), 1) + self.assertEqual(len(init_ops), init_ops_count) self.assertEqual(init_ops[0].type, "fill_constant") self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py index 362428631e68cc7ac88be93d7ba1ff449a035822..6bc1a310d0aea0b5e7af0b5536fad8e4403d892f 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py @@ -12,57 +12,97 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.distributed.fleet as fleet -import paddle.distributed.fleet.base.role_maker as role_maker import unittest import paddle +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle.distributed.fleet.meta_optimizers import AMPOptimizer import os +from fleet_meta_optimizer_base import TestFleetMetaOptimizer paddle.enable_static() -class TestFleetAMPOptimizer(unittest.TestCase): - def setUp(self): - os.environ["PADDLE_TRAINER_ID"] = "0" - os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" +class TestFleetAMPOptimizer(TestFleetMetaOptimizer): + def test_amp_optimizer_backward(self): + """ test amp optimizer backward """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = AMPOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('cast', ops) + self.assertNotIn('check_finite_and_unscale', ops) + + def test_amp_optimizer_backward_gradients(self): + """ test amp optimizer backward + gradients""" + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = AMPOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + with fluid.program_guard(train_prog, startup_prog): + opt.apply_gradients(params_grads) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + + def test_amp_optimizer_backward_optimize(self): + """ test amp optimizer backward + optimizer """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = AMPOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + opt.apply_optimize(avg_cost, startup_prog, params_grads) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) def test_amp_optimizer(self): - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - - fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.amp = True - strategy.amp_configs = { - "init_loss_scaling": 32768, - "decr_every_n_nan_or_inf": 2, - "incr_every_n_steps": 1000, - "incr_ratio": 2.0, - "use_dynamic_loss_scaling": True, - "decr_ratio": 0.5, - "custom_white_list": ['softmax'], - "custom_black_list": ['tanh'], - } - - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) + """ test amp """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'amp') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + + def test_amp_recompute_optimizer(self): + """ test amp + recompute """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'amp') + self.set_strategy(strategy, 'recompute') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) strategy = fleet._final_strategy() ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] self.assertIn('cast', ops) self.assertIn('check_finite_and_unscale', ops) + # recompute + self.assertIn('subprog', ''.join(outs)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py index 55d4ff7726aace09e486156d26efdecf22b310a5..0faafd76a799d038c175e8ce5758f77374bfd37e 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py @@ -17,65 +17,82 @@ import paddle from paddle import fluid import os import paddle.distributed.fleet as fleet +from fleet_meta_optimizer_base import TestFleetMetaOptimizer +from paddle.distributed.fleet.meta_optimizers import DGCOptimizer import paddle.distributed.fleet.base.role_maker as role_maker +paddle.enable_static() -class TestFleetDGCOptimizer(unittest.TestCase): - def setUp(self): - os.environ["PADDLE_TRAINER_ID"] = "1" - os.environ[ - "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" - - def net(self, main_prog, startup_prog): - with fluid.program_guard(main_prog, startup_prog): - with fluid.unique_name.guard(): - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data( - name="y", shape=[1], dtype='int64') - - fc_1 = paddle.fluid.layers.fc(input=input_x, - size=64, - act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=256, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], - size=2, - act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.dgc = True - strategy.dgc_configs = { - "rampup_begin_step": 128, - "rampup_step": 100, - "sparsity": [0.996, 0.999] - } - return avg_cost, strategy + +class TestFleetDGCOptimizer(TestFleetMetaOptimizer): + def test_dgc_optimizer_backward(self): + """ test dgc optimizer backward """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'dgc') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + dgc_opt = DGCOptimizer(opt) + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + dgc_opt._set_basic_info(avg_cost, role, opt, strategy) + params_grads = dgc_opt.backward(avg_cost, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + self.assertNotIn('dgc', ops) + + def test_dgc_optimizer_gradients(self): + """ test dgc optimizer backward + gradients """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'dgc') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + dgc_opt = DGCOptimizer(opt) + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + dgc_opt._set_basic_info(avg_cost, role, opt, strategy) + params_grads = dgc_opt.backward(avg_cost, startup_prog) + with fluid.program_guard(train_prog, startup_prog): + dgc_opt.apply_gradients(params_grads) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('dgc', ops) + self.assertIn('dgc_momentum', ops) + + def test_dgc_optimizer_optimize(self): + """ test dgc optimizer backward + optimize """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'dgc') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + dgc_opt = DGCOptimizer(opt) + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + dgc_opt._set_basic_info(avg_cost, role, opt, strategy) + params_grads = dgc_opt.backward(avg_cost, startup_prog) + dgc_opt.apply_optimize(avg_cost, startup_prog, params_grads) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('dgc', ops) + self.assertIn('dgc_momentum', ops) def test_dgc_optimizer(self): - startup_prog = fluid.Program() - train_prog = fluid.Program() + train_prog, startup_prog = fluid.Program(), fluid.Program() avg_cost, strategy = self.net(train_prog, startup_prog) - optimizer = paddle.fluid.optimizer.Momentum( - learning_rate=0.01, momentum=0.9) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) + self.set_strategy(strategy, 'dgc') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) ops = [op.type for op in avg_cost.block.ops] self.assertIn('dgc', ops) self.assertIn('dgc_momentum', ops) def test_dgc_not_apply_with_adam(self): - startup_prog = fluid.Program() - train_prog = fluid.Program() + train_prog, startup_prog = fluid.Program(), fluid.Program() avg_cost, strategy = self.net(train_prog, startup_prog) - optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) + self.set_strategy(strategy, 'dgc') + self.optimizer(avg_cost, strategy, train_prog, startup_prog, 'adam') ops = [op.type for op in avg_cost.block.ops] self.assertNotIn('dgc', ops) @@ -85,18 +102,32 @@ class TestFleetDGCOptimizer(unittest.TestCase): os.environ["PADDLE_TRAINER_ID"] = "0" os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" - startup_prog = fluid.Program() - train_prog = fluid.Program() + train_prog, startup_prog = fluid.Program(), fluid.Program() avg_cost, strategy = self.net(train_prog, startup_prog) - optimizer = paddle.fluid.optimizer.Momentum( - learning_rate=0.01, momentum=0.9) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) + self.set_strategy(strategy, 'dgc') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) ops = [op.type for op in avg_cost.block.ops] self.assertNotIn('dgc', ops) self.assertNotIn('dgc_momentum', ops) + def test_dgc_recompute_optimizer(self): + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'dgc') + self.set_strategy(strategy, 'recompute') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('dgc', ops) + self.assertIn('dgc_momentum', ops) + + # recompute + self.assertIn('subprog', ''.join(outs)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py index f5347b0c665e2a162f7f8210171ec415afee4599..bafb2419123b0b348542eeede6af9ded9925fdcc 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py @@ -16,71 +16,87 @@ import unittest import paddle import os +import paddle +import paddle.fluid as fluid import paddle.distributed.fleet as fleet import paddle.distributed.fleet.base.role_maker as role_maker +from fleet_meta_optimizer_base import TestFleetMetaOptimizer +paddle.enable_static() -class TestFleetLocalSGDMetaOptimizer(unittest.TestCase): - def setUp(self): - os.environ["PADDLE_TRAINER_ID"] = "1" - os.environ[ - "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" +class TestFleetLocalSGDMetaOptimizer(TestFleetMetaOptimizer): def test_localsgd_optimizer(self): - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - - fc = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.localsgd = True - strategy.auto = True - config = strategy.localsgd_configs - config['k_steps'] = 1 - config['begin_step'] = 1 - strategy.localsgd_configs = config - - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) - - -class TestFleetAdaptiveLocalSGDMetaOptimizer(unittest.TestCase): - def setUp(self): - os.environ["PADDLE_TRAINER_ID"] = "1" - os.environ[ - "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" - + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'localsgd') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + ''.join(op.output('Out')) for op in avg_cost.block.ops + if op.type == 'conditional_block' + ] + + self.assertIn('conditional_block', ops) + self.assertIn('@SNAPSHOT', ''.join(outs)) + + def test_localsgd_amp_optimizer(self): + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'localsgd') + self.set_strategy(strategy, 'amp') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + ''.join(op.output('Out')) for op in avg_cost.block.ops + if op.type == 'conditional_block' + ] + + self.assertIn('conditional_block', ops) + self.assertIn('@SNAPSHOT', ''.join(outs)) + + # amp + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + + +class TestFleetAdaptiveLocalSGDMetaOptimizer(TestFleetMetaOptimizer): def test_adaptive_localsgd_optimizer(self): - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - - fc = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.adaptive_localsgd = True - config = strategy.adaptive_localsgd_configs - config['init_k_steps'] = 1 - config['begin_step'] = 1 - strategy.adaptive_localsgd_configs = config - - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'adaptive_localsgd') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + ''.join(op.output('Out')) for op in avg_cost.block.ops + if op.type == 'conditional_block' + ] + + self.assertIn('conditional_block', ops) + self.assertIn('@SNAPSHOT', ''.join(outs)) + + def test_localsgd_amp_optimizer(self): + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'adaptive_localsgd') + self.set_strategy(strategy, 'amp') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + ''.join(op.output('Out')) for op in avg_cost.block.ops + if op.type == 'conditional_block' + ] + + self.assertIn('conditional_block', ops) + self.assertIn('@SNAPSHOT', ''.join(outs)) + + # amp + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py index a42010a4eaa5066821adb817e7a5df2b81bedf7c..42b60cd3fad5a76aee851620c2348d2de2e024e3 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py @@ -14,40 +14,144 @@ import unittest import paddle +import paddle.fluid as fluid import os +from fleet_meta_optimizer_base import TestFleetMetaOptimizer +from paddle.distributed.fleet.meta_optimizers import RecomputeOptimizer +paddle.enable_static() -class TestFleetRecomputeMetaOptimizer(unittest.TestCase): - def setUp(self): - os.environ["POD_IP"] = "127.0.0.1" - os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" - os.environ["PADDLE_TRAINERS_NUM"] = "2" - os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ - "127.0.0.1:36001,127.0.0.2:36001" + +class TestFleetRecomputeMetaOptimizer(TestFleetMetaOptimizer): + def test_recompute_optimizer_backward(self): + """ test recompute optimizer backward """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'recompute') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = RecomputeOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('subprog', ''.join(outs)) + + def test_recompute_optimizer_backward_gradients(self): + """ test recompute optimizer backward + gradients """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'recompute') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = RecomputeOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + with fluid.program_guard(train_prog, startup_prog): + opt.apply_gradients(params_grads) + + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('subprog', ''.join(outs)) + + def test_recompute_optimizer_backward_optimize(self): + """ test recompute optimizer backward + optimize """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'recompute') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = RecomputeOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + opt.apply_optimize(avg_cost, startup_prog, params_grads) + + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('subprog', ''.join(outs)) + + def test_recompute_optimizer_backward(self): + """ test recompute optimizer backward """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'recompute') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = RecomputeOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('subprog', ''.join(outs)) + + def test_recompute_optimizer_backward(self): + """ test recompute optimizer backward """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + + self.set_strategy(strategy, 'recompute') + opt = fluid.optimizer.MomentumOptimizer( + learning_rate=0.001, momentum=0.9) + opt = RecomputeOptimizer(opt) + opt.user_defined_strategy = strategy + params_grads = opt.backward(avg_cost, startup_prog) + + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('subprog', ''.join(outs)) def test_recompute_optimizer(self): - import paddle.distributed.fleet as fleet - import paddle.distributed.fleet.base.role_maker as role_maker - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - - fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.recompute = True - strategy.recompute_configs = {"checkpoints": ["fc_1.tmp_0"]} - - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'recompute') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + + self.assertIn('subprog', ''.join(outs)) + + def test_recompute_lars_optimizer(self): + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'recompute') + self.set_strategy(strategy, 'lars') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + + self.assertIn('subprog', ''.join(outs)) + self.assertIn('lars_momentum', ops) + + def test_recompute_lamb_optimizer(self): + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'recompute') + self.set_strategy(strategy, 'lamb') + self.optimizer(avg_cost, strategy, train_prog, startup_prog, 'adam') + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + + self.assertIn('subprog', ''.join(outs)) + self.assertIn('lamb', ops) if __name__ == "__main__":