test_fleet_api_input.py 10.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16 17 18

from dist_simnet_bow import train_network

19
import paddle
20
import paddle.fluid as fluid
21
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
22
from paddle.fluid.incubate.fleet.base.role_maker import (
23
    Role,
24
    UserDefinedCollectiveRoleMaker,
25
    UserDefinedRoleMaker,
26
)
27 28 29 30
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer
from paddle.fluid.transpiler.distribute_transpiler import (
    DistributeTranspilerConfig,
)
W
wangzhen38 已提交
31 32
from paddle.incubate.fleet.parameter_server import TranspilerOptimizer
from paddle.incubate.fleet.parameter_server.distribute_transpiler import fleet
33 34 35 36 37 38 39 40 41 42 43 44


class DistributeTranspilerConfigTest(unittest.TestCase):
    def set_runtime_split_send_recv(self, config, value):
        config.runtime_split_send_recv = value

    def set_sync_mode(self, config, value):
        config.sync_mode = value

    def testConfig(self):
        config = DistributeTranspilerConfig()
        self.assertRaises(Exception, self.set_sync_mode, config, None)
45 46 47 48 49 50
        self.assertRaises(
            Exception, self.set_runtime_split_send_recv, config, None
        )
        self.assertRaises(
            Exception, self.set_runtime_split_send_recv, config, True
        )
51 52 53 54 55 56 57 58 59 60 61
        self.set_sync_mode(config, False)
        self.assertFalse(config.sync_mode)
        self.set_runtime_split_send_recv(config, True)
        self.assertRaises(Exception, self.set_sync_mode, config, True)


class FleetTest(unittest.TestCase):
    def testInvalidInputs(self):
        self.assertRaises(Exception, fleet.split_files, "files")
        self.assertRaises(Exception, fleet.init, "pserver")

G
GGBond8488 已提交
62
        data = paddle.static.data(name='X', shape=[-1, 1], dtype='float32')
C
Charles-hit 已提交
63
        hidden = paddle.static.nn.fc(x=data, size=10)
64
        loss = paddle.mean(hidden)
65 66 67 68 69
        adam = fluid.optimizer.Adam()
        adam.minimize(loss)
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        pe = fluid.ParallelExecutor(use_cuda=False, loss_name=loss.name)
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        self.assertRaises(
            Exception,
            fleet.save_inference_model,
            dirname='/tmp/',
            feeded_var_names=['X'],
            target_vars=[loss],
            executor=pe,
        )
        self.assertRaises(
            Exception,
            fleet.save_inference_model,
            dirname='/tmp/',
            feeded_var_names=['X'],
            target_vars=[loss],
            executor="executor",
        )
86
        compiled_prog = fluid.compiler.CompiledProgram(
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
            fluid.default_main_program()
        )
        self.assertRaises(
            Exception,
            fleet.save_inference_model,
            dirname='/tmp/',
            feeded_var_names=['X'],
            target_vars=[loss],
            executor=exe,
            main_program=compiled_prog,
        )
        self.assertRaises(
            Exception, fleet.save_persistables, executor=pe, dirname='/tmp/'
        )
        self.assertRaises(
            Exception,
            fleet.save_persistables,
            executor="executor",
            dirname='/tmp/',
        )
        self.assertRaises(
            Exception,
            fleet.save_persistables,
            executor=exe,
            dirname='/tmp/',
            main_program=compiled_prog,
        )
114 115
        self.assertRaises(Exception, fleet._transpile, "config")

C
Chengmo 已提交
116
    def set_program(self, avg_cost, strategy):
117 118 119 120
        with fluid.scope_guard(fluid.Scope()):
            optimizer = fluid.optimizer.SGD(0.1)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
            optimizer.minimize(avg_cost)
C
Chengmo 已提交
121 122 123 124 125 126

    def test_init_role(self):
        role = role_maker.UserDefinedRoleMaker(
            current_id=0,
            role=role_maker.Role.SERVER,
            worker_num=2,
127 128
            server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"],
        )
C
Chengmo 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141
        # for test optimizer without init(role)
        # fleet.init(role)
        batch_size = 128
        is_sparse = True
        is_distribute = False
        strategy = DistributeTranspilerConfig()
        strategy.sync_mode = False
        strategy.geo_sgd_mode = True
        strategy.geo_sgd_need_push_nums = 5
        avg_cost, _, _ = train_network(batch_size, is_distribute, is_sparse)

        self.assertRaises(Exception, self.set_program, avg_cost, strategy)

142 143 144 145 146
    def test_transpile(self):
        role = role_maker.UserDefinedRoleMaker(
            current_id=0,
            role=role_maker.Role.SERVER,
            worker_num=2,
147 148
            server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"],
        )
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        # for test optimizer without init(role)
        fleet.init(role)
        batch_size = 128
        is_sparse = True
        is_distribute = False

        strategy = DistributeTranspilerConfig()
        strategy.sync_mode = False
        strategy.runtime_split_send_recv = True
        avg_cost, _, _ = train_network(batch_size, is_distribute, is_sparse)

        self.set_program(avg_cost, strategy)
        strategy.runtime_split_send_recv = False
        self.set_program(avg_cost, strategy)

164 165 166 167

