the_one_ps.py 56.6 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17 18 19 20 21
import warnings

import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
22
from paddle.fluid.framework import Program
T
tangwei12 已提交
23
from paddle.fluid.parallel_executor import ParallelExecutor
24

T
tangwei12 已提交
25
from ..base.private_helper_function import wait_server_ready
26
from .runtime_base import RuntimeBase
T
tangwei12 已提交
27

28 29
__all__ = []

T
tangwei12 已提交
30 31 32 33 34

def conv_indent(indent):
    return "".join([" "] * indent)


T
tangwei12 已提交
35
PSERVER_SAVE_SUFFIX = ".shard"
36 37


T
Thunderbrook 已提交
38
def parse_table_class(varname, o_main_program):
39 40 41 42
    from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
        is_distributed_sparse_op,
        is_sparse_op,
    )
T
Thunderbrook 已提交
43 44 45 46 47 48 49

    for op in o_main_program.global_block().ops:
        if not is_distributed_sparse_op(op) and not is_sparse_op(op):
            continue

        param_name = op.input("W")[0]

50 51 52 53 54
        if (
            param_name == varname
            and op.type == "lookup_table"
            or op.type == "lookup_table_v2"
        ):
T
Thunderbrook 已提交
55 56 57
            if op.has_attr('table_class') and op.attr("table_class") != "none":
                return op.attr('table_class')
            else:
58
                return "MemorySparseTable"
T
Thunderbrook 已提交
59 60


61 62 63 64 65 66 67
def get_default_accessor_proto(accessor, varname, o_main_program):
    embedding_dim = 0
    for var in o_main_program.list_vars():
        if var.name == varname:
            embedding_dim = var.shape[1]
            break

68 69 70
    if not accessor.HasField("accessor_class"):
        accessor.accessor_class = "CtrCommonAccessor"
    if not accessor.HasField("fea_dim"):
71
        accessor.fea_dim = embedding_dim
72
    if not accessor.HasField("embedx_dim"):
73
        accessor.embedx_dim = embedding_dim - 3
74 75
    if not accessor.HasField("embedx_threshold"):
        accessor.embedx_threshold = 0
76 77

    ctr_accessor_param = accessor.ctr_accessor_param
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    if not ctr_accessor_param.HasField("nonclk_coeff"):
        ctr_accessor_param.nonclk_coeff = 0.1
    if not ctr_accessor_param.HasField("click_coeff"):
        ctr_accessor_param.click_coeff = 1.0
    if not ctr_accessor_param.HasField("base_threshold"):
        ctr_accessor_param.base_threshold = 0
    if not ctr_accessor_param.HasField("delta_threshold"):
        ctr_accessor_param.delta_threshold = 0
    if not ctr_accessor_param.HasField("delta_keep_days"):
        ctr_accessor_param.delta_keep_days = 16
    if not ctr_accessor_param.HasField("show_click_decay_rate"):
        ctr_accessor_param.show_click_decay_rate = 1
    if not ctr_accessor_param.HasField("delete_threshold"):
        ctr_accessor_param.delete_threshold = 0
    if not ctr_accessor_param.HasField("delete_after_unseen_days"):
        ctr_accessor_param.delete_after_unseen_days = 30
    if not ctr_accessor_param.HasField("ssd_unseenday_threshold"):
        ctr_accessor_param.ssd_unseenday_threshold = 1

    for sgd_param in [accessor.embed_sgd_param, accessor.embedx_sgd_param]:
        if not sgd_param.HasField("name"):
            sgd_param.name = "SparseAdaGradSGDRule"
100 101 102 103
        if (
            sgd_param.name == "SparseAdaGradSGDRule"
            or sgd_param.name == "StdAdaGradSGDRule"
        ):
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
            if not sgd_param.adagrad.HasField("learning_rate"):
                sgd_param.adagrad.learning_rate = 0.05
            if not sgd_param.adagrad.HasField("initial_g2sum"):
                sgd_param.adagrad.initial_g2sum = 3.0
            if not sgd_param.adagrad.HasField("initial_range"):
                sgd_param.adagrad.initial_range = 0.0001
            if len(sgd_param.adagrad.weight_bounds) == 0:
                sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0])
        if sgd_param.name == "SparseNaiveSGDRule":
            if not sgd_param.naive.HasField("learning_rate"):
                sgd_param.naive.learning_rate = 0.05
            if not sgd_param.naive.HasField("initial_range"):
                sgd_param.naive.initial_range = 0.0001
            if len(sgd_param.naive.weight_bounds) == 0:
                sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
        if sgd_param.name == "SparseAdamSGDRule":
            if not sgd_param.adam.HasField("learning_rate"):
                sgd_param.adam.learning_rate = 0.001
            if not sgd_param.adam.HasField("initial_range"):
                sgd_param.adam.initial_range = 0.0001
            if not sgd_param.adam.HasField("beta1_decay_rate"):
                sgd_param.adam.beta1_decay_rate = 0.9
            if not sgd_param.adam.HasField("beta2_decay_rate"):
                sgd_param.adam.beta2_decay_rate = 0.999
            if not sgd_param.adam.HasField("ada_epsilon"):
                sgd_param.adam.ada_epsilon = 1e-08
            if len(sgd_param.adam.weight_bounds) == 0:
                sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
132 133 134 135 136 137 138 139 140


def check_embedding_dim(accessor, varname, o_main_program):
    embedding_dim = 0
    for var in o_main_program.list_vars():
        if var.name == varname:
            embedding_dim = var.shape[1]
            break
    fea_dim = accessor.fea_dim
141
    if fea_dim != embedding_dim:
142
        raise ValueError(
143 144 145 146
            "The fea_dim is wrong, it will be sparse_embedding_dim: {}, but got {}".format(
                embedding_dim, fea_dim
            )
        )
147
    embedx_dim = accessor.embedx_dim
148
    if embedx_dim != embedding_dim - 3:
149
        raise ValueError(
150 151 152 153
            "The embedx_dim is wrong, it will be sparse_embedding_dim - 3: {}, but got {}".format(
                embedding_dim - 3, embedx_dim
            )
        )
154 155


T
tangwei12 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
class Accessor:
    def __init__(self):
        self.accessor_class = ""
        self.optimizer = None
        self.feature_dim = -1
        self.embedding_dim = -1
        self.optimizer = None

    def to_string(self, indent):
        accessor_str = "{}accessor {{{}\n{}}}"
        attrs = ""
        attrs += "accessor_class: \"{}\" ".format(self.accessor_class)
        attrs += "fea_dim: {} ".format(self.feature_dim)
        attrs += "embedx_dim: {} ".format(self.embedding_dim)
        attrs += "\n"
        if self.optimizer is not None:
            attrs += self.optimizer.to_string(indent)
173 174 175
        return accessor_str.format(
            conv_indent(indent), attrs, conv_indent(indent)
        )
T
tangwei12 已提交
176 177 178 179 180 181


class CommonAccessor:
    def __init__(self):
        self.accessor_class = ""
        self.table_name = None
T
tangwei12 已提交
182
        self.entry = None
T
tangwei12 已提交
183 184 185 186 187
        self.attrs = []
        self.params = []
        self.dims = []
        self.trainer_num = 0
        self.sync = "false"
