coordinator.py 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.communicator import FLCommunicator
from paddle.distributed.fleet.proto import the_one_ps_pb2
from google.protobuf import text_format
from paddle.distributed.ps.utils.public import is_distributed_env
from paddle.distributed import fleet
import time
import abc
import os
import logging

logger = logging.getLogger(__name__)
K
kuizhiqing 已提交
27 28
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
29 30
    fmt='%(asctime)s %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s'
)
K
kuizhiqing 已提交
31 32 33
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60


class ClientInfoAttr:
    CLIENT_ID = 0
    DEVICE_TYPE = 1
    COMPUTE_CAPACITY = 2
    BANDWIDTH = 3


class FLStrategy:
    JOIN = 0
    WAIT = 1
    FINISH = 2


class ClientSelectorBase(abc.ABC):
    def __init__(self, fl_clients_info_mp):
        self.fl_clients_info_mp = fl_clients_info_mp
        self.clients_info = {}
        self.fl_strategy = {}

    def parse_from_string(self):
        if not self.fl_clients_info_mp:
            logger.warning("fl-ps > fl_clients_info_mp is null!")

        for client_id, info in self.fl_clients_info_mp.items():
            self.fl_client_info_desc = the_one_ps_pb2.FLClientInfo()
61 62 63
            text_format.Parse(
                bytes(info, encoding="utf8"), self.fl_client_info_desc
            )
64 65
            self.clients_info[client_id] = {}
            self.clients_info[client_id][
66 67
                ClientInfoAttr.DEVICE_TYPE
            ] = self.fl_client_info_desc.device_type
68
            self.clients_info[client_id][
69 70
                ClientInfoAttr.COMPUTE_CAPACITY
            ] = self.fl_client_info_desc.compute_capacity
71
            self.clients_info[client_id][
72 73
                ClientInfoAttr.BANDWIDTH
            ] = self.fl_client_info_desc.bandwidth
74 75 76 77 78 79 80 81 82 83 84 85 86 87

    @abc.abstractmethod
    def select(self):
        pass


class ClientSelector(ClientSelectorBase):
    def __init__(self, fl_clients_info_mp):
        super().__init__(fl_clients_info_mp)
        self.__fl_strategy = {}

    def select(self):
        self.parse_from_string()
        for client_id in self.clients_info:
88 89 90 91 92
            logger.info(
                "fl-ps > client {} info : {}".format(
                    client_id, self.clients_info[client_id]
                )
            )
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
            # ......... to implement ...... #
            fl_strategy_desc = the_one_ps_pb2.FLStrategy()
            fl_strategy_desc.iteration_num = 99
            fl_strategy_desc.client_id = 0
            fl_strategy_desc.next_state = "JOIN"
            str_msg = text_format.MessageToString(fl_strategy_desc)
            self.__fl_strategy[client_id] = str_msg
        return self.__fl_strategy


class FLClientBase(abc.ABC):
    def __init__(self):
        pass

    def set_basic_config(self, role_maker, config, metrics):
        self.role_maker = role_maker
        self.config = config
        self.total_train_epoch = int(self.config.get("runner.epochs"))
        self.train_statical_info = dict()
        self.train_statical_info['speed'] = []
        self.epoch_idx = 0
        self.worker_index = fleet.worker_index()
        self.main_program = paddle.static.default_main_program()
        self.startup_program = paddle.static.default_startup_program()
        self._client_ptr = fleet.get_fl_client()
        self._coordinators = self.role_maker._get_coordinator_endpoints()
119 120 121
        logger.info(
            "fl-ps > coordinator enpoints: {}".format(self._coordinators)
        )
122 123 124 125 126 127 128 129 130 131 132 133 134 135
        self.strategy_handlers = dict()
        self.exe = None
        self.use_cuda = int(self.config.get("runner.use_gpu"))
        self.place = paddle.CUDAPlace(0) if self.use_cuda else paddle.CPUPlace()
        self.print_step = int(self.config.get("runner.print_interval"))
        self.debug = self.config.get("runner.dataset_debug", False)
        self.reader_type = self.config.get("runner.reader_type", "QueueDataset")
        self.set_executor()
        self.make_save_model_path()
        self.set_metrics(metrics)

    def set_train_dataset_info(self, train_dataset, train_file_list):
        self.train_dataset = train_dataset
        self.train_file_list = train_file_list
