# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from ..ps.utils.public import * from paddle.framework import core from .pass_base import PassBase, register_pass from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import ExponentialDecay, NoamDecay, PiecewiseDecay, NaturalExpDecay, InverseTimeDecay from paddle.fluid.layers.learning_rate_scheduler import exponential_decay, noam_decay, piecewise_decay, natural_exp_decay, inverse_time_decay @register_pass("add_lr_decay_table_pass") class AddLrDecayTablePass(PassBase): def __init__(self): super(AddLrDecayTablePass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _add_tensor_table(self, attrs, feed_var_name, fetch_var_name="", startup_program=None, main_program=None, tensor_table_class=""): tensor_table_dict = {} tensor_table_dict[feed_var_name] = {} tensor_table_dict[feed_var_name]["feed_var_name"] = feed_var_name tensor_table_dict[feed_var_name]["fetch_var_name"] = fetch_var_name tensor_table_dict[feed_var_name]["startup_program"] = startup_program tensor_table_dict[feed_var_name]["main_program"] = main_program tensor_table_dict[feed_var_name][ "tensor_table_class"] = tensor_table_class attrs['tensor_table'] = tensor_table_dict def _get_lr_sheduler_program(self, lr_sheduler, lr_decay_steps): schedler_decay = [ 'NoamDecay', 'NaturalExpDecay', 'InverseTimeDecay', 'ExponentialDecay' ] decay_main_program = fluid.framework.Program() decay_startup_program = fluid.framework.Program() lr_name = "" if isinstance(lr_sheduler, ExponentialDecay): with fluid.program_guard(decay_main_program, decay_startup_program): lr = exponential_decay(1.0, lr_decay_steps, lr_sheduler.gamma, True) lr_name = lr.name logging.warn( "ExponentialDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n" "\t strategy = paddle.distributed.fleet.DistributedStrategy() \n " "\t strategy.a_sync = True \n" "\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n" % lr_decay_steps) elif isinstance(lr_sheduler, NoamDecay): with fluid.program_guard(decay_main_program, decay_startup_program): lr = noam_decay(lr_sheduler.d_model, lr_sheduler.warmup_steps, 1.0) lr_name = lr.name logging.warn("NoamDecay is set, warmup steps is [ %d ]" % lr_sheduler.warmup_steps) elif isinstance(lr_sheduler, NaturalExpDecay): with fluid.program_guard(decay_main_program, decay_startup_program): lr = natural_exp_decay(1.0, lr_decay_steps, lr_sheduler.gamma, True) lr_name = lr.name logging.warn( "NaturalExpDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n" "\t strategy = paddle.distributed.fleet.DistributedStrategy() \n " "\t strategy.a_sync = True \n" "\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n" % lr_decay_steps) elif isinstance(lr_sheduler, InverseTimeDecay): with fluid.program_guard(decay_main_program, decay_startup_program): lr = inverse_time_decay(1.0, lr_decay_steps, lr_sheduler.gamma, True) lr_name = lr.name logging.warn( "InverseTimeDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n" "\t strategy = paddle.distributed.fleet.DistributedStrategy() \n " "\t strategy.a_sync = True \n" "\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n" % lr_decay_steps) else: raise ValueError( "Not supported current LearningRate strategy, please use follow decay strategy: {}". format(schedler_decay)) return decay_main_program, decay_startup_program, lr_name def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs if hasattr(attrs['origin_main_program'], 'lr_sheduler') == False: return assert isinstance(attrs['origin_main_program'].lr_sheduler, LRScheduler), "must be LRScheduler" ops = get_optimize_ops(attrs['origin_main_program']) lr_decay_main_program, lr_decay_startup_program, lr_name = self._get_lr_sheduler_program( attrs['origin_main_program'].lr_sheduler, attrs['lr_decay_steps']) self._add_tensor_table(attrs, "@LR_DECAY_COUNTER@", lr_name, lr_decay_startup_program, lr_decay_main_program, "GlobalStepTable") return @register_pass("add_listen_and_serv_pass") class AddListenAndServPass(PassBase): def __init__(self): super(AddListenAndServPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs opt = { "grad_to_block_id": None, "sparse_grad_to_param": None, "lr_decay_block_id": None, "dense_optimize_blocks": None, "sparse_optimize_blocks": None, # runtime attribute "endpoint": get_ps_endpoint(attrs['role_maker']), "pserver_id": get_role_id(attrs['role_maker']), "Fanin": get_trainers(attrs['role_maker']), "distributed_mode": attrs['ps_mode'], "rpc_get_thread_num": -1, "rpc_send_thread_num": -1, "rpc_prefetch_thread_num": -1 } main_program.global_block().append_op( type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=opt) @register_pass("add_rpc_global_flags_pass") class AddRpcGlobalFlagsPass(PassBase): def __init__(self): super(AddRpcGlobalFlagsPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): pass @register_pass("add_optimizer_pass") class AddOptimizerPass(PassBase): def __init__(self): super(AddOptimizerPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): pass @register_pass("add_geo_optimizer_pass") class AddGeoOptimizerPass(PassBase): def __init__(self): super(AddGeoOptimizerPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): pass @register_pass("build_pserver_startup_program_pass") class BuildPserverStartupProgramPass(PassBase): def __init__(self): super(BuildPserverStartupProgramPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): pass @register_pass("delete_unused_in_startup_pass") class DeleteUnusedInStartupPass(PassBase): def __init__(self): super(DeleteUnusedInStartupPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): pass