188 189
        self.table_num = None
        self.table_dim = None
T
tangwei12 已提交
190 191 192 193 194 195 196 197 198
        self.initializers = []
        self.opt_input_map = {}
        self.opt_attr_map = {}
        self.opt_init_map = {}
        self.define_optimize_map()

    def define_optimize_map(self):
        opt_input_map = {}
        opt_input_map["sgd"] = [("Param", None), ("LearningRate", 1)]
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
        opt_input_map["adam"] = [
            ("Param", None),
            ("Moment1", None),
            ("Moment2", None),
            ("Beta1Pow", 1),
            ("Beta2Pow", 1),
            ("LearningRate", 1),
        ]
        opt_input_map["adam_d2sum"] = [
            ("Param", None),
            ("D2Sum", None),
            ("G2Sum", None),
            ("Moment", None),
            ("MomentDecayRate", 1),
            ("AdaDecayRate", 1),
            ("AdaEpsilon", 1),
            ("LearningRate", 1),
        ]
T
tangwei12 已提交
217
        opt_input_map["sum"] = [("Param", None)]
218 219 220 221 222
        opt_input_map["naive_adagrad"] = [
            ("Param", None),
            ("G2Sum", 1),
            ("LearningRate", 1),
        ]
T
tangwei12 已提交
223 224 225 226

        opt_attr_map = {}
        opt_attr_map["sgd"] = []
        opt_attr_map["sum"] = []
T
Thunderbrook 已提交
227
        opt_attr_map["naive_adagrad"] = []
228 229 230 231 232 233 234 235 236 237
        opt_attr_map["adam"] = [
            ("beta1", "f"),
            ("beta2", "f"),
            ("epsilon", "f"),
        ]
        opt_attr_map["adam_d2sum"] = [
            ("beta1", "f"),
            ("beta2", "f"),
            ("epsilon", "f"),
        ]
T
tangwei12 已提交
238 239 240 241 242 243 244 245 246 247 248

        opt_init_map = {}
        opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
        opt_init_map["fill_constant"] = ["value"]
        opt_init_map["uniform_random"] = ["seed", "min", "max"]
        opt_init_map["truncated_gaussian_random"] = ["seed", "mean", "std"]

        self.opt_attr_map = opt_attr_map
        self.opt_input_map = opt_input_map
        self.opt_init_map = opt_init_map

T
tangwei12 已提交
249
    def parse_entry(self, varname, o_main_program):
250 251 252 253
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            is_distributed_sparse_op,
            is_sparse_op,
        )
T
tangwei12 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

        for op in o_main_program.global_block().ops:
            if not is_distributed_sparse_op(op) and not is_sparse_op(op):
                continue

            param_name = op.input("W")[0]

            if param_name == varname and op.type == "lookup_table":
                self.entry = op.attr('entry')
                break

            if param_name == varname and op.type == "lookup_table_v2":
                self.entry = "none"
                break

T
tangwei12 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    def get_shard(self, total_dim, shard_num, pserver_id):
        # remainder = total_dim % shard_num
        blocksize = int(total_dim / shard_num + 1)

        if blocksize * (pserver_id + 1) <= total_dim:
            return blocksize
        else:
            if blocksize * pserver_id < total_dim:
                return total_dim - blocksize * pserver_id
            else:
                return 0

    def get_initializer_attr(self, value_name, o_startup_program):
        l_in = "&"
        attr_str = ""

        origin_var_name = value_name
        for op in o_startup_program.global_block().ops:
287 288 289 290
            if (
                op.type in self.opt_init_map.keys()
                and origin_var_name == op.output("Out")[0]
            ):
T
tangwei12 已提交
291 292 293 294 295 296 297
                init_attr = [op.type]
                for attr in self.opt_init_map[op.type]:
                    init_attr.append(str(op.attr(attr)))
                attr_str = l_in.join(init_attr)
                break
        return attr_str

298 299 300 301 302 303 304 305 306 307 308 309 310
    def parse_by_optimizer(
        self,
        grad_name,
        is_sparse,
        size,
        single_dim,
        compiled_strategy,
        adam_d2sum,
    ):
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_optimize_ops,
        )

T
tangwei12 已提交
311 312 313 314 315 316 317 318
        param_name = compiled_strategy.grad_name_to_param_name[grad_name]
        main_program, startup_program = compiled_strategy.get_origin_programs()
        pserver_id = compiled_strategy.get_role_id()
        pserver_num = len(compiled_strategy.get_ps_endpoints())
        optimizer_ops = _get_optimize_ops(main_program)
        oop = None

        for op in optimizer_ops:
319 320 321
            if ("Param" in op.input_names) and (
                op.input("Param")[0] == param_name
            ):
T
tangwei12 已提交
322 323 324 325 326 327 328 329 330 331 332 333
                oop = op
                break

        if oop is None:
            raise ValueError("can not find optimizer for {}".format(grad_name))

        params = []
        dims = []
        attrs = []
        initializers = []

        self.trainer_num = compiled_strategy.get_trainers()
334 335
        self.table_num = size
        self.table_dim = single_dim
T
tangwei12 已提交
336

337
        if oop.type != 'adam' and adam_d2sum:
338 339 340
            print('optimization algorithm is not adam, set adam_d2sum False')
            adam_d2sum = False
        print("adam_d2sum:", adam_d2sum)
T
tangwei12 已提交
341 342 343 344
        if compiled_strategy.is_geo_mode():
            param_varnames = self.opt_input_map["sum"]
            attr_varnames = self.opt_attr_map["sum"]
            self.accessor_class = "sum"
T
Thunderbrook 已提交
345 346 347 348
        elif compiled_strategy.use_ps_gpu and is_sparse:
            param_varnames = self.opt_input_map["naive_adagrad"]
            attr_varnames = self.opt_attr_map["naive_adagrad"]
            self.accessor_class = "sgd"
349
        elif adam_d2sum and not is_sparse:
350 351 352
            param_varnames = self.opt_input_map["adam_d2sum"]
            attr_varnames = self.opt_attr_map["adam_d2sum"]
            self.accessor_class = "adam_d2sum"
T
tangwei12 已提交
353 354 355 356 357 358 359
        else:
            param_varnames = self.opt_input_map[oop.type]
            attr_varnames = self.opt_attr_map[oop.type]
            self.accessor_class = oop.type

        for (formal_name, shape) in param_varnames:
            params.append(formal_name)
360
            if self.accessor_class == "adam_d2sum":
361
                # for dims
T
Thunderbrook 已提交
362 363
                if shape is None:
                    if is_sparse:
364
                        shape = single_dim
T
Thunderbrook 已提交
365
                    else:
366
                        shape = self.get_shard(size, pserver_num, pserver_id)
T
Thunderbrook 已提交
367 368
                dims.append(shape)

369
                # for initializers
370
                if formal_name == "Param" or formal_name == "LearningRate":
371 372 373 374 375 376 377 378
                    param = main_program.global_block().vars[
                        oop.input(formal_name)[0]
                    ]
                    # TODO: for dense learning_rate, can be different from sparse lr
                    if (
                        formal_name == "LearningRate"
                        and param.name != "learning_rate_0"
                    ):
379
                        warnings.warn("will support decay soon")
380 381 382
                        param = main_program.global_block().vars[
                            "learning_rate_0"
                        ]
