test_fleet_api_input.py 10.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

meteor135's avatar
meteor135 已提交
17
from dist_fleet_simnet_bow import train_network
18

19
import paddle
20
import paddle.fluid as fluid
meteor135's avatar
meteor135 已提交
21 22 23 24 25
import paddle.incubate.distributed.fleet.role_maker as role_maker
from paddle.fluid.transpiler.distribute_transpiler import (
    DistributeTranspilerConfig,
)
from paddle.incubate.distributed.fleet.collective import CollectiveOptimizer
26 27 28 29 30

# from paddle.incubate.distributed.fleet.parameter_server import TranspilerOptimizer
from paddle.incubate.distributed.fleet.parameter_server.distribute_transpiler import (
    fleet,
)
meteor135's avatar
meteor135 已提交
31
from paddle.incubate.distributed.fleet.role_maker import (
32
    Role,
33
    UserDefinedCollectiveRoleMaker,
34
    UserDefinedRoleMaker,
35
)
meteor135's avatar
meteor135 已提交
36

37 38 39 40 41 42 43 44 45 46 47

class DistributeTranspilerConfigTest(unittest.TestCase):
    def set_runtime_split_send_recv(self, config, value):
        config.runtime_split_send_recv = value

    def set_sync_mode(self, config, value):
        config.sync_mode = value

    def testConfig(self):
        config = DistributeTranspilerConfig()
        self.assertRaises(Exception, self.set_sync_mode, config, None)
48 49 50 51 52 53
        self.assertRaises(
            Exception, self.set_runtime_split_send_recv, config, None
        )
        self.assertRaises(
            Exception, self.set_runtime_split_send_recv, config, True
        )
54 55 56 57 58 59 60 61 62 63 64
        self.set_sync_mode(config, False)
        self.assertFalse(config.sync_mode)
        self.set_runtime_split_send_recv(config, True)
        self.assertRaises(Exception, self.set_sync_mode, config, True)


class FleetTest(unittest.TestCase):
    def testInvalidInputs(self):
        self.assertRaises(Exception, fleet.split_files, "files")
        self.assertRaises(Exception, fleet.init, "pserver")

G
GGBond8488 已提交
65
        data = paddle.static.data(name='X', shape=[-1, 1], dtype='float32')
C
Charles-hit 已提交
66
        hidden = paddle.static.nn.fc(x=data, size=10)
67
        loss = paddle.mean(hidden)
68 69 70 71 72
        adam = fluid.optimizer.Adam()
        adam.minimize(loss)
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        pe = fluid.ParallelExecutor(use_cuda=False, loss_name=loss.name)
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        self.assertRaises(
            Exception,
            fleet.save_inference_model,
            dirname='/tmp/',
            feeded_var_names=['X'],
            target_vars=[loss],
            executor=pe,
        )
        self.assertRaises(
            Exception,
            fleet.save_inference_model,
            dirname='/tmp/',
            feeded_var_names=['X'],
            target_vars=[loss],
            executor="executor",
        )
89
        compiled_prog = fluid.compiler.CompiledProgram(
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
            fluid.default_main_program()
        )
        self.assertRaises(
            Exception,
            fleet.save_inference_model,
            dirname='/tmp/',
            feeded_var_names=['X'],
            target_vars=[loss],
            executor=exe,
            main_program=compiled_prog,
        )
        self.assertRaises(
            Exception, fleet.save_persistables, executor=pe, dirname='/tmp/'
        )
        self.assertRaises(
            Exception,
            fleet.save_persistables,
            executor="executor",
            dirname='/tmp/',
        )
        self.assertRaises(
            Exception,
            fleet.save_persistables,
            executor=exe,
            dirname='/tmp/',
            main_program=compiled_prog,
        )
meteor135's avatar
meteor135 已提交
117
        # self.assertRaises(Exception, fleet._transpile, "config")
118

C
Chengmo 已提交
119
    def set_program(self, avg_cost, strategy):
120 121 122 123
        with fluid.scope_guard(fluid.Scope()):
            optimizer = fluid.optimizer.SGD(0.1)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
            optimizer.minimize(avg_cost)
C
Chengmo 已提交
124 125 126 127 128 129

    def test_init_role(self):
        role = role_maker.UserDefinedRoleMaker(
            current_id=0,
            role=role_maker.Role.SERVER,
            worker_num=2,
130 131
            server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"],
        )
C
Chengmo 已提交
132 133 134 135 136 137 138 139 140
        # for test optimizer without init(role)
        # fleet.init(role)
        batch_size = 128
        is_sparse = True
        is_distribute = False
        strategy = DistributeTranspilerConfig()
        strategy.sync_mode = False
        strategy.geo_sgd_mode = True
        strategy.geo_sgd_need_push_nums = 5
meteor135's avatar
meteor135 已提交
141
        avg_cost, _, _, _ = train_network(batch_size, is_distribute, is_sparse)
C
Chengmo 已提交
142 143 144

        self.assertRaises(Exception, self.set_program, avg_cost, strategy)

145 146 147 148 149
    def test_transpile(self):
        role = role_maker.UserDefinedRoleMaker(
            current_id=0,
            role=role_maker.Role.SERVER,
            worker_num=2,
150 151
            server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"],
        )
152 153 154 155 156 157 158 159 160
        # for test optimizer without init(role)
        fleet.init(role)
        batch_size = 128
        is_sparse = True
        is_distribute = False

        strategy = DistributeTranspilerConfig()
        strategy.sync_mode = False
        strategy.runtime_split_send_recv = True
