test_fleet_lars_meta_optimizer.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import unittest
17

18
import paddle
19
import paddle.distributed.fleet as fleet
20
import paddle.distributed.fleet.base.role_maker as role_maker
21
from paddle import fluid
22

D
Dong Daxiang 已提交
23 24
paddle.enable_static()

25 26 27

class TestFleetLarsMetaOptimizer(unittest.TestCase):
    def setUp(self):
28 29
        os.environ["PADDLE_TRAINER_ID"] = "1"
        os.environ[
30 31
            "PADDLE_TRAINER_ENDPOINTS"
        ] = "127.0.0.1:36001,127.0.0.1:36002"
32

33 34 35
    def net(self, main_prog, startup_prog):
        with fluid.program_guard(main_prog, startup_prog):
            with fluid.unique_name.guard():
36 37 38 39 40 41 42 43 44 45
                input_x = paddle.fluid.layers.data(
                    name="x", shape=[32], dtype='float32'
                )
                input_y = paddle.fluid.layers.data(
                    name="y", shape=[1], dtype='int64'
                )

                fc_1 = paddle.fluid.layers.fc(
                    input=input_x, size=64, act='tanh'
                )
46
                fc_2 = paddle.fluid.layers.fc(input=fc_1, size=256, act='tanh')
47 48 49
                prediction = paddle.fluid.layers.fc(
                    input=[fc_2], size=2, act='softmax'
                )
50 51 52 53 54
                cost = paddle.nn.functional.cross_entropy(
                    input=prediction,
                    label=input_y,
                    reduction='none',
                    use_softmax=False,
55
                )
56
                avg_cost = paddle.mean(x=cost)
57

58
                strategy = paddle.distributed.fleet.DistributedStrategy()
59 60 61 62
                strategy.lars = True
                strategy.lars_configs = {
                    "lars_coeff": 0.001,
                    "lars_weight_decay": 0.0005,
63 64
                    "epsilon": 0,
                    "exclude_from_weight_decay": ["batch_norm", ".b"],
65
                }
66 67 68 69

        return avg_cost, strategy

    def test_lars_optimizer(self):
70 71 72 73 74
        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)
        startup_prog = fluid.Program()
        train_prog = fluid.Program()
        avg_cost, strategy = self.net(train_prog, startup_prog)
75 76 77
        optimizer = paddle.fluid.optimizer.Momentum(
            learning_rate=0.01, momentum=0.9
        )
78 79 80 81 82 83 84
        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
        optimizer.minimize(avg_cost)

        ops = [op.type for op in avg_cost.block.ops]
        self.assertIn('lars_momentum', ops)

    def test_lars_not_apply_with_adam(self):
85 86 87 88 89
        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)
        startup_prog = fluid.Program()
        train_prog = fluid.Program()
        avg_cost, strategy = self.net(train_prog, startup_prog)
M
MRXLT 已提交
90
        optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
91 92 93 94 95 96
        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
        optimizer.minimize(avg_cost)

        ops = [op.type for op in avg_cost.block.ops]
        self.assertNotIn('lars_momentum', ops)

97 98 99 100 101 102
    def test_lars_exclude_fn(self):
        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)
        startup_prog = fluid.Program()
        train_prog = fluid.Program()
        avg_cost, strategy = self.net(train_prog, startup_prog)
103 104 105
        optimizer = paddle.fluid.optimizer.Momentum(
            learning_rate=0.01, momentum=0.9
        )
106 107 108 109 110

        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
        optimizer.minimize(avg_cost)

        ops_without_wd = [
111 112 113 114 115 116 117
            op
            for op in avg_cost.block.ops
            if op.type == 'lars_momentum'
            and (
                "batch_norm" in op.attr('op_role_var')[0]
                or ".b" in op.attr('op_role_var')[0]
            )
118 119
        ]
        for op in ops_without_wd:
L
limingshu 已提交
120
            self.assertEqual(op.attr('lars_weight_decay')[0], 0)
121 122 123 124

    def test_lars_apply_with_amp(self):
        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)
125 126 127
        input_x = paddle.fluid.layers.data(
            name="x", shape=[32], dtype='float32'
        )
128 129 130 131 132
        input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')

        fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
        fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
        prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
133 134
        cost = paddle.nn.functional.cross_entropy(
            input=prediction, label=input_y, reduction='none', use_softmax=False
135
        )
136
        avg_cost = paddle.mean(x=cost)
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157

        strategy = paddle.distributed.fleet.DistributedStrategy()
        strategy.amp = True
        strategy.amp_configs = {
            "init_loss_scaling": 32768,
            "decr_every_n_nan_or_inf": 2,
            "incr_every_n_steps": 1000,
            "incr_ratio": 2.0,
            "use_dynamic_loss_scaling": True,
            "decr_ratio": 0.5,
            "custom_white_list": ['softmax'],
            "custom_black_list": ['tanh'],
        }
        strategy.lars = True
        strategy.lars_configs = {
            "lars_coeff": 0.001,
            "lars_weight_decay": 0.0005,
            "epsilon": 0,
            "exclude_from_weight_decay": ["batch_norm", ".b"],
        }

158 159 160
        optimizer = paddle.fluid.optimizer.Momentum(
            learning_rate=0.01, momentum=0.9
        )
161 162 163 164 165 166
        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
        optimizer.minimize(avg_cost)

        ops = [op.type for op in avg_cost.block.ops]
        self.assertIn('lars_momentum', ops)
        self.assertIn('cast', ops)
167
        self.assertIn('check_finite_and_unscale', ops)
168

169 170 171

if __name__ == "__main__":
    unittest.main()