383

384
                    initializer = self.get_initializer_attr(
385 386
                        param.name, startup_program
                    )
387 388 389 390 391 392 393 394
                elif formal_name == "MomentDecayRate":
                    initializer = "fill_constant&0.99"
                elif formal_name == "AdaDecayRate":
                    initializer = "fill_constant&0.9999"
                elif formal_name == "AdaEpsilon":
                    initializer = "fill_constant&1.0e-8"
                else:
                    initializer = "fill_constant&0"
T
Thunderbrook 已提交
395
                initializers.append(initializer)
396 397 398 399 400 401
            else:
                if formal_name == "G2Sum":
                    dims.append(1)
                    initializer = "fill_constant&0"
                    initializers.append(initializer)
                else:
402 403 404 405 406 407 408
                    param = main_program.global_block().vars[
                        oop.input(formal_name)[0]
                    ]
                    if (
                        formal_name == "LearningRate"
                        and param.name != "learning_rate_0"
                    ):
409
                        warnings.warn("will support decay soon")
410 411 412
                        param = main_program.global_block().vars[
                            "learning_rate_0"
                        ]
413 414 415

                    if shape is None:
                        if is_sparse:
416
                            shape = single_dim
417
                        else:
418 419 420
                            shape = self.get_shard(
                                size, pserver_num, pserver_id
                            )
421 422
                    dims.append(shape)

423
                    initializer = self.get_initializer_attr(
424 425
                        param.name, startup_program
                    )
426
                    initializers.append(initializer)
T
tangwei12 已提交
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444

        for (attr_varname, type_) in attr_varnames:
            value = oop.attr(attr_varname)
            attrs.append("&".join([attr_varname, type_, str(value)]))

        self.params = params
        self.dims = dims
        self.initializers = initializers
        self.attrs = attrs

    def to_string(self, indent):
        accessor_str = "{}common {{{}\n{}}}"
        attrs = ""
        attrs += "name: \"{}\" ".format(self.accessor_class)

        if self.table_name:
            attrs += "table_name: \"{}\" ".format(self.table_name)

T
tangwei12 已提交
445 446
        if self.entry:
            attrs += "entry: \"{}\" ".format(self.entry)
T
tangwei12 已提交
447 448
        attrs += "trainer_num: {} ".format(self.trainer_num)
        attrs += "sync: {} ".format(self.sync)
449 450 451 452
        if self.table_num:
            attrs += "table_num: {} ".format(self.table_num)
        if self.table_dim:
            attrs += "table_dim: {} ".format(self.table_dim)
T
tangwei12 已提交
453 454 455 456 457 458 459 460 461 462 463

        for param in self.params:
            attrs += "params: \"{}\" ".format(param)

        for dim in self.dims:
            attrs += "dims: {} ".format(dim)

        for initializer in self.initializers:
            attrs += "initializers: \"{}\" ".format(initializer)

        attrs += "\n"
464 465 466
        return accessor_str.format(
            conv_indent(indent), attrs, conv_indent(indent)
        )
T
tangwei12 已提交
467 468


469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
class Tensor:
    def __init__(self):
        self.main_program_id = None
        self.startup_program_id = None
        self.feed_var_name = None
        self.fetch_var_name = None
        self.tensor_table_class = False

    def to_string(self, indent):
        program_str = "{}tensor {{{}\n{}}}"
        attrs = ""
        attrs += "feed_var_name: \"{}\" ".format(str(self.feed_var_name))
        attrs += "fetch_var_name: \"{}\" ".format(str(self.fetch_var_name))
        attrs += "startup_program_id: {} ".format(str(self.startup_program_id))
        attrs += "main_program_id: {} ".format(str(self.main_program_id))
        attrs += "tensor_table_class: \"{}\" ".format(
485 486
            str(self.tensor_table_class)
        )
487
        attrs += "\n"
488 489 490
        return program_str.format(
            conv_indent(indent), attrs, conv_indent(indent)
        )
491 492


T
tangwei12 已提交
493 494 495 496 497 498 499 500
class Table:
    def __init__(self):
        self.id = -1
        self.table_class = None
        self.shard_num = -1
        self.type = None
        self.accessor = None
        self.common = None
501
        self.tensor = None
502
        self.accessor_proto = None
T
tangwei12 已提交
503 504

    def to_string(self, indent):
505 506 507 508 509
        # if self.id == 1:
        #     proto_txt = ''
        #     with open('./sparse_table.prototxt') as f:
        #         proto_txt = f.read()
        #     return proto_txt
T
tangwei12 已提交
510 511 512 513 514 515 516 517 518 519
        table_str = "{}downpour_table_param {{{}\n{}}}"

        attrs = ""
        attrs += "table_id: {} ".format(self.id)
        attrs += "table_class: \"{}\" ".format(self.table_class)
        attrs += "shard_num: {} ".format(self.shard_num)
        attrs += "type: {}".format(self.type)
        attrs += "\n"
        indent += 2

520 521
        if self.accessor_proto is not None:
            accessor_str = "{}accessor {{{}\n{}}}"
522 523 524
            accessor_str = accessor_str.format(
                conv_indent(indent), self.accessor_proto, conv_indent(indent)
            )
525
            attrs += accessor_str + "\n"
526
        elif self.accessor is not None:
T
tangwei12 已提交
527 528 529
            attrs += self.accessor.to_string(indent)
            attrs += "\n"

530 531 532 533
        if self.tensor is not None:
            attrs += self.tensor.to_string(indent)
            attrs += "\n"

T
tangwei12 已提交
534 535 536 537 538 539 540 541 542 543 544
        if self.common is not None:
            attrs += self.common.to_string(indent)
            attrs += "\n"

        return table_str.format(conv_indent(indent), attrs, conv_indent(indent))


class Service:
    def __init__(self):
        self.server_class = "BrpcPsServer"
        self.client_class = "BrpcPsClient"
T
tangwei12 已提交
545
        self.service_class = "BrpcPsService"
T
tangwei12 已提交
546 547 548 549 550 551 552 553 554 555 556 557 558
        self.start_server_port = 0
        self.server_thread_num = 12

    def to_string(self, indent):
        service_str = "{}service_param {{{}\n{}}}"

        attrs = ""
        attrs += "server_class: \"{}\" ".format(self.server_class)
        attrs += "client_class: \"{}\" ".format(self.client_class)
        attrs += "service_class: \"{}\" ".format(self.service_class)
        attrs += "start_server_port: {} ".format(self.start_server_port)
        attrs += "server_thread_num: {} ".format(self.server_thread_num)

559 560 561
        return service_str.format(
            conv_indent(indent), attrs, conv_indent(indent)
        )
T
tangwei12 已提交
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588


class DownpourServer:
    def __init__(self):
        self.service = None
        self.tables = []

    def set_service_param(self, service):
        self.service = service

    def append_tables(self, table):
        if not isinstance(table, Table):
            raise ValueError("only support instance Table")
        self.tables.append(table)

    def to_string(self, indent):
        server_str = "{}downpour_server_param {{{}\n{}}}"

        table_strs = ""
        indent += 2

        table_strs += "\n"
        table_strs += self.service.to_string(indent)

        for table in self.tables:
            table_strs += "\n"
            table_strs += table.to_string(indent)