136 137 138 139 140
        logger.info(
            "fl-ps > {}, data_feed_desc:\n {}".format(
                type(self.train_dataset), self.train_dataset._desc()
            )
        )
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170

    def set_test_dataset_info(self, test_dataset, test_file_list):
        self.test_dataset = test_dataset
        self.test_file_list = test_file_list

    def set_train_example_num(self, num):
        self.train_example_nums = num

    def load_dataset(self):
        if self.reader_type == "InmemoryDataset":
            self.train_dataset.load_into_memory()

    def release_dataset(self):
        if self.reader_type == "InmemoryDataset":
            self.train_dataset.release_memory()

    def set_executor(self):
        self.exe = paddle.static.Executor(self.place)

    def make_save_model_path(self):
        self.save_model_path = self.config.get("runner.model_save_path")
        if self.save_model_path and (not os.path.exists(self.save_model_path)):
            os.makedirs(self.save_model_path)

    def set_dump_fields(self):
        # DumpField
        # TrainerDesc -> SetDumpParamVector -> DumpParam -> DumpWork
        if self.config.get("runner.need_dump"):
            self.debug = True
            dump_fields_path = "{}/epoch_{}".format(
171 172
                self.config.get("runner.dump_fields_path"), self.epoch_idx
            )
173 174 175 176 177 178 179 180
            dump_fields = self.config.get("runner.dump_fields", [])
            dump_param = self.config.get("runner.dump_param", [])
            persist_vars_list = self.main_program.all_parameters()
            persist_vars_name = [
                str(param).split(":")[0].strip().split()[-1]
                for param in persist_vars_list
            ]
            logger.info(
181 182
                "fl-ps > persist_vars_list: {}".format(persist_vars_name)
            )
183 184 185

            if dump_fields_path is not None:
                self.main_program._fleet_opt[
186 187
                    'dump_fields_path'
                ] = dump_fields_path
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
            if dump_fields is not None:
                self.main_program._fleet_opt["dump_fields"] = dump_fields
            if dump_param is not None:
                self.main_program._fleet_opt["dump_param"] = dump_param

    def set_metrics(self, metrics):
        self.metrics = metrics
        self.fetch_vars = [var for _, var in self.metrics.items()]


class FLClient(FLClientBase):
    def __init__(self):
        super(FLClient, self).__init__()

    def __build_fl_client_info_desc(self, state_info):
        # ......... to implement ...... #
        state_info = {
            ClientInfoAttr.DEVICE_TYPE: "Andorid",
            ClientInfoAttr.COMPUTE_CAPACITY: 10,
207
            ClientInfoAttr.BANDWIDTH: 100,
208 209 210 211
        }
        client_info = the_one_ps_pb2.FLClientInfo()
        client_info.device_type = state_info[ClientInfoAttr.DEVICE_TYPE]
        client_info.compute_capacity = state_info[
212 213
            ClientInfoAttr.COMPUTE_CAPACITY
        ]
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        client_info.bandwidth = state_info[ClientInfoAttr.BANDWIDTH]
        str_msg = text_format.MessageToString(client_info)
        return str_msg

    def run(self):
        self.register_default_handlers()
        self.print_program()
        self.strategy_handlers['initialize_model_params']()
        self.strategy_handlers['init_worker']()
        self.load_dataset()
        self.train_loop()
        self.release_dataset()
        self.strategy_handlers['finish']()

    def train_loop(self):
        while self.epoch_idx < self.total_train_epoch:
            logger.info("fl-ps > curr epoch idx: {}".format(self.epoch_idx))
            self.strategy_handlers['train']()
            self.strategy_handlers['save_model']()
            self.barrier()
            state_info = {
                "client id": self.worker_index,
                "auc": 0.9,
237
                "epoch": self.epoch_idx,
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
            }
            self.push_fl_client_info_sync(state_info)
            strategy_dict = self.pull_fl_strategy()
            logger.info("fl-ps > recved fl strategy: {}".format(strategy_dict))
            # ......... to implement ...... #
            if strategy_dict['next_state'] == "JOIN":
                self.strategy_handlers['infer']()
            elif strategy_dict['next_state'] == "FINISH":
                self.strategy_handlers['finish']()

    def push_fl_client_info_sync(self, state_info):
        str_msg = self.__build_fl_client_info_desc(state_info)
        self._client_ptr.push_fl_client_info_sync(str_msg)
        return

    def pull_fl_strategy(self):
        strategy_dict = {}
255 256
        fl_strategy_str = (
            self._client_ptr.pull_fl_strategy()
257
        )  # block: wait for coordinator's strategy arrived
