From 51b081230aecbec3c6614713a82ba6fdf74f6c35 Mon Sep 17 00:00:00 2001 From: wangzhen38 <41941775+wangzhen38@users.noreply.github.com> Date: Tue, 22 Nov 2022 17:25:29 +0800 Subject: [PATCH] [remove fluid] under fleet meta_optimizers_wz (#47888) * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * update * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz * [remove fluid] under fleet meta_optimizers_wz --- .../fleet/base/meta_optimizer_factory.py | 1 + .../fleet/meta_optimizers/__init__.py | 1 + .../fleet/meta_optimizers/asp_optimizer.py | 2 +- .../fleet/meta_optimizers/common.py | 8 +- .../fleet/meta_optimizers/dgc_optimizer.py | 423 +++++++++++++++- .../fp16_allreduce_optimizer.py | 6 +- .../graph_execution_optimizer.py | 5 +- .../fleet/meta_optimizers/lamb_optimizer.py | 2 +- .../fleet/meta_optimizers/lars_optimizer.py | 3 +- .../meta_optimizers/localsgd_optimizer.py | 69 +-- .../parameter_server_graph_optimizer.py | 4 +- .../parameter_server_optimizer.py | 7 +- .../meta_optimizers/pipeline_optimizer.py | 4 +- .../fleet/meta_optimizers/ps_optimizer.py | 9 +- python/paddle/fluid/optimizer.py | 472 ------------------ .../collective/fleet/test_dgc_optimizer.py | 4 +- .../fluid/tests/unittests/dist_mnist.py | 2 +- .../fluid/tests/unittests/dist_se_resnext.py | 16 +- .../unittests/test_imperative_optimizer.py | 3 +- .../unittests/test_imperative_optimizer_v2.py | 3 +- 20 files changed, 505 insertions(+), 539 deletions(-) diff --git a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py index dd4611fc0a..2577df9380 100755 --- a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py +++ b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py @@ -24,6 +24,7 @@ meta_optimizer_names = list( # should be removed meta_optimizer_names.remove("HybridParallelOptimizer") meta_optimizer_names.remove("HeterParallelOptimizer") +meta_optimizer_names.remove("DGCMomentumOptimizer") class MetaOptimizerFactory: diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 1eae4be579..feb7b125ad 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -24,6 +24,7 @@ from .localsgd_optimizer import AdaptiveLocalSGDOptimizer from .lars_optimizer import LarsOptimizer from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer from .dgc_optimizer import DGCOptimizer +from .dgc_optimizer import DGCMomentumOptimizer from .lamb_optimizer import LambOptimizer from .fp16_allreduce_optimizer import FP16AllReduceOptimizer from .sharding_optimizer import ShardingOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py index 637fa31a6b..a2f494e4a8 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.fluid.contrib.sparsity.asp import ASPHelper from .meta_optimizer_base import MetaOptimizerBase +from paddle.fluid.contrib.sparsity.asp import ASPHelper __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 03ed84563b..bbcd1d8215 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -13,9 +13,9 @@ # limitations under the License. import os - -import paddle.fluid as fluid -from paddle.fluid import core, unique_name +import paddle +from paddle.framework import core +from paddle.utils import unique_name from ..base.private_helper_function import wait_server_ready __all__ = [] @@ -62,7 +62,7 @@ class CollectiveHelper: def update_startup_program(self, startup_program=None): self.startup_program = startup_program if startup_program is None: - self.startup_program = fluid.default_startup_program() + self.startup_program = paddle.static.default_startup_program() endpoints = self.role_maker._get_trainer_endpoints() current_endpoint = endpoints[self.role_maker._worker_index()] diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index a8861f12cc..1c728ed16e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -11,12 +11,433 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.fluid.optimizer import Momentum, DGCMomentumOptimizer +from functools import reduce from .meta_optimizer_base import MetaOptimizerBase import logging __all__ = [] +from paddle.fluid.layers import tensor +import paddle +from paddle import framework +from paddle.framework import core +from paddle.common_ops_import import LayerHelper +from paddle.fluid.clip import GradientClipByNorm, append_gradient_clip_ops +from paddle.fluid.optimizer import Optimizer, Momentum +from paddle.fluid.dygraph import base as imperative_base + + +class DGCMomentumOptimizer(Optimizer): + _u_velocity_acc_str = "_dgc_u_" + _v_velocity_acc_str = "_dgc_v_" + + def __init__( + self, + learning_rate, + momentum, + rampup_begin_step, + rampup_step=1, + sparsity=[0.999], + parameter_list=None, + use_nesterov=False, + num_trainers=None, + regularization=None, + grad_clip=None, + name=None, + ): + if framework._non_static_mode(): + raise Exception("In dygraph, don't support DGCMomentumOptimizer.") + + assert ( + core.is_compiled_with_cuda() + ), "Paddle is not compiled with CUDA. DGC is only support GPU for now." + + assert learning_rate is not None + assert momentum is not None + super().__init__( + learning_rate=learning_rate, + parameter_list=parameter_list, + regularization=regularization, + grad_clip=grad_clip, + name=name, + ) + self.type = "dgc_momentum" + self._momentum = momentum + self._use_nesterov = bool(use_nesterov) + + assert rampup_begin_step >= 0, "rampup_begin_step must >= 0" + self._rampup_begin_step = rampup_begin_step + self._rampup_step = rampup_step + self._sparsity = sparsity + + self._rampup_begin_step_var = None + self._global_step_var = None + + self._dgc_clip_norm = None + if grad_clip is not None: + if not isinstance(grad_clip, GradientClipByNorm): + raise TypeError( + "The type of grad_clip should be 'GradientClipByNorm', because DGCMomentumOptimizer only support GradientClipByNorm" + ) + assert isinstance(num_trainers, int), ( + "The type of num_trainers should be 'int', but received %s" + % type(num_trainers) + ) + assert ( + num_trainers > 0 + ), "The value of num_trainers should be greater than 0!" + + self._num_trainers = num_trainers + self._dgc_clip_norm = grad_clip.clip_norm * (num_trainers**-0.5) + + self.regular_type, self.regular_coeff = self._get_regularization_param( + self.regularization + ) + + def _get_regularization_param(self, regularization): + regular_type = 0 + regular_coeff = 0.0 + + if regularization is not None: + regular_coeff = regularization._regularization_coeff + from paddle.fluid.regularizer import L1Decay, L2Decay + + if isinstance(regularization, L1Decay): + regular_type = 1 + elif isinstance(regularization, L2Decay): + regular_type = 2 + else: + assert False, 'regularization must be None|L1Decay|L2Deacy' + return regular_type, regular_coeff + + def _is_use_dgc(self, param_var, grad_var): + var_numel = abs(reduce(lambda x, y: x * y, param_var.shape)) + if ( + var_numel < 16384 + or param_var.type == core.VarDesc.VarType.SELECTED_ROWS + or grad_var.type == core.VarDesc.VarType.SELECTED_ROWS + or param_var.dtype != core.VarDesc.VarType.FP32 + ): + return False + return True + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, paddle.fluid.framework.Block) + velocity_acc = self._get_accumulator( + self._u_velocity_acc_str, param_and_grad[0] + ) + assert velocity_acc is not None + + inputs = { + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Velocity": velocity_acc, + "LearningRate": self._create_param_lr(param_and_grad), + } + outputs = { + "ParamOut": param_and_grad[0], + "VelocityOut": velocity_acc, + } + attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov} + + if not self._is_use_dgc(param_and_grad[0], param_and_grad[1]): + type = "momentum" + else: + type = "dgc_momentum" + inputs.update( + { + "current_step": self._global_step_var, + "nranks": self._nranks_var, + } + ) + outputs.update({'Grad_out': param_and_grad[1]}) + attrs.update({"rampup_begin_step": float(self._rampup_begin_step)}) + + # create the dgc momentum optimize op + dgc_momentum_op = block.append_op( + type=type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True, + ) + return dgc_momentum_op + + def _add_auto_increment_var(self, counter_name, begin, step=1): + helper = LayerHelper('global_step_counter') + counter, is_new_var = helper.create_or_get_global_variable( + name=counter_name, dtype='float32', shape=[1], persistable=True + ) + if is_new_var: + helper.set_variable_initializer( + counter, + initializer=paddle.fluid.initializer.Constant( + value=float(begin - 1), force_cpu=True + ), + ) + helper.main_program.global_block()._prepend_op( + type='increment', + inputs={'X': [counter]}, + outputs={'Out': [counter]}, + attrs={'step': float(step)}, + stop_gradient=True, + ) + counter.stop_gradient = True + + return counter + + def _add_nranks_var(self, name, value=-1): + helper = LayerHelper('global_step_counter') + counter, is_new_var = helper.create_or_get_global_variable( + name=name, dtype='float32', shape=[1], persistable=True + ) + if is_new_var: + helper.set_variable_initializer( + counter, + initializer=paddle.fluid.initializer.Constant( + value=float(value), force_cpu=True + ), + ) + counter.stop_gradient = True + + return counter + + def _append_dgc_ops(self, param_and_grads): + main_program = paddle.static.default_main_program() + main_program._enable_dgc = True + + # step counter + self._global_step_var = self._add_auto_increment_var( + counter_name=core.dgc.kDGCCounterName(), begin=0 + ) + + self._nranks_var = self._add_nranks_var( + name=core.dgc.kDGCNRanksName(), value=-1 + ) + + # rampup begin step var for all_reduce_op_handle + self._rampup_begin_step_var = tensor.create_global_var( + shape=[1], + dtype=core.VarDesc.VarType.FP32, + persistable=True, + name=core.dgc.kDGCRampUpBeginStepName(), + value=self._rampup_begin_step * 1.0, + force_cpu=True, + ) + + self.helper = LayerHelper(self.__class__.__name__) + + for param_var, grad_var in param_and_grads: + # reuse velocity in dgc_op and dgc_momentum_op + u_var = self._add_accumulator(self._u_velocity_acc_str, param_var) + + if not self._is_use_dgc(param_var, grad_var): + continue + + v_var = self._add_accumulator(self._v_velocity_acc_str, param_var) + + k_var = tensor.create_global_var( + shape=[1], + dtype=param_var.dtype, + persistable=True, + name=param_var.name + core.dgc.kDGCKName(), + value=0.0, + force_cpu=True, + ) + + encoded_var = tensor.create_global_var( + shape=[1], + dtype=param_var.dtype, + persistable=True, + name=param_var.name + core.dgc.kDGCEncodedName(), + value=0.0, + force_cpu=False, + ) + + gather_var = tensor.create_global_var( + shape=[1], + dtype=param_var.dtype, + persistable=True, + name=param_var.name + core.dgc.kDGCGatherName(), + value=0.0, + force_cpu=False, + ) + + # del back oprolevarname + op_maker = core.op_proto_and_checker_maker + backward = core.op_proto_and_checker_maker.OpRole.Backward + for op in main_program.global_block().ops: + if not self._is_the_backward_op(op): + continue + + var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()] + if param_var.name not in var_attr: + continue + + var_attr.remove(param_var.name) + var_attr.remove(grad_var.name) + if len(var_attr) > 1: + op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr) + else: + op._remove_attr(op_maker.kOpRoleVarAttrName()) + + clip_var = grad_var + if self._dgc_clip_norm is not None: + clip_var = self._append_clip_norm(grad_var, self._dgc_clip_norm) + self._dgc_op( + param_var, + clip_var, + grad_var, + u_var, + v_var, + k_var, + encoded_var, + gather_var, + ) + + def _is_the_backward_op(self, op): + op_maker = core.op_proto_and_checker_maker + backward = core.op_proto_and_checker_maker.OpRole.Backward + if op_maker.kOpRoleVarAttrName() in op.attr_names and int( + op.all_attrs()[op_maker.kOpRoleAttrName()] + ) == int(backward): + return True + return False + + def _clip_by_norm(self, x, max_norm, name=None): + args = {'x': x, 'max_norm': max_norm, 'name': name} + + helper = LayerHelper("dgc_clip_by_norm_op", **args) + + if name is None: + name = paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join([helper.name, 'tmp']) + ) + + out = helper.create_variable( + type=x.type, name=name, dtype=x.dtype, persistable=False + ) + + helper.append_op( + type="dgc_clip_by_norm", + inputs={"X": x, "current_step": self._global_step_var}, + attrs={ + "max_norm": max_norm, + "rampup_begin_step": float(self._rampup_begin_step), + }, + outputs={"Out": out}, + ) + return out + + def _append_clip_norm(self, grad_var, clip_norm): + with grad_var.block.program._backward_role_guard(): + return self._clip_by_norm( + x=grad_var, max_norm=clip_norm, name=grad_var.name + ) + + def _dgc_op( + self, + param_var, + clip_var, + grad_var, + u_var, + v_var, + k_var, + encoded_var, + gather_var, + ): + block = paddle.static.default_main_program().global_block() + op_maker = core.op_proto_and_checker_maker + + regular_type = self.regular_type + regular_coeff = self.regular_coeff + # The regularizer of the Parameters have higher priority + if param_var.regularizer is not None: + regular_type, regular_coeff = self._get_regularization_param( + param_var.regularizer + ) + + dgc_op = block.append_op( + type="dgc", + inputs={ + "U": u_var, + "V": v_var, + "Grad": clip_var, + "Param": param_var, + "current_step": self._global_step_var, + "nranks": self._nranks_var, + }, + outputs={ + "U_out": u_var, + "V_out": v_var, + "EncodeGrad": encoded_var, + "k": k_var, + "Grad_out": grad_var, + "GatherBuff": gather_var, + }, + attrs={ + "m": self._momentum, + "sparsity": self._sparsity, + "use_nesterov": self._use_nesterov, + "rampup_begin_step": float(self._rampup_begin_step), + "rampup_step": float(self._rampup_step), + "regular_coeff": float(regular_coeff), + "regular_type": int(regular_type), + }, + stop_gradient=True, + ) + + backward = op_maker.OpRole.Backward + dgc_op._set_attr(op_maker.kOpRoleAttrName(), backward) + dgc_op._set_attr( + op_maker.kOpRoleVarAttrName(), [param_var.name, grad_var.name] + ) + + @imperative_base.no_grad() + def apply_gradients(self, params_grads): + # Note: since we can't use all_reduce_op now, + # dgc_op should be the last op of one grad. + # Maybe need a grad allreduce pass. + self._append_dgc_ops(params_grads) + + params_grads = sorted(params_grads, key=lambda x: x[0].name) + ( + params_grads, + table_param_and_grad, + table_optimize_op, + ) = self._process_distribute_lookuptable(params_grads) + + not_dgc_params_grads = [] + dgc_params_grads = [] + # DGC clip and regularization in optimizer.backward + for param, grad in params_grads: + if not self._is_use_dgc(param, grad): + not_dgc_params_grads.append((param, grad)) + else: + dgc_params_grads.append((param, grad)) + + # 'optimizer(grad_clip)' or 'set_gradient_clip' + if self._grad_clip is not None: + not_dgc_params_grads = self._grad_clip(not_dgc_params_grads) + else: + not_dgc_params_grads = append_gradient_clip_ops( + not_dgc_params_grads + ) + + not_dgc_params_grads = self.append_regularization_ops( + not_dgc_params_grads, self.regularization + ) + + params_grads = not_dgc_params_grads + dgc_params_grads + params_grads = sorted(params_grads, key=lambda x: x[0].name) + + optimize_ops = self._create_optimization_pass(params_grads) + if table_optimize_op is not None: + optimize_ops.append(table_optimize_op) + params_grads.append(table_param_and_grad) + + return optimize_ops + class DGCOptimizer(MetaOptimizerBase): def __init__(self, optimizer): diff --git a/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py index 0ab95830ba..1a29448e02 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.fluid import core, framework, unique_name +from paddle.framework import core +from paddle.utils import unique_name from .meta_optimizer_base import MetaOptimizerBase +import paddle __all__ = [] @@ -133,7 +135,7 @@ class FP16AllReduceOptimizer(MetaOptimizerBase): with block.program._optimized_guard( [param, grad] - ), framework.name_scope('fp16_allreduce'): + ), paddle.static.name_scope('fp16_allreduce'): cast_op = block.append_op( type="cast", inputs={"X": grad}, diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index 1dfdce6f6c..ccc4fecbb5 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -13,8 +13,7 @@ import copy import paddle -from paddle.fluid.framework import core -from paddle.fluid import compiler +from paddle.framework import core from .meta_optimizer_base import MetaOptimizerBase from ..base.private_helper_function import wait_server_ready import logging @@ -247,7 +246,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ) local_build_strategy.enable_backward_optimizer_op_deps = True - self._compiled_program = compiler.CompiledProgram(main_program) + self._compiled_program = paddle.static.CompiledProgram(main_program) self._compiled_program.with_data_parallel( loss_name=loss.name, diff --git a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py index 0f1ba5d29d..b160c5f6fa 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import LambOptimizer as LAMB from .meta_optimizer_base import MetaOptimizerBase +from paddle.fluid.optimizer import AdamOptimizer import logging __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py index 0eb4be0ca8..5c716bd375 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.fluid.optimizer import Momentum, LarsMomentumOptimizer +from paddle.fluid.optimizer import Momentum +from paddle.fluid.optimizer import LarsMomentumOptimizer from .meta_optimizer_base import MetaOptimizerBase import logging diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index 62ff253fb7..e73d3c6b4b 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -13,8 +13,8 @@ # limitations under the License. import paddle -from paddle.fluid import program_guard, layers, default_main_program -from paddle.fluid import default_startup_program +from paddle.static import program_guard, default_main_program +from paddle.static import default_startup_program from .meta_optimizer_base import MetaOptimizerBase from .common import CollectiveHelper, OP_ROLE_KEY, OpRole @@ -83,7 +83,7 @@ class LocalSGDOptimizer(MetaOptimizerBase): def init_snapshot_vars(self, startup_program, param2snapshot): with program_guard(startup_program): for param, snapshot in param2snapshot: - layers.assign(param, snapshot) + paddle.assign(param, snapshot) def minimize_impl( self, loss, startup_program=None, parameter_list=None, no_grad_set=None @@ -109,8 +109,8 @@ class LocalSGDOptimizer(MetaOptimizerBase): p2s = self.create_snapshot_vars(main_block.program) with program_guard(main_block.program, startup_program): - step = layers.autoincreased_step_counter(begin=1) - k_steps = layers.create_global_var( + step = paddle.fluid.layers.autoincreased_step_counter(begin=1) + k_steps = paddle.static.create_global_var( name="k_steps", shape=[1], value=k_steps_value, @@ -118,7 +118,7 @@ class LocalSGDOptimizer(MetaOptimizerBase): persistable=True, ) - begin_step = layers.create_global_var( + begin_step = paddle.static.create_global_var( name="begin_step", shape=[1], value=begin_step_value, @@ -126,7 +126,7 @@ class LocalSGDOptimizer(MetaOptimizerBase): persistable=True, ) - last_step = layers.create_global_var( + last_step = paddle.static.create_global_var( name="last_step", shape=[1], value=begin_step_value, @@ -194,12 +194,14 @@ class LocalSGDOptimizer(MetaOptimizerBase): outputs={'Out': [snapshot]}, attrs={OP_ROLE_KEY: OpRole.Optimize}, ) - layers.assign(step, last_step) + paddle.assign(step, last_step) def begin_localsgd(): - layers.cond(step - last_step == k_steps, communicate) + paddle.static.nn.cond(step - last_step == k_steps, communicate) - layers.cond(step > begin_step, begin_localsgd, communicate) + paddle.static.nn.cond( + step > begin_step, begin_localsgd, communicate + ) return minimized @@ -225,7 +227,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): return False return ( - isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) + isinstance(self.inner_opt, paddle.optimizer.Momentum) or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD) @@ -268,7 +270,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): def init_snapshot_vars(self, startup_program, param2snapshot): with program_guard(startup_program): for param, snapshot in param2snapshot: - layers.assign(param, snapshot) + paddle.assign(param, snapshot) def _generate_avg_loss(self, program_block, loss, avg_loss): program_block.append_op( @@ -324,9 +326,9 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): p2s = self.create_snapshot_vars(main_block.program) with program_guard(main_block.program, startup_program): - step = layers.autoincreased_step_counter(begin=1) + step = paddle.fluid.layers.autoincreased_step_counter(begin=1) - k_steps = layers.create_global_var( + k_steps = paddle.static.create_global_var( name="k_steps", shape=[1], value=int(init_k_steps), @@ -334,7 +336,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): persistable=True, ) - begin_step = layers.create_global_var( + begin_step = paddle.static.create_global_var( name="begin_step", shape=[1], value=int(begin_step_value), @@ -342,7 +344,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): persistable=True, ) - last_step = layers.create_global_var( + last_step = paddle.static.create_global_var( name="last_step", shape=[1], value=int(0), @@ -350,7 +352,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): persistable=True, ) - avg_loss = layers.create_global_var( + avg_loss = paddle.static.create_global_var( name="avg_loss", shape=[1], value=float(0), @@ -358,7 +360,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): persistable=True, ) - lr_0 = layers.create_global_var( + lr_0 = paddle.static.create_global_var( name="lr_0", shape=[1], value=float(0), @@ -366,7 +368,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): persistable=True, ) - loss_0 = layers.create_global_var( + loss_0 = paddle.static.create_global_var( name="loss_0", shape=[1], value=float(0), @@ -378,10 +380,10 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): def initialize(): self._generate_avg_loss(main_block, loss, avg_loss) - layers.assign(avg_loss, loss_0) - layers.assign(global_lr, lr_0) + paddle.assign(avg_loss, loss_0) + paddle.assign(global_lr, lr_0) - layers.cond(step == 1, initialize) + paddle.static.nn.cond(step == 1, initialize) def communicate(): sub_block = default_main_program().current_block() @@ -443,12 +445,13 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): outputs={'Out': [snapshot]}, attrs={OP_ROLE_KEY: OpRole.Optimize}, ) - layers.assign(step, last_step) + paddle.assign(step, last_step) def communicate_avg_loss(): communicate() self._generate_avg_loss(main_block, loss, avg_loss) - next_local_steps = layers.cast( + + next_local_steps = paddle.cast( paddle.ceil( paddle.sqrt( lr_0 @@ -459,11 +462,11 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): ), dtype='int64', ) - max_local_steps = layers.fill_constant( - shape=[1], dtype='int64', value=16 + max_local_steps = paddle.full( + shape=[1], dtype='int64', fill_value=16 ) - min_local_steps = layers.fill_constant( - shape=[1], dtype='int64', value=1 + min_local_steps = paddle.full( + shape=[1], dtype='int64', fill_value=1 ) next_local_steps = paddle.minimum( next_local_steps, max_local_steps @@ -471,11 +474,15 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): next_local_steps = paddle.maximum( next_local_steps, min_local_steps ) - layers.assign(next_local_steps, k_steps) + paddle.assign(next_local_steps, k_steps) def begin_localsgd(): - layers.cond(step - last_step == k_steps, communicate_avg_loss) + paddle.static.nn.cond( + step - last_step == k_steps, communicate_avg_loss + ) - layers.cond(step > begin_step, begin_localsgd, communicate) + paddle.static.nn.cond( + step > begin_step, begin_localsgd, communicate + ) return minimized diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py index 22a1b82541..74d57fe59b 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.fluid import compiler from .parameter_server_optimizer import ParameterServerOptimizer +import paddle __all__ = [] @@ -56,7 +56,7 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer): build_strategy = dist_strategy.get_build_strategy() exec_strategy = dist_strategy.get_execute_strategy() - self._compiled_program = compiler.CompiledProgram(main_program) + self._compiled_program = paddle.static.CompiledProgram(main_program) self._compiled_program.with_data_parallel( loss_name=loss.name, diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py index 2ea83ada81..362dec4e62 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +import paddle from paddle import fluid from .meta_optimizer_base import MetaOptimizerBase -from paddle.fluid import core +from paddle.framework import core import subprocess import re import os @@ -185,8 +186,8 @@ class ParameterServerOptimizer(MetaOptimizerBase): return _main, _startup def _build_pserver_programs(self, compiled_config): - _main = fluid.Program() - _startup = fluid.Program() + _main = paddle.static.Program() + _startup = paddle.static.Program() from paddle.fluid.incubate.fleet.parameter_server.ir import ( pserver_pass as server, diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index 655670f305..45dde10b1e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -import paddle.fluid as fluid +import paddle from paddle.fluid.optimizer import PipelineOptimizer as PO from .meta_optimizer_base import MetaOptimizerBase from .common import ( @@ -210,7 +210,7 @@ class PipelineOptimizer(MetaOptimizerBase): orig_startup_program = ( startup_program if startup_program - else fluid.default_startup_program() + else paddle.static.default_startup_program() ) block = loss.block program = block.program diff --git a/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py index 66b8acb4d7..31fcf3450d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle import fluid import paddle.distributed.passes from .meta_optimizer_base import MetaOptimizerBase -from paddle.fluid import core +from paddle.framework import core import subprocess import re import os @@ -111,8 +110,8 @@ class ParameterServerOptimizer(MetaOptimizerBase): build_var_distributed(attrs) # server - attrs['_main_server'] = fluid.Program() - attrs['_startup_server'] = fluid.Program() + attrs['_main_server'] = paddle.static.Program() + attrs['_startup_server'] = paddle.static.Program() attrs['tensor_table'] = {} self.pass_ctx._attrs = attrs @@ -203,7 +202,7 @@ class ParameterServerOptimizer(MetaOptimizerBase): % (platform.system()) ) - if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer): + if not isinstance(self.inner_opt, paddle.fluid.optimizer.SGDOptimizer): return False free = get_sys_free_mem() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 8e030a54d8..c724a0f348 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1744,478 +1744,6 @@ class MomentumOptimizer(Optimizer): return momentum_op -class DGCMomentumOptimizer(Optimizer): - r""" - :api_attr: Static Graph - - DGC (Deep Gradient Compression) Momentum Optimizer. Original paper is https://arxiv.org/abs/1712.01887 - - DGC reduces the communication bandwidth by sending only the important gradients (sparse update):\ - only gradients larger than a threshold are transmitted. - - To avoid losing information, DGC accumulates the rest of the gradients locally. - - Eventually, these gradients become large enough to be transmitted. - - Thus, DGC sends the large gradients immediately but eventually sends all of the gradients over time. - - To ensure no loss of accuracy, DGC employs momentum correction and local gradient clipping on top of the gradient sparsification to maintain model performance. - - DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication. - - This optimizer will do two things: - - 1. Compress the gradient by get TopK import value from tensor \ - and use it for allreduce to reduce network bandwidth. - - 2. Call momentum to optimize the cost. - - Args: - learning_rate (float|Variable): The learning rate used to update parameters. \ - It can be a float value or a Variable with one float value as a data element. - momentum (float): Momentum factor. - rampup_begin_step (int): The beginning step from which gradient compression is implemented. - rampup_step (int): Time steps used in sparsity warm-up periods. Default is 1. - For example, if the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 100, \ - it will use 0.75 at 0~19 steps, and 0.9375 at 20~39 steps, and so on. \ - And when reach sparsity array ends, it will use 0.999 then and after. - sparsity (list[float]): Get top important element from gradient tensor, the ratio is (1 - current sparsity). \ - Default is [0.999]. For example, if the sparsity is [0.99, 0.999], \ - the top [1%, 0.1%] important element will be transmitted. - parameter_list (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. - use_nesterov (bool): Enables Nesterov momentum. True means use Nesterov. Default is False. - regularization (WeightDecayRegularizer, optional): The strategy of regularization. There are two method: \ - :ref:`api_fluid_regularizer_L1Decay` , :ref:`api_fluid_regularizer_L2Decay` . If a parameter has set \ - regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ - ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ - Default None, meaning there is no regularization. - grad_clip (GradientClipByNorm, optional): Gradient cliping strategy. ``DGCMomentumOptimizer`` only support - :ref:`api_fluid_clip_GradientClipByNorm` , and if not, it will raise TypeError. Default None, - meaning there is no gradient clipping. - name (str, optional): This parameter is used by developers to print debugging information. \ - For details, please refer to :ref:`api_guide_Name`. Default is None. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - optimizer = fluid.optimizer.DGCMomentumOptimizer( - learning_rate=0.0001, - momentum=0.9, - rampup_step=1000, - rampup_begin_step=1252, - sparsity=[0.999, 0.999]) - - """ - _u_velocity_acc_str = "_dgc_u_" - _v_velocity_acc_str = "_dgc_v_" - - def __init__( - self, - learning_rate, - momentum, - rampup_begin_step, - rampup_step=1, - sparsity=[0.999], - parameter_list=None, - use_nesterov=False, - num_trainers=None, - regularization=None, - grad_clip=None, - name=None, - ): - if framework._non_static_mode(): - raise Exception("In dygraph, don't support DGCMomentumOptimizer.") - - assert ( - core.is_compiled_with_cuda() - ), "Paddle is not compiled with CUDA. DGC is only support GPU for now." - - assert learning_rate is not None - assert momentum is not None - super().__init__( - learning_rate=learning_rate, - parameter_list=parameter_list, - regularization=regularization, - grad_clip=grad_clip, - name=name, - ) - self.type = "dgc_momentum" - self._momentum = momentum - self._use_nesterov = bool(use_nesterov) - - assert rampup_begin_step >= 0, "rampup_begin_step must >= 0" - self._rampup_begin_step = rampup_begin_step - self._rampup_step = rampup_step - self._sparsity = sparsity - - self._rampup_begin_step_var = None - self._global_step_var = None - - self._dgc_clip_norm = None - if grad_clip is not None: - if not isinstance(grad_clip, GradientClipByNorm): - raise TypeError( - "The type of grad_clip should be 'GradientClipByNorm', because DGCMomentumOptimizer only support GradientClipByNorm" - ) - assert isinstance(num_trainers, int), ( - "The type of num_trainers should be 'int', but received %s" - % type(num_trainers) - ) - assert ( - num_trainers > 0 - ), "The value of num_trainers should be greater than 0!" - - self._num_trainers = num_trainers - self._dgc_clip_norm = grad_clip.clip_norm * (num_trainers**-0.5) - - self.regular_type, self.regular_coeff = self._get_regularization_param( - self.regularization - ) - - def _get_regularization_param(self, regularization): - regular_type = 0 - regular_coeff = 0.0 - - if regularization is not None: - regular_coeff = regularization._regularization_coeff - from .regularizer import L1Decay, L2Decay - - if isinstance(regularization, L1Decay): - regular_type = 1 - elif isinstance(regularization, L2Decay): - regular_type = 2 - else: - assert False, 'regularization must be None|L1Decay|L2Deacy' - return regular_type, regular_coeff - - def _is_use_dgc(self, param_var, grad_var): - var_numel = abs(reduce(lambda x, y: x * y, param_var.shape)) - if ( - var_numel < 16384 - or param_var.type == core.VarDesc.VarType.SELECTED_ROWS - or grad_var.type == core.VarDesc.VarType.SELECTED_ROWS - or param_var.dtype != core.VarDesc.VarType.FP32 - ): - return False - return True - - def _append_optimize_op(self, block, param_and_grad): - assert isinstance(block, framework.Block) - velocity_acc = self._get_accumulator( - self._u_velocity_acc_str, param_and_grad[0] - ) - assert velocity_acc is not None - - inputs = { - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Velocity": velocity_acc, - "LearningRate": self._create_param_lr(param_and_grad), - } - outputs = { - "ParamOut": param_and_grad[0], - "VelocityOut": velocity_acc, - } - attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov} - - if not self._is_use_dgc(param_and_grad[0], param_and_grad[1]): - type = "momentum" - else: - type = "dgc_momentum" - inputs.update( - { - "current_step": self._global_step_var, - "nranks": self._nranks_var, - } - ) - outputs.update({'Grad_out': param_and_grad[1]}) - attrs.update({"rampup_begin_step": float(self._rampup_begin_step)}) - - # create the dgc momentum optimize op - dgc_momentum_op = block.append_op( - type=type, - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True, - ) - return dgc_momentum_op - - def _add_auto_increment_var(self, counter_name, begin, step=1): - helper = LayerHelper('global_step_counter') - counter, is_new_var = helper.create_or_get_global_variable( - name=counter_name, dtype='float32', shape=[1], persistable=True - ) - if is_new_var: - helper.set_variable_initializer( - counter, - initializer=Constant(value=float(begin - 1), force_cpu=True), - ) - helper.main_program.global_block()._prepend_op( - type='increment', - inputs={'X': [counter]}, - outputs={'Out': [counter]}, - attrs={'step': float(step)}, - stop_gradient=True, - ) - counter.stop_gradient = True - - return counter - - def _add_nranks_var(self, name, value=-1): - helper = LayerHelper('global_step_counter') - counter, is_new_var = helper.create_or_get_global_variable( - name=name, dtype='float32', shape=[1], persistable=True - ) - if is_new_var: - helper.set_variable_initializer( - counter, - initializer=Constant(value=float(value), force_cpu=True), - ) - counter.stop_gradient = True - - return counter - - def _append_dgc_ops(self, param_and_grads): - main_program = default_main_program() - main_program._enable_dgc = True - - # step counter - self._global_step_var = self._add_auto_increment_var( - counter_name=core.dgc.kDGCCounterName(), begin=0 - ) - - self._nranks_var = self._add_nranks_var( - name=core.dgc.kDGCNRanksName(), value=-1 - ) - - # rampup begin step var for all_reduce_op_handle - self._rampup_begin_step_var = tensor.create_global_var( - shape=[1], - dtype=core.VarDesc.VarType.FP32, - persistable=True, - name=core.dgc.kDGCRampUpBeginStepName(), - value=self._rampup_begin_step * 1.0, - force_cpu=True, - ) - - self.helper = LayerHelper(self.__class__.__name__) - - for param_var, grad_var in param_and_grads: - # reuse velocity in dgc_op and dgc_momentum_op - u_var = self._add_accumulator(self._u_velocity_acc_str, param_var) - - if not self._is_use_dgc(param_var, grad_var): - continue - - v_var = self._add_accumulator(self._v_velocity_acc_str, param_var) - - k_var = tensor.create_global_var( - shape=[1], - dtype=param_var.dtype, - persistable=True, - name=param_var.name + core.dgc.kDGCKName(), - value=0.0, - force_cpu=True, - ) - - encoded_var = tensor.create_global_var( - shape=[1], - dtype=param_var.dtype, - persistable=True, - name=param_var.name + core.dgc.kDGCEncodedName(), - value=0.0, - force_cpu=False, - ) - - gather_var = tensor.create_global_var( - shape=[1], - dtype=param_var.dtype, - persistable=True, - name=param_var.name + core.dgc.kDGCGatherName(), - value=0.0, - force_cpu=False, - ) - - # del back oprolevarname - op_maker = core.op_proto_and_checker_maker - backward = core.op_proto_and_checker_maker.OpRole.Backward - for op in main_program.global_block().ops: - if not self._is_the_backward_op(op): - continue - - var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()] - if param_var.name not in var_attr: - continue - - var_attr.remove(param_var.name) - var_attr.remove(grad_var.name) - if len(var_attr) > 1: - op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr) - else: - op._remove_attr(op_maker.kOpRoleVarAttrName()) - - clip_var = grad_var - if self._dgc_clip_norm is not None: - clip_var = self._append_clip_norm(grad_var, self._dgc_clip_norm) - self._dgc_op( - param_var, - clip_var, - grad_var, - u_var, - v_var, - k_var, - encoded_var, - gather_var, - ) - - def _is_the_backward_op(self, op): - op_maker = core.op_proto_and_checker_maker - backward = core.op_proto_and_checker_maker.OpRole.Backward - if op_maker.kOpRoleVarAttrName() in op.attr_names and int( - op.all_attrs()[op_maker.kOpRoleAttrName()] - ) == int(backward): - return True - return False - - def _clip_by_norm(self, x, max_norm, name=None): - args = {'x': x, 'max_norm': max_norm, 'name': name} - - helper = LayerHelper("dgc_clip_by_norm_op", **args) - - if name is None: - name = unique_name.generate_with_ignorable_key( - ".".join([helper.name, 'tmp']) - ) - - out = helper.create_variable( - type=x.type, name=name, dtype=x.dtype, persistable=False - ) - - helper.append_op( - type="dgc_clip_by_norm", - inputs={"X": x, "current_step": self._global_step_var}, - attrs={ - "max_norm": max_norm, - "rampup_begin_step": float(self._rampup_begin_step), - }, - outputs={"Out": out}, - ) - return out - - def _append_clip_norm(self, grad_var, clip_norm): - with grad_var.block.program._backward_role_guard(): - return self._clip_by_norm( - x=grad_var, max_norm=clip_norm, name=grad_var.name - ) - - def _dgc_op( - self, - param_var, - clip_var, - grad_var, - u_var, - v_var, - k_var, - encoded_var, - gather_var, - ): - block = framework.default_main_program().global_block() - op_maker = core.op_proto_and_checker_maker - - regular_type = self.regular_type - regular_coeff = self.regular_coeff - # The regularizer of the Parameters have higher priority - if param_var.regularizer is not None: - regular_type, regular_coeff = self._get_regularization_param( - param_var.regularizer - ) - - dgc_op = block.append_op( - type="dgc", - inputs={ - "U": u_var, - "V": v_var, - "Grad": clip_var, - "Param": param_var, - "current_step": self._global_step_var, - "nranks": self._nranks_var, - }, - outputs={ - "U_out": u_var, - "V_out": v_var, - "EncodeGrad": encoded_var, - "k": k_var, - "Grad_out": grad_var, - "GatherBuff": gather_var, - }, - attrs={ - "m": self._momentum, - "sparsity": self._sparsity, - "use_nesterov": self._use_nesterov, - "rampup_begin_step": float(self._rampup_begin_step), - "rampup_step": float(self._rampup_step), - "regular_coeff": float(regular_coeff), - "regular_type": int(regular_type), - }, - stop_gradient=True, - ) - - backward = op_maker.OpRole.Backward - dgc_op._set_attr(op_maker.kOpRoleAttrName(), backward) - dgc_op._set_attr( - op_maker.kOpRoleVarAttrName(), [param_var.name, grad_var.name] - ) - - @imperative_base.no_grad - def apply_gradients(self, params_grads): - # Note: since we can't use all_reduce_op now, - # dgc_op should be the last op of one grad. - # Maybe need a grad allreduce pass. - self._append_dgc_ops(params_grads) - - params_grads = sorted(params_grads, key=lambda x: x[0].name) - ( - params_grads, - table_param_and_grad, - table_optimize_op, - ) = self._process_distribute_lookuptable(params_grads) - - not_dgc_params_grads = [] - dgc_params_grads = [] - # DGC clip and regularization in optimizer.backward - for param, grad in params_grads: - if not self._is_use_dgc(param, grad): - not_dgc_params_grads.append((param, grad)) - else: - dgc_params_grads.append((param, grad)) - - # 'optimizer(grad_clip)' or 'set_gradient_clip' - if self._grad_clip is not None: - not_dgc_params_grads = self._grad_clip(not_dgc_params_grads) - else: - not_dgc_params_grads = append_gradient_clip_ops( - not_dgc_params_grads - ) - - not_dgc_params_grads = self.append_regularization_ops( - not_dgc_params_grads, self.regularization - ) - - params_grads = not_dgc_params_grads + dgc_params_grads - params_grads = sorted(params_grads, key=lambda x: x[0].name) - - optimize_ops = self._create_optimization_pass(params_grads) - if table_optimize_op is not None: - optimize_ops.append(table_optimize_op) - params_grads.append(table_param_and_grad) - - return optimize_ops - - class LarsMomentumOptimizer(Optimizer): r""" Momentum optimizer with LARS support diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dgc_optimizer.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dgc_optimizer.py index 0da05a377b..335916e520 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dgc_optimizer.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dgc_optimizer.py @@ -24,7 +24,9 @@ paddle.enable_static() class TestDGCMomentumOptimizer(unittest.TestCase): - class MockDGCMomentum(optimizer.DGCMomentumOptimizer): + class MockDGCMomentum( + paddle.distributed.fleet.meta_optimizers.DGCMomentumOptimizer + ): def get_accumulators(self): return self._accumulators diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 856ac1b930..2df9549918 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -93,7 +93,7 @@ class TestDistMnist2x2(TestDistRunnerBase): if not use_dgc: opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) else: - opt = fluid.optimizer.DGCMomentumOptimizer( + opt = paddle.distributed.fleet.meta_optimizers.DGCMomentumOptimizer( learning_rate=self.lr, momentum=0.9, rampup_begin_step=2 ) diff --git a/python/paddle/fluid/tests/unittests/dist_se_resnext.py b/python/paddle/fluid/tests/unittests/dist_se_resnext.py index 0d8ed873f0..ae7fb207d2 100644 --- a/python/paddle/fluid/tests/unittests/dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/dist_se_resnext.py @@ -247,13 +247,15 @@ class DistSeResneXt2x2(TestDistRunnerBase): regularization=fluid.regularizer.L2Decay(1e-4), ) else: - optimizer = fluid.optimizer.DGCMomentumOptimizer( - learning_rate=fluid.layers.piecewise_decay( - boundaries=bd, values=lr - ), - momentum=0.9, - rampup_begin_step=0, - regularization=fluid.regularizer.L2Decay(1e-4), + optimizer = ( + paddle.distributed.fleet.meta_optimizers.DGCMomentumOptimizer( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr + ), + momentum=0.9, + rampup_begin_step=0, + regularization=fluid.regularizer.L2Decay(1e-4), + ) ) optimizer.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index a75208d88d..917876c974 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -35,7 +35,6 @@ from paddle.fluid.optimizer import ( ) from paddle.fluid.optimizer import ( ModelAverage, - DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, @@ -45,6 +44,8 @@ from paddle.fluid.dygraph import Linear from test_imperative_base import new_program_scope from paddle.fluid.framework import _test_eager_guard +from paddle.distributed.fleet.meta_optimizers import DGCMomentumOptimizer + # Note(wangzhongpu) # In dygraph, don't support ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer. diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py index 4023d3596b..0c6853ce65 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py @@ -32,7 +32,6 @@ from paddle.fluid.optimizer import ( ) from paddle.fluid.optimizer import ( ModelAverage, - DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, @@ -42,6 +41,8 @@ from paddle.fluid.dygraph import Linear from test_imperative_base import new_program_scope from paddle.fluid.framework import _test_eager_guard +from paddle.distributed.fleet.meta_optimizers import DGCMomentumOptimizer + # Note(wangzhongpu) # In dygraph, don't support ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer. -- GitLab