589 590 591
        return server_str.format(
            conv_indent(indent), table_strs, conv_indent(indent)
        )
T
tangwei12 已提交
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630


class Server:
    def __init__(self):
        self.servers = []

    def add_server(self, server):
        if not isinstance(server, DownpourServer):
            raise ValueError("only support instance DownpourServer")
        self.servers.append(server)

    def __str__(self):
        server_str = "server_param {{{}\n}}"
        indent = 2
        servers_str = ""
        for server in self.servers:
            servers_str += "\n"
            servers_str += server.to_string(indent)

        return server_str.format(servers_str)


class DownpourWorker:
    def __init__(self):
        self.tables = []

    def append_tables(self, table):
        if not isinstance(table, Table):
            raise ValueError("only support instance Table")
        self.tables.append(table)

    def to_string(self, indent):
        worker_str = "{}downpour_worker_param {{{}\n{}}}"
        table_strs = ""
        indent += 2
        for table in self.tables:
            table_strs += "\n"
            table_strs += table.to_string(indent)

631 632 633
        return worker_str.format(
            conv_indent(indent), table_strs, conv_indent(indent)
        )
T
tangwei12 已提交
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655


class Worker:
    def __init__(self):
        self.workers = []

    def add_worker(self, worker):
        if not isinstance(worker, DownpourWorker):
            raise ValueError("only support instance DownpourWorker")
        self.workers.append(worker)

    def __str__(self):
        worker_str = "worker_param {{{}\n}}"
        indent = 2
        workers_str = ""
        for worker in self.workers:
            workers_str += "\n"
            workers_str += worker.to_string(indent)

        return worker_str.format(workers_str)


656 657 658 659 660 661 662 663 664 665
class fsClient:
    def __init__(self, proto):
        self.proto = proto
        self.uri = proto.uri
        self.user = proto.user
        self.passwd = proto.passwd
        self.hadoop_bin = proto.hadoop_bin

    def to_string(self):
        from google.protobuf import text_format
666

667 668 669 670 671 672 673 674
        proto_txt = text_format.MessageToString(self.proto)
        if proto_txt:
            fs_str = "fs_client_param {{\n{}}}"
            return fs_str.format(proto_txt)
        else:
            return ""


T
tangwei12 已提交
675 676
class TheOnePSRuntime(RuntimeBase):
    def __init__(self):
677
        super().__init__()
T
tangwei12 已提交
678 679 680
        self._communicator = None
        self._server = None
        self._worker = fluid.core.DistFleetWrapper()
681
        self._server_sub_program = []
T
tangwei12 已提交
682 683 684 685 686 687 688 689 690 691 692 693 694
        self._heter_client = None

    def _set_basic_info(self, context):
        self.context = context
        self.role_maker = context["role_maker"]
        self.origin_main_program = context["origin_main_program"]
        self.origin_startup_program = context["origin_startup_program"]
        self.async_strategy = self._get_distributed_strategy()
        self.compiled_strategy = self.build_compiled_startegy()

    def _get_distributed_strategy(self):
        strategy = None

695 696 697
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            StrategyFactory,
        )
T
tangwei12 已提交
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713

        dist_strategy = self.context["valid_strategy"]
        k_steps = dist_strategy.a_sync_configs["k_steps"]

        if not dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if dist_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

T
Thunderbrook 已提交
714 715
        if dist_strategy.a_sync_configs["use_ps_gpu"]:
            strategy.use_ps_gpu = True
T
tangwei12 已提交
716 717 718
        return strategy

    def build_compiled_startegy(self):
719 720 721 722 723 724 725 726 727 728
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            CompileTimeStrategy,
        )

        compiled_config = CompileTimeStrategy(
            self.origin_main_program,
            self.origin_main_program,
            self.async_strategy,
            self.role_maker,
        )
T
Thunderbrook 已提交
729 730
        if self.async_strategy.use_ps_gpu:
            compiled_config.use_ps_gpu = True
T
tangwei12 已提交
731 732 733
        return compiled_config

    def _init_worker(self):
734 735 736
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            SyncStrategy,
        )
T
tangwei12 已提交
737 738 739 740 741

        is_sync = self.compiled_strategy.is_sync_mode()
        worker = self._get_fleet_proto(is_server=False, is_sync=is_sync)
        server = self._get_fleet_proto(is_server=True, is_sync=is_sync)

T
Thunderbrook 已提交
742 743 744 745 746 747 748 749
        dist_strategy = self.context["valid_strategy"]
        use_ps_gpu = dist_strategy.a_sync_configs["use_ps_gpu"]
        if use_ps_gpu:
            main_program = self.context['loss'].block.program
            if not main_program._fleet_opt:
                main_program._fleet_opt = {}
            main_program._fleet_opt["use_ps_gpu"] = True
            gpus_env = os.getenv("FLAGS_selected_gpus")
750 751 752
            main_program._fleet_opt["worker_places"] = [
                int(s) for s in gpus_env.split(",")
            ]
T
Thunderbrook 已提交
753

T
tangwei12 已提交
754 755 756
        def sync_strategy_envs():
            kwargs = {}
            kwargs[
757 758
                "pserver_endpoints"
            ] = self.role_maker._get_pserver_endpoints()
T
tangwei12 已提交
759 760 761 762
            kwargs["trainer_id"] = self.role_maker._worker_index()
            return kwargs

        proto_txt = str(worker) + "\n" + str(server)
763 764
        with open('proto_txt', 'w') as f:
            f.write(proto_txt)
T
tangwei12 已提交
765 766 767 768 769 770 771 772 773 774 775 776 777 778 779

        debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))

        if debug:
            print("worker: \n{}".format(proto_txt))

        endpoints = self.compiled_strategy.get_ps_endpoints()

        string_hosts = []
        for idx, ep in enumerate(endpoints):
            host, port = ep.split(":")
            pshost = fluid.core.PSHost(host, int(port), idx)
            string_hosts.append(pshost.serialize_to_string())

        dense_map = self.compiled_strategy.get_the_one_recv_context(
780 781
            split_dense_table=self.role_maker._is_heter_parameter_server_mode
        )
T
tangwei12 已提交
782 783
        send_ctx = self.compiled_strategy.get_the_one_send_context(
            split_dense_table=self.role_maker._is_heter_parameter_server_mode,
Z
zmx 已提交
784
            use_origin_program=self.role_maker._is_heter_parameter_server_mode,
785 786
            ep_list=endpoints,
        )
T
tangwei12 已提交
787 788 789 790 791 792 793 794 795 796 797 798 799 800 801
        trainer_config = self.async_strategy.get_trainer_runtime_config()

        debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
        if debug:
            print("worker: \n{}".format(proto_txt))
            print("communicator send_ctx:")
            for key in send_ctx:
                print("{}: {}".format(key, send_ctx[key]))
            for key in dense_map:
                print("{}: {}".format(key, dense_map[key]))

        kwargs = {}
        kwargs['need_global_step'] = "0"
        kwargs["trainer_id"] = self.role_maker._role_id()
        kwargs["trainers"] = self.role_maker._worker_num()
802
        # if self.role_maker._is_heter_worker():
803
        #    kwargs["trainer_id"] += kwargs["trainers"]