258 259 260 261 262
        logger.info(
            "fl-ps > fl client recved fl_strategy(str):\n{}".format(
                fl_strategy_str
            )
        )
263
        fl_strategy_desc = the_one_ps_pb2.FLStrategy()
264 265 266
        text_format.Parse(
            bytes(fl_strategy_str, encoding="utf8"), fl_strategy_desc
        )
267 268 269 270 271 272 273 274 275 276 277 278 279
        strategy_dict["next_state"] = fl_strategy_desc.next_state
        return strategy_dict

    def barrier(self):
        fleet.barrier_worker()

    def register_handlers(self, strategy_type, callback_func):
        self.strategy_handlers[strategy_type] = callback_func

    def register_default_handlers(self):
        self.register_handlers('train', self.callback_train)
        self.register_handlers('infer', self.callback_infer)
        self.register_handlers('finish', self.callback_finish)
280 281 282
        self.register_handlers(
            'initialize_model_params', self.callback_initialize_model_params
        )
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
        self.register_handlers('init_worker', self.callback_init_worker)
        self.register_handlers('save_model', self.callback_save_model)

    def callback_init_worker(self):
        fleet.init_worker()

    def callback_initialize_model_params(self):
        if self.exe == None or self.main_program == None:
            raise AssertionError("exe or main_program not set")
        self.exe.run(self.startup_program)

    def callback_train(self):
        epoch_start_time = time.time()
        self.set_dump_fields()
        fetch_info = [
            "Epoch {} Var {}".format(self.epoch_idx, var_name)
            for var_name in self.metrics
        ]
301 302 303 304 305 306 307 308
        self.exe.train_from_dataset(
            program=self.main_program,
            dataset=self.train_dataset,
            fetch_list=self.fetch_vars,
            fetch_info=fetch_info,
            print_period=self.print_step,
            debug=self.debug,
        )
309 310 311 312 313 314 315 316 317 318 319
        self.epoch_idx += 1
        epoch_time = time.time() - epoch_start_time
        epoch_speed = self.train_example_nums / epoch_time
        self.train_statical_info["speed"].append(epoch_speed)
        logger.info("fl-ps > callback_train finished")

    def callback_infer(self):
        fetch_info = [
            "Epoch {} Var {}".format(self.epoch_idx, var_name)
            for var_name in self.metrics
        ]
320 321 322 323 324 325 326 327
        self.exe.infer_from_dataset(
            program=self.main_program,
            dataset=self.test_dataset,
            fetch_list=self.fetch_vars,
            fetch_info=fetch_info,
            print_period=self.print_step,
            debug=self.debug,
        )
328 329 330 331 332 333 334 335 336 337 338 339 340

    def callback_save_model(self):
        model_dir = "{}/{}".format(self.save_model_path, self.epoch_idx)
        if fleet.is_first_worker() and self.save_model_path:
            if is_distributed_env():
                fleet.save_persistables(self.exe, model_dir)  # save all params
            else:
                raise ValueError("it is not distributed env")

    def callback_finish(self):
        fleet.stop_worker()

    def print_program(self):
341 342 343
        with open(
            "./{}_worker_main_program.prototxt".format(self.worker_index), 'w+'
        ) as f:
344 345
            f.write(str(self.main_program))
        with open(
346 347 348
            "./{}_worker_startup_program.prototxt".format(self.worker_index),
            'w+',
        ) as f:
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
            f.write(str(self.startup_program))

    def print_train_statical_info(self):
        with open("./train_statical_info.txt", 'w+') as f:
            f.write(str(self.train_statical_info))


class Coordinator(object):
    def __init__(self, ps_hosts):
        self._communicator = FLCommunicator(ps_hosts)
        self._client_selector = None

    def start_coordinator(self, self_endpoint, trainer_endpoints):
        self._communicator.start_coordinator(self_endpoint, trainer_endpoints)

    def make_fl_strategy(self):
        logger.info("fl-ps > running make_fl_strategy(loop) in coordinator\n")
        while True:
            # 1. get all fl clients reported info
368 369
            str_map = (
                self._communicator.query_fl_clients_info()
370 371 372 373 374 375 376
            )  # block: wait for all fl clients info reported
            # 2. generate fl strategy
            self._client_selector = ClientSelector(str_map)
            fl_strategy = self._client_selector.select()
            # 3. save fl strategy from python to c++
            self._communicator.save_fl_strategy(fl_strategy)
            time.sleep(5)