class TranspilerOptimizerTest(unittest.TestCase):
    def testInvalidInputs(self):
        self.assertRaises(Exception, TranspilerOptimizer, "Adam", None)
168 169 170 171 172 173
        self.assertRaises(
            Exception,
            TranspilerOptimizer,
            fluid.optimizer.Adam(0.001),
            "strategy",
        )
174 175 176

        transpiler = TranspilerOptimizer(fluid.optimizer.Adam(0.001))
        self.assertRaises(Exception, transpiler.minimize, loss=[])
G
GGBond8488 已提交
177
        data = paddle.static.data(name='X', shape=[-1, 1], dtype='float32')
C
Charles-hit 已提交
178
        hidden = paddle.static.nn.fc(x=data, size=10)
179
        loss = paddle.mean(hidden)
180 181 182
        self.assertRaises(
            Exception, transpiler.minimize, loss=loss.name, startup_program=[]
        )
183 184 185


class UserDefinedRoleMakerTest(unittest.TestCase):
186 187 188 189 190 191 192 193 194 195
    def createRoleMaker(
        self,
        current_id=0,
        role=Role.WORKER,
        worker_num=1,
        server_endpoints=["127.0.0.1:8080"],
    ):
        role = UserDefinedRoleMaker(
            current_id, role, worker_num, server_endpoints
        )
196 197 198

    def testRoleMaker(self):
        self.createRoleMaker()
C
Chengmo 已提交
199
        # test all invalid server_endpoints
200
        self.assertRaises(
201 202 203 204 205 206 207 208
            Exception, self.createRoleMaker, server_endpoints=None
        )  # server_endpoints must be as list
        self.assertRaises(
            Exception, self.createRoleMaker, server_endpoints=[]
        )  # server_endpoints can't be empty
        self.assertRaises(
            Exception, self.createRoleMaker, server_endpoints=[3, []]
        )  # element in server_endpoints must be as string
209
        self.assertRaises(
210 211 212 213
            Exception,
            self.createRoleMaker,
            server_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"],
        )  # element in server_endpoints can't be duplicate
C
Chengmo 已提交
214
        # test all invalid current_id
215
        self.assertRaises(
216 217 218 219 220
            Exception, self.createRoleMaker, current_id="0"
        )  # current_id must be as int
        self.assertRaises(
            Exception, self.createRoleMaker, current_id=-1
        )  # current_id must be greater than or equal to 0
221 222 223 224 225
        self.assertRaises(
            Exception,
            self.createRoleMaker,
            current_id=1,
            role=Role.SERVER,
226
            server_endpoints=["127.0.0.1:8080"],
227
        )  # if role is server, current_id must be less than len(server_endpoints)
C
Chengmo 已提交
228
        # test all invalid worker_num
229 230 231 232 233 234
        self.assertRaises(
            Exception, self.createRoleMaker, worker_num="1"
        )  # worker_num must be as int
        self.assertRaises(
            Exception, self.createRoleMaker, worker_num=0
        )  # worker_num must be greater than 0
C
Chengmo 已提交
235
        # test all invalid role
236
        self.assertRaises(
237 238
            Exception, self.createRoleMaker, role=3
        )  # role must be as Role(Role.WORKER=1, Role.SERVER=2)
239 240 241


class UserDefinedCollectiveRoleMakerTest(unittest.TestCase):
242 243 244
    def createRoleMaker(
        self, current_id=0, worker_endpoints=["127.0.0.1:8080"]
    ):
245 246 247 248
        role = UserDefinedCollectiveRoleMaker(current_id, worker_endpoints)

    def testRoleMaker(self):
        self.createRoleMaker()
C
Chengmo 已提交
249
        # test all invalid worker_endpoints
250
        self.assertRaises(
251 252 253 254 255 256 257 258
            Exception, self.createRoleMaker, worker_endpoints=None
        )  # worker_endpoints must be as list
        self.assertRaises(
            Exception, self.createRoleMaker, worker_endpoints=[]
        )  # worker_endpoints can't be empty
        self.assertRaises(
            Exception, self.createRoleMaker, worker_endpoints=[3, []]
        )  # element worker_endpoints must be as string
259
        self.assertRaises(
260 261 262 263
            Exception,
            self.createRoleMaker,
            worker_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"],
        )  # element in worker_endpoints can't be duplicate
C
Chengmo 已提交
264
        # test all invalid current_id
265
        self.assertRaises(
266 267 268 269 270
            Exception, self.createRoleMaker, current_id="0"
        )  # current_id must be as int
        self.assertRaises(
            Exception, self.createRoleMaker, current_id=-1
        )  # current_id must be greater than or equal to 0
271 272 273 274
        self.assertRaises(
            Exception,
            self.createRoleMaker,
            current_id=1,
275 276
            worker_endpoints=["127.0.0.1:8080"],
        )  # current_id must be less than len(worker_endpoints)
277 278


279 280 281 282 283 284
class CollectiveOptimizerTest(unittest.TestCase):
    def test_ds_as_None(self):
        optimizer = fluid.optimizer.AdamOptimizer()
        dist_optimizer = CollectiveOptimizer(optimizer, strategy=None)


285 286
if __name__ == '__main__':
    unittest.main()