T
tangwei12 已提交
804 805 806 807 808 809 810 811 812 813 814

        for table in server.servers[0].tables:
            if table.table_class == "BarrierTable":
                kwargs["barrier_table_id"] = table.id
                break

        if isinstance(self.async_strategy, SyncStrategy):
            sync_kwargs = sync_strategy_envs()
            kwargs.update(sync_kwargs)

        from paddle.fluid.communicator import Communicator, HeterClient
815

T
tangwei12 已提交
816
        self._communicator = Communicator(
817 818 819 820 821
            trainer_config.mode, kwargs, trainer_config.get_communicator_flags()
        )
        self._communicator.init_with_ctx(
            send_ctx, dense_map, proto_txt, string_hosts, fluid.global_scope()
        )
T
tangwei12 已提交
822

823
        import paddle.distributed.fleet as fleet
824

825 826 827 828 829 830 831 832 833
        fleet.util.barrier()
        info = self._communicator.get_client_info()
        if isinstance(info, list) and len(info) > 0:
            all_info = self.role_maker._all_gather(info[0])
            # for unittest
            if not isinstance(all_info, list):
                warnings.warn("gloo may not initialize correctly")
                all_info = [all_info]
            self._communicator.set_clients(all_info)
834
            # create_c2c_connection default param:
835 836 837 838 839 840 841 842
            #  pserver_timeout_ms=500000
            #  pserver_connect_timeout_ms=10000
            #  max_retry=3
            self._communicator.create_client_to_client_connection()
            print('create c2c connection done')
        else:
            print('cannot create c2c connection')

T
tangwei12 已提交
843 844 845 846
        dist_strategy = self.context["valid_strategy"]

        is_test = bool(int(os.getenv("TEST_MODE", "0")))

847 848 849 850
        if (
            self.role_maker._is_first_worker()
            and self.role_maker._is_heter_parameter_server_mode
        ):
T
tangwei12 已提交
851 852
            # for ps-heter mode load all parameters on first_worker
            init_params = self.compiled_strategy.get_the_one_recv_context(
853 854
                split_dense_table=True, use_origin_program=True
            )
T
tangwei12 已提交
855 856 857 858 859
        else:
            init_params = dense_map

        if not is_test:
            self._communicator.init_params(init_params)
Z
zhaocaibei123 已提交
860 861 862
            fleet.util.barrier()
        self._communicator.pull_dense(init_params)
        fleet.util.barrier()
T
tangwei12 已提交
863 864 865 866 867 868 869 870 871 872 873

        if not self._communicator.is_running():
            self._communicator.start()
        else:
            warnings.warn("communicator has been initialized, skip")

        launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
        launch_barrier_flag = int(os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
        if launch_barrier and launch_barrier_flag:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())
874 875 876 877
            if (
                self.role_maker._is_heter_parameter_server_mode
                and self.role_maker._get_next_trainers() != []
            ):
878 879 880 881 882 883 884 885
                wait_server_ready(self.role_maker._get_next_trainers())
            if self.role_maker._is_heter_parameter_server_mode:
                previous_trainers = []
                if self.role_maker._get_previous_trainers() != []:
                    previous_trainers = self.role_maker._get_previous_trainers()
                next_trainers = []
                if self.role_maker._get_next_trainers() != []:
                    next_trainers = self.role_maker._get_next_trainers()
886 887 888 889 890 891 892
                self._heter_client = HeterClient(
                    next_trainers, previous_trainers, self.role_maker._role_id()
                )

    def _push_sparse_param(
        self, var_name, table_id=-1, scope=fluid.global_scope()
    ):
T
tangwei12 已提交
893 894 895 896 897 898
        self._communicator.push_sparse_param(var_name, table_id, scope)

    def _get_executor(self):
        executor = fluid.Executor(fluid.CPUPlace())
        if self.role_maker._is_heter_parameter_server_mode:
            if self.role_maker._is_heter_worker():
899 900
                heter_device_type = self.role_maker._heter_device_type().upper()
                if heter_device_type not in ["GPU", "XPU", "CPU"]:
901
                    raise ValueError(
902 903 904
                        "Heter Worker Not Support Device {}".format(
                            heter_device_type
                        )
905
                    )
906
                if heter_device_type == "GPU":
T
tangwei12 已提交
907 908
                    executor = Executor(
                        fluid.CUDAPlace(
909 910 911
                            int(os.getenv("FLAGS_selected_gpus", "0"))
                        )
                    )
912
                elif heter_device_type == "XPU":
T
tangwei12 已提交
913 914
                    executor = Executor(
                        fluid.XPUPlace(
915 916 917
                            int(os.getenv("FLAGS_selected_xpus", "0"))
                        )
                    )
T
tangwei12 已提交
918 919
        return executor

920
    def _get_fleet_proto(self, is_server, is_sync, **kwargs):
T
tangwei12 已提交
921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
        def _build_merge_accessor(ctx):
            accessor = Accessor()
            accessor.accessor_class = "CommMergeAccessor"
            accessor.optimizer = None

            if ctx.is_sparse():
                accessor.feature_dim = ctx.sections()[0]
                accessor.embedding_dim = ctx.sections()[1]
            else:
                accessor.feature_dim = ctx.sections()[0]
                accessor.embedding_dim = 1

            return accessor

        def _build_barrier_table(idx):
            table = Table()
            table.id = idx
            table.type = "PS_OTHER_TABLE"
            table.table_class = "BarrierTable"
            table.shard_num = 256

            accessor = Accessor()
            accessor.accessor_class = "CommMergeAccessor"
            accessor.optimizer = None
            accessor.feature_dim = 0
            accessor.embedding_dim = 0
            table.accessor = accessor

            common = CommonAccessor()
            common.table_name = "barrier_table"
            trainer_num = self.compiled_strategy.get_trainers()
            if self.role_maker._is_heter_parameter_server_mode:
953
                trainer_num += len(
954 955
                    self.role_maker._get_heter_worker_endpoints()
                )
T
tangwei12 已提交
956 957 958 959 960 961 962
            common.trainer_num = trainer_num
            common.attrs = ""
            common.dims = []
            common.params = []
            table.common = common
            return table

963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998
        def _build_tensor_table(idx, tensor_dict):
            table = Table()
            table.id = idx
            table.type = "PS_OTHER_TABLE"
            table.table_class = tensor_dict["tensor_table_class"]
            table.shard_num = 256

            accessor = Accessor()
            accessor.accessor_class = "CommMergeAccessor"
            accessor.optimizer = None
            accessor.feature_dim = 0
            accessor.embedding_dim = 0
            table.accessor = accessor

            common = CommonAccessor()
            common.table_name = tensor_dict["feed_var_name"]
            common.trainer_num = self.compiled_strategy.get_trainers()
            common.attrs = ""
            common.dims = []
            common.params = []
            table.common = common

            tensor = Tensor()
            tensor.main_program_id = tensor_dict["main_program_id"]
            tensor.startup_program_id = tensor_dict["startup_program_id"]
            tensor.feed_var_name = tensor_dict["feed_var_name"]
            tensor.fetch_var_name = tensor_dict["fetch_var_name"]
            tensor.tensor_table_class = tensor_dict["tensor_table_class"]
            table.tensor = tensor

            return table

        def _add_tensor_table(tables):
            tensor_table_dict = self.compiled_strategy.get_tensor_table_dict()
            program_idx = 0
            for table_name in tensor_table_dict:
