test_fleet_util.py 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
import os


class TestFleetUtil(unittest.TestCase):
    def test_util_base(self):
        import paddle.fleet as fleet
        util = fleet.UtilBase()
        strategy = fleet.DistributedStrategy()
        util._set_strategy(strategy)
        role_maker = None  # should be fleet.PaddleCloudRoleMaker()
        util._set_role_maker(role_maker)

    def test_util_factory(self):
        import paddle.fleet as fleet
        factory = fleet.base.util_factory.UtilFactory()
        strategy = fleet.DistributedStrategy()
        role_maker = None  # should be fleet.PaddleCloudRoleMaker()
        optimize_ops = []
        params_grads = []
        util = factory._create_util(strategy, role_maker, optimize_ops,
                                    params_grads)
        self.assertEqual(util.role_maker, None)

    def test_get_util(self):
        import paddle.fleet as fleet
        import paddle.fluid.incubate.fleet.base.role_maker as role_maker
        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)
        default_util = fleet.util
        self.assertEqual(default_util, None)

    def test_set_user_defined_util(self):
        import paddle.fleet as fleet

        class UserDefinedUtil(fleet.UtilBase):
            def __init__(self):
                super(UserDefinedUtil, self).__init__()

            def get_user_id(self):
                return 10

        import paddle.fluid.incubate.fleet.base.role_maker as role_maker
        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)
        my_util = UserDefinedUtil()
        fleet.util = my_util
        user_id = fleet.util.get_user_id()
        self.assertEqual(user_id, 10)


if __name__ == "__main__":
    unittest.main()