meteor135's avatar
meteor135 已提交
161
        avg_cost, _, _, _ = train_network(batch_size, is_distribute, is_sparse)
162 163 164 165 166

        self.set_program(avg_cost, strategy)
        strategy.runtime_split_send_recv = False
        self.set_program(avg_cost, strategy)

167

meteor135's avatar
meteor135 已提交
168
"""
169 170 171
class TranspilerOptimizerTest(unittest.TestCase):
    def testInvalidInputs(self):
        self.assertRaises(Exception, TranspilerOptimizer, "Adam", None)
172 173 174 175 176 177
        self.assertRaises(
            Exception,
            TranspilerOptimizer,
            fluid.optimizer.Adam(0.001),
            "strategy",
        )
178 179 180

        transpiler = TranspilerOptimizer(fluid.optimizer.Adam(0.001))
        self.assertRaises(Exception, transpiler.minimize, loss=[])
G
GGBond8488 已提交
181
        data = paddle.static.data(name='X', shape=[-1, 1], dtype='float32')
C
Charles-hit 已提交
182
        hidden = paddle.static.nn.fc(x=data, size=10)
183
        loss = paddle.mean(hidden)
184 185 186
        self.assertRaises(
            Exception, transpiler.minimize, loss=loss.name, startup_program=[]
        )
meteor135's avatar
meteor135 已提交
187
"""
188 189 190


class UserDefinedRoleMakerTest(unittest.TestCase):
191 192 193 194 195 196 197 198 199 200
    def createRoleMaker(
        self,
        current_id=0,
        role=Role.WORKER,
        worker_num=1,
        server_endpoints=["127.0.0.1:8080"],
    ):
        role = UserDefinedRoleMaker(
            current_id, role, worker_num, server_endpoints
        )
201 202 203

    def testRoleMaker(self):
        self.createRoleMaker()
C
Chengmo 已提交
204
        # test all invalid server_endpoints
205
        self.assertRaises(
206 207 208 209 210 211 212 213
            Exception, self.createRoleMaker, server_endpoints=None
        )  # server_endpoints must be as list
        self.assertRaises(
            Exception, self.createRoleMaker, server_endpoints=[]
        )  # server_endpoints can't be empty
        self.assertRaises(
            Exception, self.createRoleMaker, server_endpoints=[3, []]
        )  # element in server_endpoints must be as string
214
        self.assertRaises(
215 216 217 218
            Exception,
            self.createRoleMaker,
            server_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"],
        )  # element in server_endpoints can't be duplicate
C
Chengmo 已提交
219
        # test all invalid current_id
220
        self.assertRaises(
221 222 223 224 225
            Exception, self.createRoleMaker, current_id="0"
        )  # current_id must be as int
        self.assertRaises(
            Exception, self.createRoleMaker, current_id=-1
        )  # current_id must be greater than or equal to 0
226 227 228 229 230
        self.assertRaises(
            Exception,
            self.createRoleMaker,
            current_id=1,
            role=Role.SERVER,
231
            server_endpoints=["127.0.0.1:8080"],
232
        )  # if role is server, current_id must be less than len(server_endpoints)
C
Chengmo 已提交
233
        # test all invalid worker_num
234 235 236 237 238 239
        self.assertRaises(
            Exception, self.createRoleMaker, worker_num="1"
        )  # worker_num must be as int
        self.assertRaises(
            Exception, self.createRoleMaker, worker_num=0
        )  # worker_num must be greater than 0
C
Chengmo 已提交
240
        # test all invalid role
241
        self.assertRaises(
242 243
            Exception, self.createRoleMaker, role=3
        )  # role must be as Role(Role.WORKER=1, Role.SERVER=2)
244 245 246


class UserDefinedCollectiveRoleMakerTest(unittest.TestCase):
247 248 249
    def createRoleMaker(
        self, current_id=0, worker_endpoints=["127.0.0.1:8080"]
    ):
250 251 252 253
        role = UserDefinedCollectiveRoleMaker(current_id, worker_endpoints)

    def testRoleMaker(self):
        self.createRoleMaker()
C
Chengmo 已提交
254
        # test all invalid worker_endpoints
255
        self.assertRaises(
256 257 258 259 260 261 262 263
            Exception, self.createRoleMaker, worker_endpoints=None
        )  # worker_endpoints must be as list
        self.assertRaises(
            Exception, self.createRoleMaker, worker_endpoints=[]
        )  # worker_endpoints can't be empty
        self.assertRaises(
            Exception, self.createRoleMaker, worker_endpoints=[3, []]
        )  # element worker_endpoints must be as string
264
        self.assertRaises(
265 266 267 268
            Exception,
            self.createRoleMaker,
            worker_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"],
        )  # element in worker_endpoints can't be duplicate
C
Chengmo 已提交
269
        # test all invalid current_id
270
        self.assertRaises(
271 272 273 274 275
            Exception, self.createRoleMaker, current_id="0"
        )  # current_id must be as int
        self.assertRaises(
            Exception, self.createRoleMaker, current_id=-1
        )  # current_id must be greater than or equal to 0
276 277 278 279
        self.assertRaises(
            Exception,
            self.createRoleMaker,
            current_id=1,
280 281
            worker_endpoints=["127.0.0.1:8080"],
        )  # current_id must be less than len(worker_endpoints)
282 283


284 285 286 287 288 289
class CollectiveOptimizerTest(unittest.TestCase):
    def test_ds_as_None(self):
        optimizer = fluid.optimizer.AdamOptimizer()
        dist_optimizer = CollectiveOptimizer(optimizer, strategy=None)


290 291
if __name__ == '__main__':
    unittest.main()