999
                if tensor_table_dict[table_name]["startup_program"] is not None:
1000
                    tensor_table_dict[table_name][
1001 1002
                        "startup_program_id"
                    ] = program_idx
1003
                    self._server_sub_program.append(
1004 1005
                        tensor_table_dict[table_name]["startup_program"].desc
                    )
1006
                    program_idx += 1
1007
                if tensor_table_dict[table_name]["main_program"] is not None:
1008
                    tensor_table_dict[table_name][
1009 1010
                        "main_program_id"
                    ] = program_idx
1011
                    self._server_sub_program.append(
1012 1013
                        tensor_table_dict[table_name]["main_program"].desc
                    )
1014 1015
                    program_idx += 1
                # Todo: Hard code for lr_decay table apply table id
1016 1017 1018
                new_table = _build_tensor_table(
                    len(tables), tensor_table_dict[table_name]
                )
1019 1020 1021
                tables.append(new_table)
            return tables

T
tangwei12 已提交
1022 1023 1024
        def _get_tables():
            send_ctx = self.compiled_strategy.get_the_one_send_context(
                use_origin_program=True,
1025 1026
                split_dense_table=self.role_maker._is_heter_parameter_server_mode,
            )
T
tangwei12 已提交
1027

1028
            tables = []
T
tangwei12 已提交
1029
            for idx, (name, ctx) in enumerate(send_ctx.items()):
T
tangwei12 已提交
1030 1031 1032
                if ctx.is_tensor_table() or len(ctx.origin_varnames()) < 1:
                    continue

T
tangwei12 已提交
1033 1034
                table = Table()
                table.id = ctx.table_id()
T
tangwei12 已提交
1035
                common = CommonAccessor()
1036

T
tangwei12 已提交
1037 1038
                if ctx.is_sparse():
                    table.type = "PS_SPARSE_TABLE"
T
tangwei12 已提交
1039
                    table.shard_num = 256
T
tangwei12 已提交
1040

1041 1042 1043 1044 1045
                    common.table_name = (
                        self.compiled_strategy.grad_name_to_param_name[
                            ctx.origin_varnames()[0]
                        ]
                    )
T
Thunderbrook 已提交
1046

T
tangwei12 已提交
1047
                    if self.compiled_strategy.is_geo_mode():
Z
zhaocaibei123 已提交
1048
                        table.table_class = "MemorySparseGeoTable"
T
tangwei12 已提交
1049
                    else:
1050
                        all_table_proto = self.context[
1051 1052
                            "user_defined_strategy"
                        ].sparse_table_configs
1053 1054 1055 1056 1057 1058
                        table_proto = all_table_proto.add()
                        for proto in all_table_proto:
                            if proto.table_name == common.table_name:
                                table_proto = proto
                                break
                        if table_proto.HasField("table_class"):
1059 1060 1061
                            table.table_class = table_proto.table_class
                        else:
                            table.table_class = parse_table_class(
1062 1063
                                common.table_name, self.origin_main_program
                            )
1064 1065 1066
                        if table.table_class != 'MemorySparseTable':
                            table.table_class = 'MemorySparseTable'
                            warnings.warn(
1067 1068
                                "The PS mode must use MemorySparseTable."
                            )
1069

1070
                        if table_proto.HasField("shard_num"):
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
                            table.shard_num = table_proto.shard_num
                        else:
                            table.shard_num = 1000
                            warnings.warn(
                                "The shard_num of sparse table is not set, use default value 1000."
                            )

                        if table_proto.accessor.ByteSize() == 0:
                            warnings.warn(
                                "The accessor of sparse table is not set, use default value."
                            )
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
                        get_default_accessor_proto(
                            table_proto.accessor,
                            common.table_name,
                            self.origin_main_program,
                        )
                        check_embedding_dim(
                            table_proto.accessor,
                            common.table_name,
                            self.origin_main_program,
                        )
1092
                        from google.protobuf import text_format
1093

1094
                        table.accessor_proto = text_format.MessageToString(
1095 1096
                            table_proto.accessor
                        )
T
tangwei12 已提交
1097
                else:
T
tangwei12 已提交
1098
                    table.type = "PS_DENSE_TABLE"
1099
                    table.table_class = "MemoryDenseTable"
T
tangwei12 已提交
1100
                    table.shard_num = 256
T
tangwei12 已提交
1101 1102
                    common.table_name = "MergedDense"

1103
                adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
1104
                common.parse_by_optimizer(
1105 1106
                    ctx.origin_varnames()[0],
                    ctx.is_sparse(),
1107 1108
                    ctx.sections()[0],
                    ctx.sections()[1] if ctx.is_sparse() else 1,
1109 1110 1111
                    self.compiled_strategy,
                    adam_d2sum,
                )
T
tangwei12 已提交
1112

T
tangwei12 已提交
1113
                if ctx.is_sparse():
1114 1115 1116
                    common.parse_entry(
                        common.table_name, self.origin_main_program
                    )
T
tangwei12 已提交
1117

T
tangwei12 已提交
1118 1119 1120 1121 1122 1123
                if is_sync:
                    common.sync = "true"
                else:
                    common.sync = "false"
                table.common = common

1124 1125 1126
                if table.table_class != 'MemorySparseTable':
                    accessor = _build_merge_accessor(ctx)
                    table.accessor = accessor
1127 1128 1129 1130 1131 1132 1133 1134
                tables.append(table)

            tensor_table_dict = self.compiled_strategy.get_tensor_table_dict()
            if len(tensor_table_dict) > 0:
                tables = _add_tensor_table(tables)
            else:
                empty_porgram = Program()
                self._server_sub_program.append(empty_porgram.desc)
T
tangwei12 已提交
1135

1136 1137
            barrier_table = _build_barrier_table(len(tables))
            tables.append(barrier_table)
T
tangwei12 已提交
1138 1139 1140 1141 1142 1143 1144
            return tables

        if is_server:
            server = Server()
            downpour_server = DownpourServer()

            service = Service()
T
Thunderbrook 已提交
1145 1146 1147 1148 1149
            dist_strategy = self.context["valid_strategy"]
            use_ps_gpu = dist_strategy.a_sync_configs["use_ps_gpu"]
            if use_ps_gpu:
                service.server_class = "PsLocalServer"
                service.client_class = "PsLocalClient"
T
tangwei12 已提交
1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168
            downpour_server.set_service_param(service)

            tables = _get_tables()
            downpour_server.tables = tables
            server.add_server(downpour_server)
            return server
        else:
            worker = Worker()
            downpour_worker = DownpourWorker()

            tables = _get_tables()
            downpour_worker.tables = tables
            worker.add_worker(downpour_worker)
            return worker

    def _init_server(self, dirname=None, var_names=None, **kwargs):
        role_id = self.compiled_strategy.get_role_id()
        endpoints = self.compiled_strategy.get_ps_endpoints()
        is_sync = self.compiled_strategy.is_sync_mode()
T
tangwei12 已提交
1169
        trainers = self.compiled_strategy.get_trainers()
1170 1171
        if self.role_maker._is_heter_parameter_server_mode:
            trainers += len(self.role_maker._get_heter_worker_endpoints())
