diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 551d1342edeb335d1cad4782f85ae9f94f8739bd..8d0093388b484a5c16bd4c6a0d1aeae52bb200ab 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -52,6 +52,8 @@ message DGCConfig { message LarsConfig { optional float lars_coeff = 1 [ default = 0.001 ]; optional float lars_weight_decay = 2 [ default = 0.0005 ]; + optional float epsilon = 3 [ default = 0.0 ]; + repeated string exclude_from_weight_decay = 4; } message LambConfig { diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc old mode 100644 new mode 100755 index 5f0500d2faa77f7c2e901c0d30ab2c42036d2a86..479f9643749d63c673158ad055409a0925f3d576 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -48,6 +48,9 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("lars_weight_decay", "(float, default 0.0005) LARS weight decay") .SetDefault(0.0005); + AddAttr("epsilon", + "(float, default 0.0) epsilon to avoid Division by Zero.") + .SetDefault(0.0); AddComment(R"DOC( Lars Momentum Optimizer. diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 1dace4ed6ab3e17b348035e34f6d9ea6d31edae9..eb0111ae4de2f066359e26406f6c7ec3eb54d5fc 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -23,14 +23,16 @@ __global__ void MomentumLarsKernel(const T* p, const T* g, const T* v, const T* learning_rate, const T mu, const int64_t num, const T lars_coeff, const T lars_weight_decay, const T* p_norm, - const T* g_norm, T* p_out, T* v_out) { + const T* g_norm, T* p_out, T* v_out, + const T epsilon) { T lr = learning_rate[0]; T local_lr = learning_rate[0]; CUDA_KERNEL_LOOP(i, num) { - if (p_norm[0] > 0 && g_norm[0] > 0) { + if (lars_weight_decay > 0 && p_norm[0] > 0 && g_norm[0] > 0) { local_lr = lr * lars_coeff * p_norm[0] / - (g_norm[0] + lars_weight_decay * p_norm[0]); + (g_norm[0] + lars_weight_decay * p_norm[0] + epsilon); } + T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]); v_out[i] = v_new; p_out[i] = p[i] - v_new; @@ -54,6 +56,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { T mu = static_cast(ctx.Attr("mu")); T lars_coeff = ctx.Attr("lars_coeff"); T lars_weight_decay = ctx.Attr("lars_weight_decay"); + T epsilon = ctx.Attr("epsilon"); auto* p = param->data(); auto* v = velocity->data(); @@ -79,7 +82,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { eg_norm.device(*place) = eigen_g.square().sum().sqrt(); MomentumLarsKernel<<>>( p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay, - p_norm_data, g_norm_data, p_out, v_out); + p_norm_data, g_norm_data, p_out, v_out, epsilon); } }; diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h old mode 100644 new mode 100755 index e0064c201825b1f074eb53c591dc3abdd7bc1e1b..b579b5143ddbe6221738f9864f13fb7bea4ac509 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.h @@ -39,6 +39,7 @@ class LarsMomentumOpKernel : public framework::OpKernel { T mu = static_cast(ctx.Attr("mu")); T lars_coeff = ctx.Attr("lars_coeff"); T lars_weight_decay = ctx.Attr("lars_weight_decay"); + T epsilon = ctx.Attr("epsilon"); auto p_out = framework::EigenVector::Flatten(*param_out); auto v_out = framework::EigenVector::Flatten(*velocity_out); @@ -59,9 +60,9 @@ class LarsMomentumOpKernel : public framework::OpKernel { ep_norm = p.square().sum().sqrt(); eg_norm = g.square().sum().sqrt(); T local_lr = lr[0]; - if (ep_norm(0) > 0 && eg_norm(0) > 0) { + if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { local_lr = lr[0] * lars_coeff * ep_norm(0) / - (eg_norm(0) + lars_weight_decay * ep_norm(0)); + (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); } v_out = v * mu + local_lr * (g + lars_weight_decay * p); p_out = p - v_out; diff --git a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py index 3a9f2be533b8bc176b2361eaffbc74d4b834749c..bfa186a1e7c46c9fbd5276880965ef3764d3abc3 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py @@ -91,6 +91,10 @@ class LambOptimizer(MetaOptimizerBase): return self.lamb_opt.backward(loss, startup_program, parameter_list, no_grad_set, callbacks) + # the following function will be used by AMP if both LARS and AMP are turn on together. + def apply_gradients(self, params_grads): + return self.lamb_opt.apply_gradients(params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py index cb12154ddc564687539d953c21b9e0597a8bf893..ec7a7eb18bcdfefe30d43fd9b71909e6eb827d99 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py @@ -44,13 +44,16 @@ class LarsOptimizer(MetaOptimizerBase): parameter_list=opt._parameter_list, regularization=opt.regularization, grad_clip=opt._grad_clip, - name=opt._name) + name=opt._name, + exclude_from_weight_decay=configs['exclude_from_weight_decay'], + epsilon=configs['epsilon']) def _can_apply(self): if self.user_defined_strategy.lars: if not isinstance(self.inner_opt, Momentum): logging.warn( - "lars need the inner optimizer to be Momentum optimizer.") + "lars need the inner optimizer to be Momentum optimizer but got {}.". + format(self.inner_opt.type)) return False return True return False @@ -75,6 +78,10 @@ class LarsOptimizer(MetaOptimizerBase): return self.lars_opt.backward(loss, startup_program, parameter_list, no_grad_set, callbacks) + # the following function will be used by AMP if both LARS and AMP are turn on together. + def apply_gradients(self, params_grads): + return self.lars_opt.apply_gradients(params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py old mode 100644 new mode 100755 index 8b37cfef3890eace0ff5141eeb91d85e78f1c964..192effd2e42dc937fbf47efdd1d772a4c078f888 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1604,7 +1604,7 @@ class LarsMomentumOptimizer(Optimizer): & local\_learning\_rate = learning\_rate * lars\_coeff * \\ \\frac{||param||}{||gradient|| + lars\_weight\_decay * ||param||} - & velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param) + & velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param + epsilon) & param = param - velocity @@ -1628,7 +1628,9 @@ class LarsMomentumOptimizer(Optimizer): :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. - + exclude_from_weight_decay (list[str], optional): Name string of layers which will be exclude from lars weight decay. Default is None. + epsilon (float, optional): Epsilon to avoid Division by Zero when calculate local lr. Default is 0. + Examples: .. code-block:: python @@ -1659,7 +1661,9 @@ class LarsMomentumOptimizer(Optimizer): parameter_list=None, regularization=None, grad_clip=None, - name=None): + name=None, + exclude_from_weight_decay=None, + epsilon=0): assert learning_rate is not None assert momentum is not None super(LarsMomentumOptimizer, self).__init__( @@ -1672,6 +1676,11 @@ class LarsMomentumOptimizer(Optimizer): self._momentum = momentum self._lars_coeff = float(lars_coeff) self._lars_weight_decay = float(lars_weight_decay) + self._epsilon = float(epsilon) + if exclude_from_weight_decay is None: + self._exclude_from_weight_decay = [] + else: + self._exclude_from_weight_decay = exclude_from_weight_decay def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -1682,6 +1691,14 @@ class LarsMomentumOptimizer(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + _lars_weight_decay = self._lars_weight_decay + param_name = param_and_grad[0].name + if len(self._exclude_from_weight_decay) > 0: + for name in self._exclude_from_weight_decay: + if name in param_name: + _lars_weight_decay = 0.0 + break + velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) # create the momentum optimize op @@ -1700,7 +1717,8 @@ class LarsMomentumOptimizer(Optimizer): attrs={ "mu": self._momentum, "lars_coeff": self._lars_coeff, - "lars_weight_decay": self._lars_weight_decay + "lars_weight_decay": _lars_weight_decay, + "epsilon": self._epsilon }, stop_gradient=True) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py index 3f140f53b043b1949572f3728ca8a0c556317783..ff305fb95231b96b6d8f951b2943a0ab47060ce0 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py @@ -22,11 +22,9 @@ import paddle.distributed.fleet.base.role_maker as role_maker class TestFleetLambMetaOptimizer(unittest.TestCase): def setUp(self): - os.environ["POD_IP"] = "127.0.0.1" - os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" - os.environ["PADDLE_TRAINERS_NUM"] = "2" - os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ - "127.0.0.1:36001,127.0.0.2:36001" + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" def net(self, main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog): @@ -97,13 +95,54 @@ class TestFleetLambMetaOptimizer(unittest.TestCase): optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer.minimize(avg_cost) - ops_with_bias = [ + ops_without_wd = [ op for op in avg_cost.block.ops if op.type == 'lamb' and op.attr('op_role_var')[0].endswith('.b_0') ] - for op in ops_with_bias: + for op in ops_without_wd: self.assertEqual(op.attr('weight_decay'), 0) + def test_lamb_apply_with_amp(self): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.amp = True + strategy.amp_configs = { + "init_loss_scaling": 32768, + "decr_every_n_nan_or_inf": 2, + "incr_every_n_steps": 1000, + "incr_ratio": 2.0, + "use_dynamic_loss_scaling": True, + "decr_ratio": 0.5, + "custom_white_list": ['softmax'], + "custom_black_list": ['tanh'], + } + strategy.lamb = True + strategy.lamb_configs = { + 'lamb_weight_decay': 0.01, + 'exclude_from_weight_decay': [], + } + + optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('lamb', ops) + self.assertIn('cast', ops) + self.assertIn('isfinite', ops) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py index 3caa1a4eac0bf191b13e6708b1a9adffdb111ca7..34ab423e064eebb9c93010fbc869adedb42bd6fa 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py @@ -22,11 +22,9 @@ import paddle.distributed.fleet.base.role_maker as role_maker class TestFleetLarsMetaOptimizer(unittest.TestCase): def setUp(self): - os.environ["POD_IP"] = "127.0.0.1" - os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" - os.environ["PADDLE_TRAINERS_NUM"] = "2" - os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ - "127.0.0.1:36001,127.0.0.2:36001" + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" def net(self, main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog): @@ -52,6 +50,8 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase): strategy.lars_configs = { "lars_coeff": 0.001, "lars_weight_decay": 0.0005, + "epsilon": 0, + "exclude_from_weight_decay": ["batch_norm", ".b"], } return avg_cost, strategy @@ -83,6 +83,70 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase): ops = [op.type for op in avg_cost.block.ops] self.assertNotIn('lars_momentum', ops) + def test_lars_exclude_fn(self): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + startup_prog = fluid.Program() + train_prog = fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + optimizer = paddle.fluid.optimizer.Momentum( + learning_rate=0.01, momentum=0.9) + + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops_without_wd = [ + op for op in avg_cost.block.ops + if op.type == 'lars_momentum' and ("batch_norm" in op.attr( + 'op_role_var')[0] or ".b" in op.attr('op_role_var')[0]) + ] + for op in ops_without_wd: + self.assertEqual(op.attr('lars_weight_decay'), 0) + + def test_lars_apply_with_amp(self): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.amp = True + strategy.amp_configs = { + "init_loss_scaling": 32768, + "decr_every_n_nan_or_inf": 2, + "incr_every_n_steps": 1000, + "incr_ratio": 2.0, + "use_dynamic_loss_scaling": True, + "decr_ratio": 0.5, + "custom_white_list": ['softmax'], + "custom_black_list": ['tanh'], + } + strategy.lars = True + strategy.lars_configs = { + "lars_coeff": 0.001, + "lars_weight_decay": 0.0005, + "epsilon": 0, + "exclude_from_weight_decay": ["batch_norm", ".b"], + } + + optimizer = paddle.fluid.optimizer.Momentum( + learning_rate=0.01, momentum=0.9) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('lars_momentum', ops) + self.assertIn('cast', ops) + self.assertIn('isfinite', ops) + if __name__ == "__main__": unittest.main()