T
tangwei12 已提交
1172 1173
        server = self._get_fleet_proto(is_server=True, is_sync=is_sync)
        proto_txt = str(server)
1174
        fs_client = fsClient(
1175 1176
            self.context["user_defined_strategy"].fs_client_param
        )
1177
        proto_txt = proto_txt + "\n" + fs_client.to_string()
T
tangwei12 已提交
1178

T
tangwei12 已提交
1179
        debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
T
tangwei12 已提交
1180 1181 1182 1183 1184 1185 1186 1187 1188 1189
        if debug:
            print("server: \n{}".format(proto_txt))

        string_hosts = []
        for idx, ep in enumerate(endpoints):
            host, port = ep.split(":")
            pshost = fluid.core.PSHost(host, int(port), idx)
            string_hosts.append(pshost.serialize_to_string())

        self._server = fluid.core.DistFleetWrapper()
1190 1191 1192
        self._server.init_server(
            proto_txt, string_hosts, role_id, trainers, self._server_sub_program
        )
T
tangwei12 已提交
1193

1194 1195 1196
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            get_sparse_tablenames,
        )
T
tangwei12 已提交
1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208

        dist_varnames = get_sparse_tablenames(self.origin_main_program, True)
        sparse_varnames = get_sparse_tablenames(self.origin_main_program, False)

        distributed_varnames = dist_varnames + sparse_varnames

        if var_names is None:
            load_varnames = distributed_varnames
        else:
            for var_name in var_names:
                if var_name not in distributed_varnames:
                    raise ValueError(
1209 1210 1211 1212
                        "fleet.init server can only load sparse variables in {}".format(
                            distributed_varnames
                        )
                    )
T
tangwei12 已提交
1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
            load_varnames = var_names

        if dirname is None or not load_varnames:
            return

        sparse_table_maps = {}
        for table in server.servers[0].tables:
            if table.type == "PS_SPARSE_TABLE" and table.common is not None:
                sparse_table_maps[table.common.table_name] = table.id

        dirname = os.path.normpath(dirname)
        pserver_id = self.role_maker._role_id()

        for var_name in load_varnames:
            table_id = sparse_table_maps[var_name]
1228 1229 1230 1231 1232
            # path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
            #                     "{}.block{}.txt".format(var_name, pserver_id))
            # meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
            #                     "{}.block{}.meta".format(var_name, pserver_id))
            self._server.load_sparse(dirname, "0", table_id)
T
tangwei12 已提交
1233 1234 1235 1236 1237 1238 1239 1240

    def _run_server(self):
        ep = self.compiled_strategy.get_ps_endpoint()
        host, port = ep.split(":")
        self._server.run_server(host, int(port))

    def _stop_worker(self):
        self._communicator.stop()
1241
        if self.role_maker._is_heter_parameter_server_mode:
1242
            assert (
1243
                self._heter_client is not None
1244
            ), "heter client should not be None in heterps mode"
T
tangwei12 已提交
1245
            self._heter_client.stop()
1246 1247
        # executor = self._get_executor()
        # executor.close()
T
tangwei12 已提交
1248 1249 1250 1251 1252 1253 1254

    @staticmethod
    def __exclude_vars(exclude_var_names=[]):
        def is_valid(var):
            if var.name in exclude_var_names:
                return False

1255 1256 1257
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                _get_varname_parts,
            )
T
tangwei12 已提交
1258 1259 1260 1261 1262 1263 1264 1265

            origin_varname, _, _ = _get_varname_parts(var.name)
            if origin_varname.endswith("@GRAD"):
                return False

            if origin_varname == "learning_rate_0":
                return False

1266 1267 1268 1269 1270
            if (
                var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
                or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
                or var.desc.type() == core.VarDesc.VarType.READER
            ):
T
tangwei12 已提交
1271 1272 1273 1274 1275
                return False
            return var.persistable

        return is_valid

1276 1277 1278 1279 1280 1281 1282
    def _get_inference_model_path(self, dirname):
        if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
            model_path = "./dnn_plugin"
        else:
            model_path = os.path.join(dirname, "dnn_plugin")
        return model_path

1283 1284 1285 1286 1287 1288 1289
    def _save_sparse_params(
        self, executor, dirname, context, main_program, mode
    ):
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            get_sparse_tablenames,
        )

1290
        distributed_varnames = get_sparse_tablenames(
1291 1292
            self.compiled_strategy.origin_main_program, True
        )
T
tangwei12 已提交
1293
        values = []
1294
        model_path = self._get_inference_model_path(dirname)
T
tangwei12 已提交
1295
        for id, names in context.items():
T
tangwei12 已提交
1296
            if names[0] not in distributed_varnames:
1297
                # only save sparse param to local
1298
                try:
1299
                    self._worker.recv_and_save_model(id, model_path)
1300 1301
                except:
                    pass
1302 1303
            # save sparse & distributed param on server
            self._worker.save_one_model(id, dirname, mode)
T
tangwei12 已提交
1304
            values.extend(names)
1305
        # self._worker.save_all_model(dirname, mode)
T
tangwei12 已提交
1306 1307
        return values

1308 1309 1310
    def _save_distributed_persistables(
        self, executor, dirname, main_program, mode=0
    ):
T
tangwei12 已提交
1311 1312 1313 1314

        denses = self.compiled_strategy.get_the_one_recv_context(
            is_dense=True,
            split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1315 1316
            use_origin_program=True,
        )
T
tangwei12 已提交
1317 1318 1319
        sparses = self.compiled_strategy.get_the_one_recv_context(
            is_dense=False,
            split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1320 1321
            use_origin_program=True,
        )
T
tangwei12 已提交
1322

1323 1324 1325
        sparse_varnames = self._save_sparse_params(
            executor, dirname, sparses, main_program, mode
        )
T
tangwei12 已提交
1326 1327 1328 1329

        recv_dense_varnames = []
        for id, names in denses.items():
            recv_dense_varnames.extend(names)
1330
        self._communicator.pull_dense(denses)
T
tangwei12 已提交
1331

1332
        saved_varnames = sparse_varnames
T
tangwei12 已提交
1333 1334

        remaining_vars = list(
1335 1336 1337 1338 1339
            filter(
                TheOnePSRuntime.__exclude_vars(saved_varnames),
                main_program.list_vars(),
            )
        )
T
tangwei12 已提交
1340

T
tangwei12 已提交
1341
        import paddle
1342

T
tangwei12 已提交
1343
        for var in remaining_vars:
1344 1345
            # if var.name not in recv_dense_varnames:
            #     continue
T
tangwei12 已提交
1346
            tensor = var.get_value()
1347 1348 1349 1350 1351 1352 1353
            paddle.save(
                tensor, os.path.join(dirname, var.name), use_binary_format=True
            )

    def _ps_inference_save_persistables(
        self, executor, dirname, main_program=None, mode=0, **kwargs
    ):
T
tangwei12 已提交
1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366
        """
        This function filters out all variables with `persistable==True` from the
        give `main_program` and then saves these variables to the folder `dirname`
        or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
T
tangwei12 已提交
1367
                "in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
T
tangwei12 已提交
1368 1369 1370 1371
            )

        if not isinstance(executor, Executor):
            raise TypeError(
1372 1373
                "in fleet.save() function, executor must be as Executor type"
            )
T
tangwei12 已提交
1374 1375 1376 1377 1378 1379

        if main_program is None:
            main_program = self.compiled_strategy.get_origin_ps_main_program()

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
T
tangwei12 已提交
1380
                "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
T
tangwei12 已提交
1381 1382
            )

1383
        # Todo(MrChengmo): Save optimizer status
1384 1385 1386
        # self._save_distributed_persistables(executor, dirname, main_program,
        #                                     mode)
        self._worker.save_all_model(dirname, mode)
T
tangwei12 已提交
1387

1388 1389 1390 1391 1392 1393 1394 1395 1396 1397
    def _ps_inference_save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
        mode=0,
    ):
T
tangwei12 已提交
1398 1399 1400 1401 1402 1403 1404
        """
        Prune the given `main_program` to build a new program especially for inference,
        and then save it and all related parameters to given `dirname` by the `executor`.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
T
tangwei12 已提交
1405
                "in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
T
tangwei12 已提交
1406 1407 1408 1409
            )

        if not isinstance(executor, Executor):
            raise TypeError(
1410 1411
                "in fleet.save() function, executor must be as Executor type"
            )
T
tangwei12 已提交
1412 1413

        import paddle
1414 1415 1416 1417

        program = (
            self.origin_main_program if main_program is None else main_program
        )
T
tangwei12 已提交
1418 1419 1420 1421

        if isinstance(program, CompiledProgram):
            raise TypeError(
                "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
T
tangwei12 已提交
1422 1423
            )

T
tangwei12 已提交
1424 1425 1426 1427
        feed_vars = [
            program.global_block().var(name) for name in feeded_var_names
        ]

1428 1429 1430
        infer_program = paddle.static.normalize_program(
            program, feed_vars, target_vars
        )
T
tangwei12 已提交
1431 1432 1433

        infer_program._copy_dist_param_info_from(program)

1434
        model_path = self._get_inference_model_path(dirname)
T
tangwei12 已提交
1435
        model_basename = "__model__"
1436
        model_basename = os.path.join(model_path, model_basename)
T
tangwei12 已提交
1437 1438
        paddle.save(infer_program, model_basename)

1439 1440 1441
        sparses = self.compiled_strategy.get_the_one_recv_context(
            is_dense=False,
            split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1442 1443 1444 1445 1446
            use_origin_program=True,
        )
        sparse_names = self._save_sparse_params(
            executor, dirname, sparses, main_program, mode
        )
1447 1448 1449 1450

        denses = self.compiled_strategy.get_the_one_recv_context(
            is_dense=True,
            split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1451 1452
            use_origin_program=True,
        )
Z
zhaocaibei123 已提交
1453
        # TODO(zhaocaibei123): for GEO: should call GeoCommunicator::RecvDense
1454 1455 1456
        self._communicator.pull_dense(denses)

        generate_vars = self.context[
1457 1458
            "user_defined_strategy"
        ].trainer_desc_configs["stat_var_names"]
1459 1460
        generate_vars = [var for var in generate_vars]
        remaining_vars = list(
1461 1462 1463 1464 1465
            filter(
                TheOnePSRuntime.__exclude_vars(sparse_names),
                infer_program.list_vars(),
            )
        )
1466

1467 1468
        for var in remaining_vars:
            tensor = var.get_value()
1469 1470 1471 1472 1473
            paddle.save(
                tensor,
                os.path.join(model_path, var.name),
                use_binary_format=True,
            )
1474

T
tangwei12 已提交
1475 1476 1477 1478 1479
    def _save_inference_model(self, *args, **kwargs):
        self._ps_inference_save_inference_model(*args, **kwargs)

    def _save_persistables(self, *args, **kwargs):
        self._ps_inference_save_persistables(*args, **kwargs)
1480

1481
    def _load_sparse_params(self, dirname, context, main_program, mode):
1482 1483 1484 1485
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            get_sparse_tablenames,
        )

1486
        distributed_varnames = get_sparse_tablenames(
1487 1488
            self.compiled_strategy.origin_main_program, True
        )
1489 1490 1491 1492 1493 1494 1495 1496 1497 1498
        values = []
        for id, names in context.items():
            if names[0] not in distributed_varnames:
                # TODO: only load sparse param from local
                warnings.warn("varname is not in distributed_varnames, pass")
            # load sparse & distributed param on server
            self._worker.load_one_table(id, dirname, mode)
            values.extend(names)
        return values

1499 1500 1501
    def _ps_inference_load_inference_model(
        self, dirname, mode=0, main_program=None
    ):
1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512
        if main_program is None:
            main_program = self.compiled_strategy.get_origin_ps_main_program()

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

        denses = self.compiled_strategy.get_the_one_recv_context(
            is_dense=True,
            split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1513 1514
            use_origin_program=True,
        )
1515 1516 1517
        sparses = self.compiled_strategy.get_the_one_recv_context(
            is_dense=False,
            split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1518 1519
            use_origin_program=True,
        )
1520

1521 1522 1523
        sparse_varnames = self._load_sparse_params(
            dirname, sparses, main_program, mode
        )
1524 1525 1526 1527 1528 1529 1530 1531

        recv_dense_varnames = []
        for id, names in denses.items():
            recv_dense_varnames.extend(names)

        loaded_varnames = sparse_varnames

        remaining_vars = list(
1532 1533 1534 1535 1536
            filter(
                TheOnePSRuntime.__exclude_vars(loaded_varnames),
                main_program.list_vars(),
            )
        )
1537

1538 1539 1540 1541
        if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
            model_path = "./dnn_plugin"
        else:
            model_path = os.path.join(dirname, "dnn_plugin")
1542
        import paddle
1543

1544 1545 1546
        for var in remaining_vars:
            if var.name not in recv_dense_varnames:
                continue
1547
            tensor = paddle.load(os.path.join(model_path, var.name))
1548 1549 1550 1551
            var.set_value(tensor)

        self._communicator.init_params(denses)

1552 1553 1554
    def _load_distributed_persistables(self, path, mode):
        self._worker.load_model(path, mode)

T
Thunderbrook 已提交
1555
    def load_model(self, path, mode):
1556 1557 1558 1559
        if mode == 0 or mode == 3:
            self._load_distributed_persistables(path, mode)
        else:
            self._ps_inference_load_inference_model(path, mode)
1560
        # self._load_distributed_persistables(path, mode=mode)
T
Thunderbrook 已提交
1561

1562 1563 1564 1565 1566 1567 1568
    def _shrink(self, threshold=None):
        if threshold is not None:
            warnings.warn(
                "The param threshold is not used in MemorySparseTable, if you need to shrink, please set the config of accessor"
            )
        else:
            threshold = 0
1569
        import paddle.distributed.fleet as fleet
1570

1571 1572 1573 1574
        fleet.util.barrier()
        if self.role_maker._is_first_worker():
            sparses = self.compiled_strategy.get_the_one_recv_context(
                is_dense=False,
1575 1576 1577
                split_dense_table=self.role_maker._is_heter_parameter_server_mode,
                use_origin_program=True,
            )
1578 1579 1580 1581

            for id, names in sparses.items():
                self._worker.shrink_sparse_table(id, threshold)
        fleet.util.barrier()